diff options
Diffstat (limited to 'src/Python/Regularizer.py')
-rw-r--r-- | src/Python/Regularizer.py | 322 |
1 files changed, 0 insertions, 322 deletions
diff --git a/src/Python/Regularizer.py b/src/Python/Regularizer.py deleted file mode 100644 index 15dbbb4..0000000 --- a/src/Python/Regularizer.py +++ /dev/null @@ -1,322 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Tue Aug 8 14:26:00 2017 - -@author: ofn77899 -""" - -import regularizers -import numpy as np -from enum import Enum -import timeit - -class Regularizer(): - '''Class to handle regularizer algorithms to be used during reconstruction - - Currently 5 CPU (OMP) regularization algorithms are available: - - 1) SplitBregman_TV - 2) FGP_TV - 3) LLT_model - 4) PatchBased_Regul - 5) TGV_PD - - Usage: - the regularizer can be invoked as object or as static method - Depending on the actual regularizer the input parameter may vary, and - a different default setting is defined. - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - - out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., - number_of_iterations=30, tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - A number of optional parameters can be passed or skipped - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) - - ''' - class Algorithm(Enum): - SplitBregman_TV = regularizers.SplitBregman_TV - FGP_TV = regularizers.FGP_TV - LLT_model = regularizers.LLT_model - PatchBased_Regul = regularizers.PatchBased_Regul - TGV_PD = regularizers.TGV_PD - # Algorithm - - class TotalVariationPenalty(Enum): - isotropic = 0 - l1 = 1 - # TotalVariationPenalty - - def __init__(self , algorithm, debug = True): - self.setAlgorithm ( algorithm ) - self.debug = debug - # __init__ - - def setAlgorithm(self, algorithm): - self.algorithm = algorithm - self.pars = self.getDefaultParsForAlgorithm(algorithm) - # setAlgorithm - - def getDefaultParsForAlgorithm(self, algorithm): - pars = dict() - - if algorithm == Regularizer.Algorithm.SplitBregman_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 35 - pars['tolerance_constant'] = 0.0001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.FGP_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 50 - pars['tolerance_constant'] = 0.001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.LLT_model: - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['time_step'] = None - pars['number_of_iterations'] = None - pars['tolerance_constant'] = None - pars['restrictive_Z_smoothing'] = 0 - - elif algorithm == Regularizer.Algorithm.PatchBased_Regul: - pars['algorithm'] = algorithm - pars['input'] = None - pars['searching_window_ratio'] = None - pars['similarity_window_ratio'] = None - pars['PB_filtering_parameter'] = None - pars['regularization_parameter'] = None - - elif algorithm == Regularizer.Algorithm.TGV_PD: - pars['algorithm'] = algorithm - pars['input'] = None - pars['first_order_term'] = None - pars['second_order_term'] = None - pars['number_of_iterations'] = None - pars['regularization_parameter'] = None - - else: - raise Exception('Unknown regularizer algorithm') - - return pars - # parsForAlgorithm - - def setParameter(self, **kwargs): - '''set named parameter for the regularization engine - - raises Exception if the named parameter is not recognized - Typical usage is: - - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - reg.setParameter(input=u0) - reg.setParameter(regularization_parameter=10.) - - it can be also used as - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - reg.setParameter(input=u0 , regularization_parameter=10.) - ''' - - for key , value in kwargs.items(): - if key in self.pars.keys(): - self.pars[key] = value - else: - raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) - # setParameter - - def getParameter(self, **kwargs): - ret = {} - for key , value in kwargs.items(): - if key in self.pars.keys(): - ret[key] = self.pars[key] - else: - raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) - # setParameter - - - def __call__(self, input = None, regularization_parameter = None, **kwargs): - '''Actual call for the regularizer. - - One can either set the regularization parameters first and then call the - algorithm or set the regularization parameter during the call (as - is done in the static methods). - ''' - - if kwargs is not None: - for key, value in kwargs.items(): - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - if input is not None: - self.pars['input'] = input - if regularization_parameter is not None: - self.pars['regularization_parameter'] = regularization_parameter - - if self.debug: - print ("--------------------------------------------------") - for key, value in self.pars.items(): - if key== 'algorithm' : - print("{0} = {1}".format(key, value.__name__)) - elif key == 'input': - print("{0} = {1}".format(key, np.shape(value))) - else: - print("{0} = {1}".format(key, value)) - - - if None in self.pars: - raise Exception("Not all parameters have been provided") - - input = self.pars['input'] - regularization_parameter = self.pars['regularization_parameter'] - if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : - return self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.FGP_TV : - return self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.LLT_model : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, - regularization_parameter, - self.pars['time_step'] , - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['restrictive_Z_smoothing'] ) - elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, regularization_parameter, - self.pars['searching_window_ratio'] , - self.pars['similarity_window_ratio'] , - self.pars['PB_filtering_parameter']) - elif self.algorithm == Regularizer.Algorithm.TGV_PD : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - if len(np.shape(input)) == 2: - return self.algorithm(input, regularization_parameter, - self.pars['first_order_term'] , - self.pars['second_order_term'] , - self.pars['number_of_iterations']) - elif len(np.shape(input)) == 3: - #assuming it's 3D - # run independent calls on each slice - out3d = input.copy() - for i in range(np.shape(input)[2]): - out = self.algorithm(input, regularization_parameter, - self.pars['first_order_term'] , - self.pars['second_order_term'] , - self.pars['number_of_iterations']) - # copy the result in the 3D image - out3d.T[i] = out[0].copy() - # append the rest of the info that the algorithm returns - output = [out3d] - for i in range(1,len(out)): - output.append(out[i]) - return output - - - - - - # __call__ - - @staticmethod - def SplitBregman_TV(input, regularization_parameter , **kwargs): - start_time = timeit.default_timer() - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - out = list( reg(input, regularization_parameter, **kwargs) ) - out.append(reg.pars) - txt = reg.printParametersToString() - txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) - out.append(txt) - return out - - @staticmethod - def FGP_TV(input, regularization_parameter , **kwargs): - start_time = timeit.default_timer() - reg = Regularizer(Regularizer.Algorithm.FGP_TV) - out = list( reg(input, regularization_parameter, **kwargs) ) - out.append(reg.pars) - txt = reg.printParametersToString() - txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) - out.append(txt) - return out - - @staticmethod - def LLT_model(input, regularization_parameter , time_step, number_of_iterations, - tolerance_constant, restrictive_Z_smoothing=0): - start_time = timeit.default_timer() - reg = Regularizer(Regularizer.Algorithm.LLT_model) - out = list( reg(input, regularization_parameter, time_step=time_step, - number_of_iterations=number_of_iterations, - tolerance_constant=tolerance_constant, - restrictive_Z_smoothing=restrictive_Z_smoothing) ) - out.append(reg.pars) - txt = reg.printParametersToString() - txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) - out.append(txt) - return out - - @staticmethod - def PatchBased_Regul(input, regularization_parameter, - searching_window_ratio, - similarity_window_ratio, - PB_filtering_parameter): - start_time = timeit.default_timer() - reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) - out = list( reg(input, - regularization_parameter, - searching_window_ratio=searching_window_ratio, - similarity_window_ratio=similarity_window_ratio, - PB_filtering_parameter=PB_filtering_parameter ) - ) - out.append(reg.pars) - txt = reg.printParametersToString() - txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) - out.append(txt) - return out - - @staticmethod - def TGV_PD(input, regularization_parameter , first_order_term, - second_order_term, number_of_iterations): - start_time = timeit.default_timer() - - reg = Regularizer(Regularizer.Algorithm.TGV_PD) - out = list( reg(input, regularization_parameter, - first_order_term=first_order_term, - second_order_term=second_order_term, - number_of_iterations=number_of_iterations) ) - out.append(reg.pars) - txt = reg.printParametersToString() - txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) - out.append(txt) - - return out - - def printParametersToString(self): - txt = r'' - for key, value in self.pars.items(): - if key== 'algorithm' : - txt += "{0} = {1}".format(key, value.__name__) - elif key == 'input': - txt += "{0} = {1}".format(key, np.shape(value)) - else: - txt += "{0} = {1}".format(key, value) - txt += '\n' - return txt - |