diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Python/setup.py | 1 | ||||
| -rw-r--r-- | src/Python/test_regularizers.py | 195 | 
2 files changed, 168 insertions, 28 deletions
diff --git a/src/Python/setup.py b/src/Python/setup.py index a4eed14..0468722 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -53,6 +53,7 @@ setup(                                        "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c",                                        "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c",                                        "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", +                                      "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c",                                        "..\\..\\main_func\\regularizers_CPU\\utils.c"                                          ],                               include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ),  diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 6abfba4..6a34749 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -47,6 +47,8 @@ class Regularizer():          SplitBregman_TV = regularizers.SplitBregman_TV          FGP_TV = regularizers.FGP_TV          LLT_model = regularizers.LLT_model +        PatchBased_Regul = regularizers.PatchBased_Regul +        TGV_PD = regularizers.TGV_PD      # Algorithm      class TotalVariationPenalty(Enum): @@ -55,13 +57,17 @@ class Regularizer():      # TotalVariationPenalty      def __init__(self , algorithm): -         +        self.setAlgorithm ( algorithm ) +    # __init__ +     +    def setAlgorithm(self, algorithm):          self.algorithm = algorithm          self.pars = self.parsForAlgorithm(algorithm) -    # __init__ +    # setAlgorithm      def parsForAlgorithm(self, algorithm):          pars = dict() +                  if algorithm == Regularizer.Algorithm.SplitBregman_TV :              pars['algorithm'] = algorithm              pars['input'] = None @@ -69,6 +75,7 @@ class Regularizer():              pars['number_of_iterations'] = 35              pars['tolerance_constant'] = 0.0001              pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic +                      elif algorithm == Regularizer.Algorithm.FGP_TV :              pars['algorithm'] = algorithm              pars['input'] = None @@ -76,6 +83,7 @@ class Regularizer():              pars['number_of_iterations'] = 50              pars['tolerance_constant'] = 0.001              pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic +                      elif algorithm == Regularizer.Algorithm.LLT_model:              pars['algorithm'] = algorithm              pars['input'] = None @@ -85,6 +93,24 @@ class Regularizer():              pars['tolerance_constant'] = None              pars['restrictive_Z_smoothing'] = 0 +        elif algorithm == Regularizer.Algorithm.PatchBased_Regul: +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['searching_window_ratio'] = None +            pars['similarity_window_ratio'] = None +            pars['PB_filtering_parameter'] = None +            pars['regularization_parameter'] = None +             +        elif algorithm == Regularizer.Algorithm.TGV_PD: +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['first_order_term'] = None +            pars['second_order_term'] = None +            pars['number_of_iterations'] = None +            pars['regularization_parameter'] = None +             +             +                      return pars      # parsForAlgorithm @@ -98,6 +124,8 @@ class Regularizer():          self.pars['regularization_parameter'] = regularization_parameter          #for key, value in self.pars.items():          #        print("{0} = {1}".format(key, value)) +        if None in self.pars: +                raise Exception("Not all parameters have been provided")          if self.algorithm == Regularizer.Algorithm.SplitBregman_TV :              return self.algorithm(input, regularization_parameter, @@ -112,15 +140,27 @@ class Regularizer():          elif self.algorithm == Regularizer.Algorithm.LLT_model :              #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)              # no default -            if None in self.pars: -                raise Exception("Not all parameters have been provided") -            else: -                return self.algorithm(input,  -                                  regularization_parameter, -                                  self.pars['time_step'] ,  -                                  self.pars['number_of_iterations'], -                                  self.pars['tolerance_constant'], -                                  self.pars['restrictive_Z_smoothing'] ) +            return self.algorithm(input,  +                              regularization_parameter, +                              self.pars['time_step'] ,  +                              self.pars['number_of_iterations'], +                              self.pars['tolerance_constant'], +                              self.pars['restrictive_Z_smoothing'] ) +        elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : +            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) +            # no default +            return self.algorithm(input, regularization_parameter, +                                  self.pars['searching_window_ratio'] ,  +                                  self.pars['similarity_window_ratio'] ,  +                                  self.pars['PB_filtering_parameter']) +        elif self.algorithm == Regularizer.Algorithm.TGV_PD : +            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) +            # no default +            return self.algorithm(input, regularization_parameter, +                                  self.pars['first_order_term'] ,  +                                  self.pars['second_order_term'] ,  +                                  self.pars['number_of_iterations']) +                  # __call__ @@ -142,13 +182,40 @@ class Regularizer():      @staticmethod      def LLT_model(input, regularization_parameter , time_step, number_of_iterations,                    tolerance_constant, restrictive_Z_smoothing=0): -        reg = Regularizer(Regularizer.Algorithm.FGP_TV) +        reg = Regularizer(Regularizer.Algorithm.LLT_model)          out = list( reg(input, regularization_parameter, time_step=time_step,                           number_of_iterations=number_of_iterations,                          tolerance_constant=tolerance_constant,                           restrictive_Z_smoothing=restrictive_Z_smoothing) )          out.append(reg.pars)          return out +     +    @staticmethod +    def PatchBased_Regul(input, regularization_parameter, +                        searching_window_ratio,  +                        similarity_window_ratio, +                        PB_filtering_parameter): +        reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul)    +        out = list( reg(input,  +                        regularization_parameter, +                        searching_window_ratio=searching_window_ratio,  +                        similarity_window_ratio=similarity_window_ratio, +                        PB_filtering_parameter=PB_filtering_parameter ) +            ) +        out.append(reg.pars) +        return out +     +    @staticmethod +    def TGV_PD(input, regularization_parameter , first_order_term,  +               second_order_term, number_of_iterations): +         +        reg = Regularizer(Regularizer.Algorithm.TGV_PD) +        out = list( reg(input, regularization_parameter,  +                        first_order_term=first_order_term,  +                        second_order_term=second_order_term, +                        number_of_iterations=number_of_iterations) ) +        out.append(reg.pars) +        return out  #Example: @@ -171,17 +238,17 @@ u0 = Im + (perc* np.random.normal(size=np.shape(Im)))  f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1)  u0 = f(u0).astype('float32') -# plot  +## plot   fig = plt.figure() -a=fig.add_subplot(2,3,1) -a.set_title('Original') -imgplot = plt.imshow(Im) +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) -a=fig.add_subplot(2,3,2) +a=fig.add_subplot(2,3,1)  a.set_title('noise')  imgplot = plt.imshow(u0) - +reg_output = []  ##############################################################################  # Call regularizer @@ -199,8 +266,9 @@ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., numbe            TV_Penalty=Regularizer.TotalVariationPenalty.l1)  out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )  pars = out2[2] +reg_output.append(out2) -a=fig.add_subplot(2,3,3) +a=fig.add_subplot(2,3,2)  a.set_title('SplitBregman_TV')  textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s'  textstr = textstr % (pars['regularization_parameter'],  @@ -213,7 +281,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)  # place a text box in upper left in axes coords  a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,          verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0])  ###################### FGP_TV #########################################  # u = FGP_TV(single(u0), 0.05, 100, 1e-04); @@ -221,7 +289,9 @@ out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05,                            number_of_iterations=10)  pars = out2[-1] -a=fig.add_subplot(2,3,4) +reg_output.append(out2) + +a=fig.add_subplot(2,3,3)  a.set_title('FGP_TV')  textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s'  textstr = textstr % (pars['regularization_parameter'],  @@ -234,18 +304,23 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)  # place a text box in upper left in axes coords  a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,          verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0])  ###################### LLT_model #########################################  # * u0 = Im + .03*randn(size(Im)); % adding noise  # [Den] = LLT_model(single(u0), 10, 0.1, 1); -out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10., -                          time_step=0.1, -                          tolerance_constant=1e-4, -                          number_of_iterations=10) +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);  +#input, regularization_parameter , time_step, number_of_iterations, +#                  tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +                          time_step=0.0003, +                          tolerance_constant=0.0001, +                          number_of_iterations=300)  pars = out2[-1] -a=fig.add_subplot(2,3,5) +reg_output.append(out2) + +a=fig.add_subplot(2,3,4)  a.set_title('LLT_model')  textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f'  textstr = textstr % (pars['regularization_parameter'],  @@ -259,7 +334,71 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)  # place a text box in upper left in axes coords  a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,          verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) + +###################### PatchBased_Regul ######################################### +# Quick 2D denoising example in Matlab:    +#   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +#   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +#   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);  + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +                          searching_window_ratio=3, +                          similarity_window_ratio=1, +                          PB_filtering_parameter=0.08) +pars = out2[-1] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) +a.set_title('PatchBased_Regul') +textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f' +textstr = textstr % (pars['regularization_parameter'],  +                     pars['searching_window_ratio'],  +                     pars['similarity_window_ratio'], +                     pars['PB_filtering_parameter']) + + + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) + + +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab:    +#   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +#   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +#   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +                          first_order_term=1.3, +                          second_order_term=1, +                          number_of_iterations=550) +pars = out2[-1] +reg_output.append(out2) + +a=fig.add_subplot(2,3,6) +a.set_title('TGV_PD') +textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d' +textstr = textstr % (pars['regularization_parameter'],  +                     pars['first_order_term'],  +                     pars['second_order_term'], +                     pars['number_of_iterations']) + + + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0])  | 
