summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2017-08-23 12:55:19 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2017-10-11 15:48:33 +0100
commitbc29e0690d856ad9dd147b435d34c5761556a1e5 (patch)
tree8b22b4cf0bd178f4a96ca50b96d81610d62d3d46
parent0611d34c31fa1e706c3bcd7e17651f7555469e00 (diff)
downloadregularization-bc29e0690d856ad9dd147b435d34c5761556a1e5.tar.gz
regularization-bc29e0690d856ad9dd147b435d34c5761556a1e5.tar.bz2
regularization-bc29e0690d856ad9dd147b435d34c5761556a1e5.tar.xz
regularization-bc29e0690d856ad9dd147b435d34c5761556a1e5.zip
Regularizer.pyfirst commit
-rw-r--r--src/Python/Regularizer.py322
1 files changed, 322 insertions, 0 deletions
diff --git a/src/Python/Regularizer.py b/src/Python/Regularizer.py
new file mode 100644
index 0000000..15dbbb4
--- /dev/null
+++ b/src/Python/Regularizer.py
@@ -0,0 +1,322 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Tue Aug 8 14:26:00 2017
+
+@author: ofn77899
+"""
+
+import regularizers
+import numpy as np
+from enum import Enum
+import timeit
+
+class Regularizer():
+ '''Class to handle regularizer algorithms to be used during reconstruction
+
+ Currently 5 CPU (OMP) regularization algorithms are available:
+
+ 1) SplitBregman_TV
+ 2) FGP_TV
+ 3) LLT_model
+ 4) PatchBased_Regul
+ 5) TGV_PD
+
+ Usage:
+ the regularizer can be invoked as object or as static method
+ Depending on the actual regularizer the input parameter may vary, and
+ a different default setting is defined.
+ reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+
+ out = reg(input=u0, regularization_parameter=10., number_of_iterations=30,
+ tolerance_constant=1e-4,
+ TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+
+ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10.,
+ number_of_iterations=30, tolerance_constant=1e-4,
+ TV_Penalty=Regularizer.TotalVariationPenalty.l1)
+
+ A number of optional parameters can be passed or skipped
+ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. )
+
+ '''
+ class Algorithm(Enum):
+ SplitBregman_TV = regularizers.SplitBregman_TV
+ FGP_TV = regularizers.FGP_TV
+ LLT_model = regularizers.LLT_model
+ PatchBased_Regul = regularizers.PatchBased_Regul
+ TGV_PD = regularizers.TGV_PD
+ # Algorithm
+
+ class TotalVariationPenalty(Enum):
+ isotropic = 0
+ l1 = 1
+ # TotalVariationPenalty
+
+ def __init__(self , algorithm, debug = True):
+ self.setAlgorithm ( algorithm )
+ self.debug = debug
+ # __init__
+
+ def setAlgorithm(self, algorithm):
+ self.algorithm = algorithm
+ self.pars = self.getDefaultParsForAlgorithm(algorithm)
+ # setAlgorithm
+
+ def getDefaultParsForAlgorithm(self, algorithm):
+ pars = dict()
+
+ if algorithm == Regularizer.Algorithm.SplitBregman_TV :
+ pars['algorithm'] = algorithm
+ pars['input'] = None
+ pars['regularization_parameter'] = None
+ pars['number_of_iterations'] = 35
+ pars['tolerance_constant'] = 0.0001
+ pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+
+ elif algorithm == Regularizer.Algorithm.FGP_TV :
+ pars['algorithm'] = algorithm
+ pars['input'] = None
+ pars['regularization_parameter'] = None
+ pars['number_of_iterations'] = 50
+ pars['tolerance_constant'] = 0.001
+ pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic
+
+ elif algorithm == Regularizer.Algorithm.LLT_model:
+ pars['algorithm'] = algorithm
+ pars['input'] = None
+ pars['regularization_parameter'] = None
+ pars['time_step'] = None
+ pars['number_of_iterations'] = None
+ pars['tolerance_constant'] = None
+ pars['restrictive_Z_smoothing'] = 0
+
+ elif algorithm == Regularizer.Algorithm.PatchBased_Regul:
+ pars['algorithm'] = algorithm
+ pars['input'] = None
+ pars['searching_window_ratio'] = None
+ pars['similarity_window_ratio'] = None
+ pars['PB_filtering_parameter'] = None
+ pars['regularization_parameter'] = None
+
+ elif algorithm == Regularizer.Algorithm.TGV_PD:
+ pars['algorithm'] = algorithm
+ pars['input'] = None
+ pars['first_order_term'] = None
+ pars['second_order_term'] = None
+ pars['number_of_iterations'] = None
+ pars['regularization_parameter'] = None
+
+ else:
+ raise Exception('Unknown regularizer algorithm')
+
+ return pars
+ # parsForAlgorithm
+
+ def setParameter(self, **kwargs):
+ '''set named parameter for the regularization engine
+
+ raises Exception if the named parameter is not recognized
+ Typical usage is:
+
+ reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+ reg.setParameter(input=u0)
+ reg.setParameter(regularization_parameter=10.)
+
+ it can be also used as
+ reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+ reg.setParameter(input=u0 , regularization_parameter=10.)
+ '''
+
+ for key , value in kwargs.items():
+ if key in self.pars.keys():
+ self.pars[key] = value
+ else:
+ raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))
+ # setParameter
+
+ def getParameter(self, **kwargs):
+ ret = {}
+ for key , value in kwargs.items():
+ if key in self.pars.keys():
+ ret[key] = self.pars[key]
+ else:
+ raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key))
+ # setParameter
+
+
+ def __call__(self, input = None, regularization_parameter = None, **kwargs):
+ '''Actual call for the regularizer.
+
+ One can either set the regularization parameters first and then call the
+ algorithm or set the regularization parameter during the call (as
+ is done in the static methods).
+ '''
+
+ if kwargs is not None:
+ for key, value in kwargs.items():
+ #print("{0} = {1}".format(key, value))
+ self.pars[key] = value
+
+ if input is not None:
+ self.pars['input'] = input
+ if regularization_parameter is not None:
+ self.pars['regularization_parameter'] = regularization_parameter
+
+ if self.debug:
+ print ("--------------------------------------------------")
+ for key, value in self.pars.items():
+ if key== 'algorithm' :
+ print("{0} = {1}".format(key, value.__name__))
+ elif key == 'input':
+ print("{0} = {1}".format(key, np.shape(value)))
+ else:
+ print("{0} = {1}".format(key, value))
+
+
+ if None in self.pars:
+ raise Exception("Not all parameters have been provided")
+
+ input = self.pars['input']
+ regularization_parameter = self.pars['regularization_parameter']
+ if self.algorithm == Regularizer.Algorithm.SplitBregman_TV :
+ return self.algorithm(input, regularization_parameter,
+ self.pars['number_of_iterations'],
+ self.pars['tolerance_constant'],
+ self.pars['TV_penalty'].value )
+ elif self.algorithm == Regularizer.Algorithm.FGP_TV :
+ return self.algorithm(input, regularization_parameter,
+ self.pars['number_of_iterations'],
+ self.pars['tolerance_constant'],
+ self.pars['TV_penalty'].value )
+ elif self.algorithm == Regularizer.Algorithm.LLT_model :
+ #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+ # no default
+ return self.algorithm(input,
+ regularization_parameter,
+ self.pars['time_step'] ,
+ self.pars['number_of_iterations'],
+ self.pars['tolerance_constant'],
+ self.pars['restrictive_Z_smoothing'] )
+ elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul :
+ #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+ # no default
+ return self.algorithm(input, regularization_parameter,
+ self.pars['searching_window_ratio'] ,
+ self.pars['similarity_window_ratio'] ,
+ self.pars['PB_filtering_parameter'])
+ elif self.algorithm == Regularizer.Algorithm.TGV_PD :
+ #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher)
+ # no default
+ if len(np.shape(input)) == 2:
+ return self.algorithm(input, regularization_parameter,
+ self.pars['first_order_term'] ,
+ self.pars['second_order_term'] ,
+ self.pars['number_of_iterations'])
+ elif len(np.shape(input)) == 3:
+ #assuming it's 3D
+ # run independent calls on each slice
+ out3d = input.copy()
+ for i in range(np.shape(input)[2]):
+ out = self.algorithm(input, regularization_parameter,
+ self.pars['first_order_term'] ,
+ self.pars['second_order_term'] ,
+ self.pars['number_of_iterations'])
+ # copy the result in the 3D image
+ out3d.T[i] = out[0].copy()
+ # append the rest of the info that the algorithm returns
+ output = [out3d]
+ for i in range(1,len(out)):
+ output.append(out[i])
+ return output
+
+
+
+
+
+ # __call__
+
+ @staticmethod
+ def SplitBregman_TV(input, regularization_parameter , **kwargs):
+ start_time = timeit.default_timer()
+ reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+ out = list( reg(input, regularization_parameter, **kwargs) )
+ out.append(reg.pars)
+ txt = reg.printParametersToString()
+ txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+ out.append(txt)
+ return out
+
+ @staticmethod
+ def FGP_TV(input, regularization_parameter , **kwargs):
+ start_time = timeit.default_timer()
+ reg = Regularizer(Regularizer.Algorithm.FGP_TV)
+ out = list( reg(input, regularization_parameter, **kwargs) )
+ out.append(reg.pars)
+ txt = reg.printParametersToString()
+ txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+ out.append(txt)
+ return out
+
+ @staticmethod
+ def LLT_model(input, regularization_parameter , time_step, number_of_iterations,
+ tolerance_constant, restrictive_Z_smoothing=0):
+ start_time = timeit.default_timer()
+ reg = Regularizer(Regularizer.Algorithm.LLT_model)
+ out = list( reg(input, regularization_parameter, time_step=time_step,
+ number_of_iterations=number_of_iterations,
+ tolerance_constant=tolerance_constant,
+ restrictive_Z_smoothing=restrictive_Z_smoothing) )
+ out.append(reg.pars)
+ txt = reg.printParametersToString()
+ txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+ out.append(txt)
+ return out
+
+ @staticmethod
+ def PatchBased_Regul(input, regularization_parameter,
+ searching_window_ratio,
+ similarity_window_ratio,
+ PB_filtering_parameter):
+ start_time = timeit.default_timer()
+ reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul)
+ out = list( reg(input,
+ regularization_parameter,
+ searching_window_ratio=searching_window_ratio,
+ similarity_window_ratio=similarity_window_ratio,
+ PB_filtering_parameter=PB_filtering_parameter )
+ )
+ out.append(reg.pars)
+ txt = reg.printParametersToString()
+ txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+ out.append(txt)
+ return out
+
+ @staticmethod
+ def TGV_PD(input, regularization_parameter , first_order_term,
+ second_order_term, number_of_iterations):
+ start_time = timeit.default_timer()
+
+ reg = Regularizer(Regularizer.Algorithm.TGV_PD)
+ out = list( reg(input, regularization_parameter,
+ first_order_term=first_order_term,
+ second_order_term=second_order_term,
+ number_of_iterations=number_of_iterations) )
+ out.append(reg.pars)
+ txt = reg.printParametersToString()
+ txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time)
+ out.append(txt)
+
+ return out
+
+ def printParametersToString(self):
+ txt = r''
+ for key, value in self.pars.items():
+ if key== 'algorithm' :
+ txt += "{0} = {1}".format(key, value.__name__)
+ elif key == 'input':
+ txt += "{0} = {1}".format(key, np.shape(value))
+ else:
+ txt += "{0} = {1}".format(key, value)
+ txt += '\n'
+ return txt
+