diff options
| author | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-31 11:45:39 +0000 | 
|---|---|---|
| committer | Edoardo Pasca <edo.paskino@gmail.com> | 2018-01-19 14:26:06 +0000 | 
| commit | a5ee66a4aee472ab72d204783b5e3da4b4f65beb (patch) | |
| tree | 2ea21e075737a8fa51a2a809e2824074b0c7dc30 | |
| parent | 43b6f16ea68523f9d13457b17b44181222f1e6c1 (diff) | |
| download | regularization-a5ee66a4aee472ab72d204783b5e3da4b4f65beb.tar.gz regularization-a5ee66a4aee472ab72d204783b5e3da4b4f65beb.tar.bz2 regularization-a5ee66a4aee472ab72d204783b5e3da4b4f65beb.tar.xz regularization-a5ee66a4aee472ab72d204783b5e3da4b4f65beb.zip | |
fixed setParameter
fixed setParameter
allows regularizer to output simply the image rather than list.
| -rw-r--r-- | src/Python/ccpi/imaging/Regularizer.py | 42 | 
1 files changed, 27 insertions, 15 deletions
| diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py index 8ab6c6a..23799d6 100644 --- a/src/Python/ccpi/imaging/Regularizer.py +++ b/src/Python/ccpi/imaging/Regularizer.py @@ -108,6 +108,8 @@ class Regularizer():          else:              raise Exception('Unknown regularizer algorithm') + +        self.acceptedInputKeywords = pars.keys()          return pars      # parsForAlgorithm @@ -134,17 +136,24 @@ class Regularizer():                  raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))      # setParameter -    def getParameter(self, **kwargs): -        ret = {} -        for key , value in kwargs.items(): -            if key in self.pars.keys(): -                ret[key] = self.pars[key] +    def getParameter(self, key): +        if type(key) is str: +            if key in self.acceptedInputKeywords: +                return self.pars[key] +            else: +                raise Exception('Unrecongnised parameter: {0} '.format(key) ) +        elif type(key) is list: +            outpars = [] +            for k in key: +                outpars.append(self.getParameter(k)) +            return outpars          else: -            raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) -    # setParameter +            raise Exception('Unhandled input {0}' .format(str(type(key)))) +        # getParameter -    def __call__(self, input = None, regularization_parameter = None, **kwargs): +    def __call__(self, input = None, regularization_parameter = None, +                 output_all = False, **kwargs):          '''Actual call for the regularizer.           One can either set the regularization parameters first and then call the @@ -179,19 +188,19 @@ class Regularizer():          input = self.pars['input']          regularization_parameter = self.pars['regularization_parameter']          if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : -            return self.algorithm(input, regularization_parameter, +            ret = self.algorithm(input, regularization_parameter,                                self.pars['number_of_iterations'],                                self.pars['tolerance_constant'],                                self.pars['TV_penalty'].value )              elif self.algorithm == Regularizer.Algorithm.FGP_TV : -            return self.algorithm(input, regularization_parameter, +            ret = self.algorithm(input, regularization_parameter,                                self.pars['number_of_iterations'],                                self.pars['tolerance_constant'],                                self.pars['TV_penalty'].value )          elif self.algorithm == Regularizer.Algorithm.LLT_model :              #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)              # no default -            return self.algorithm(input,  +            ret = self.algorithm(input,                                 regularization_parameter,                                self.pars['time_step'] ,                                 self.pars['number_of_iterations'], @@ -200,7 +209,7 @@ class Regularizer():          elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul :              #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)              # no default -            return self.algorithm(input, regularization_parameter, +            ret = self.algorithm(input, regularization_parameter,                                    self.pars['searching_window_ratio'] ,                                     self.pars['similarity_window_ratio'] ,                                     self.pars['PB_filtering_parameter']) @@ -208,7 +217,7 @@ class Regularizer():              #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)              # no default              if len(np.shape(input)) == 2: -                return self.algorithm(input, regularization_parameter, +                ret = self.algorithm(input, regularization_parameter,                                    self.pars['first_order_term'] ,                                     self.pars['second_order_term'] ,                                     self.pars['number_of_iterations']) @@ -227,11 +236,14 @@ class Regularizer():                  output = [out3d]                  for i in range(1,len(out)):                      output.append(out[i]) -                return output +                ret = output -             +        if output_all: +            return ret +        else: +            return ret[0]      # __call__ | 
