diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-31 11:45:39 +0000 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-31 11:45:39 +0000 |
commit | 8bc351e5cba87769c3ed0e0c872e823f7f5943fb (patch) | |
tree | 7598f5d3bb72cfa2fee9b40c2c4ce5f9bb1b0b8a | |
parent | 1a7d0382db817b5d35f3f45516301cc8003d9b2f (diff) | |
download | regularization-8bc351e5cba87769c3ed0e0c872e823f7f5943fb.tar.gz regularization-8bc351e5cba87769c3ed0e0c872e823f7f5943fb.tar.bz2 regularization-8bc351e5cba87769c3ed0e0c872e823f7f5943fb.tar.xz regularization-8bc351e5cba87769c3ed0e0c872e823f7f5943fb.zip |
fixed setParameter
fixed setParameter
allows regularizer to output simply the image rather than list.
-rw-r--r-- | src/Python/ccpi/imaging/Regularizer.py | 42 |
1 files changed, 27 insertions, 15 deletions
diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py index 8ab6c6a..23799d6 100644 --- a/src/Python/ccpi/imaging/Regularizer.py +++ b/src/Python/ccpi/imaging/Regularizer.py @@ -108,6 +108,8 @@ class Regularizer(): else: raise Exception('Unknown regularizer algorithm') + + self.acceptedInputKeywords = pars.keys() return pars # parsForAlgorithm @@ -134,17 +136,24 @@ class Regularizer(): raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) # setParameter - def getParameter(self, **kwargs): - ret = {} - for key , value in kwargs.items(): - if key in self.pars.keys(): - ret[key] = self.pars[key] + def getParameter(self, key): + if type(key) is str: + if key in self.acceptedInputKeywords: + return self.pars[key] + else: + raise Exception('Unrecongnised parameter: {0} '.format(key) ) + elif type(key) is list: + outpars = [] + for k in key: + outpars.append(self.getParameter(k)) + return outpars else: - raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) - # setParameter + raise Exception('Unhandled input {0}' .format(str(type(key)))) + # getParameter - def __call__(self, input = None, regularization_parameter = None, **kwargs): + def __call__(self, input = None, regularization_parameter = None, + output_all = False, **kwargs): '''Actual call for the regularizer. One can either set the regularization parameters first and then call the @@ -179,19 +188,19 @@ class Regularizer(): input = self.pars['input'] regularization_parameter = self.pars['regularization_parameter'] if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : - return self.algorithm(input, regularization_parameter, + ret = self.algorithm(input, regularization_parameter, self.pars['number_of_iterations'], self.pars['tolerance_constant'], self.pars['TV_penalty'].value ) elif self.algorithm == Regularizer.Algorithm.FGP_TV : - return self.algorithm(input, regularization_parameter, + ret = self.algorithm(input, regularization_parameter, self.pars['number_of_iterations'], self.pars['tolerance_constant'], self.pars['TV_penalty'].value ) elif self.algorithm == Regularizer.Algorithm.LLT_model : #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) # no default - return self.algorithm(input, + ret = self.algorithm(input, regularization_parameter, self.pars['time_step'] , self.pars['number_of_iterations'], @@ -200,7 +209,7 @@ class Regularizer(): elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) # no default - return self.algorithm(input, regularization_parameter, + ret = self.algorithm(input, regularization_parameter, self.pars['searching_window_ratio'] , self.pars['similarity_window_ratio'] , self.pars['PB_filtering_parameter']) @@ -208,7 +217,7 @@ class Regularizer(): #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) # no default if len(np.shape(input)) == 2: - return self.algorithm(input, regularization_parameter, + ret = self.algorithm(input, regularization_parameter, self.pars['first_order_term'] , self.pars['second_order_term'] , self.pars['number_of_iterations']) @@ -227,11 +236,14 @@ class Regularizer(): output = [out3d] for i in range(1,len(out)): output.append(out[i]) - return output + ret = output - + if output_all: + return ret + else: + return ret[0] # __call__ |