diff options
-rw-r--r-- | Wrappers/Python/ccpi/plugins/regularisers.py | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/Wrappers/Python/ccpi/plugins/regularisers.py b/Wrappers/Python/ccpi/plugins/regularisers.py index f665a04..77543f9 100644 --- a/Wrappers/Python/ccpi/plugins/regularisers.py +++ b/Wrappers/Python/ccpi/plugins/regularisers.py @@ -43,11 +43,16 @@ class ROF_TV(Function): 'number_of_iterations' :self.iterationsTV ,\ 'time_marching_parameter':self.time_marchstep} - out = regularisers.ROF_TV(pars['input'], + res = regularisers.ROF_TV(pars['input'], pars['regularization_parameter'], pars['number_of_iterations'], pars['time_marching_parameter'], self.device) - return DataContainer(out) + if out is not None: + out.fill(res) + else: + out = x.copy() + out.fill(res) + return out class FGP_TV(Function): def __init__(self,lambdaReg,iterationsTV,tolerance,methodTV,nonnegativity,printing,device): @@ -63,7 +68,7 @@ class FGP_TV(Function): # evaluate objective function of TV gradient EnergyValTV = TV_ENERGY(np.asarray(x.as_array(), dtype=np.float32), np.asarray(x.as_array(), dtype=np.float32), self.lambdaReg, 2) return 0.5*EnergyValTV[0] - def prox(self,x,tau): + def proximal(self,x,tau, out=None): pars = {'algorithm' : FGP_TV, \ 'input' : np.asarray(x.as_array(), dtype=np.float32),\ 'regularization_parameter':self.lambdaReg*tau, \ @@ -73,16 +78,20 @@ class FGP_TV(Function): 'nonneg': self.nonnegativity ,\ 'printingOut': self.printing} - out = regularisers.FGP_TV(pars['input'], + res = regularisers.FGP_TV(pars['input'], pars['regularization_parameter'], pars['number_of_iterations'], pars['tolerance_constant'], pars['methodTV'], pars['nonneg'], - pars['printingOut'], self.device) - return DataContainer(out) - - + self.device) + if out is not None: + out.fill(res) + else: + out = x.copy() + out.fill(res) + return out + class SB_TV(Function): def __init__(self,lambdaReg,iterationsTV,tolerance,methodTV,printing,device): # set parameters @@ -96,7 +105,7 @@ class SB_TV(Function): # evaluate objective function of TV gradient EnergyValTV = TV_ENERGY(np.asarray(x.as_array(), dtype=np.float32), np.asarray(x.as_array(), dtype=np.float32), self.lambdaReg, 2) return 0.5*EnergyValTV[0] - def prox(self,x,tau): + def proximal(self,x,tau, out=None): pars = {'algorithm' : SB_TV, \ 'input' : np.asarray(x.as_array(), dtype=np.float32),\ 'regularization_parameter':self.lambdaReg*tau, \ @@ -105,10 +114,15 @@ class SB_TV(Function): 'methodTV': self.methodTV ,\ 'printingOut': self.printing} - out = regularisers.SB_TV(pars['input'], + res = regularisers.SB_TV(pars['input'], pars['regularization_parameter'], pars['number_of_iterations'], pars['tolerance_constant'], pars['methodTV'], pars['printingOut'], self.device) - return DataContainer(out) + if out is not None: + out.fill(res) + else: + out = x.copy() + out.fill(res) + return out |