diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2018-02-18 20:32:31 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-18 20:32:31 +0000 |
commit | 1f84f8da7500b4f442aa546247989f8b094e26d7 (patch) | |
tree | 429ab01c97ead216c29f50f9249a0409805a8a6e | |
parent | d763dffb7cfc799ebeb0e9666702526a7f56d9ff (diff) | |
download | framework-1f84f8da7500b4f442aa546247989f8b094e26d7.tar.gz framework-1f84f8da7500b4f442aa546247989f8b094e26d7.tar.bz2 framework-1f84f8da7500b4f442aa546247989f8b094e26d7.tar.xz framework-1f84f8da7500b4f442aa546247989f8b094e26d7.zip |
Arithmetic operators and python 2/3 compatibility (#8)
* Added arithmetic operations WIP
* fixes for python 2 and 3
added arithmetic operators
* updated operators and test
* updated tests
-rw-r--r-- | Wrappers/Python/ccpi/framework.py | 225 | ||||
-rw-r--r-- | Wrappers/Python/test/regularizers.py | 59 |
2 files changed, 245 insertions, 39 deletions
diff --git a/Wrappers/Python/ccpi/framework.py b/Wrappers/Python/ccpi/framework.py index 035c729..3cfa2a0 100644 --- a/Wrappers/Python/ccpi/framework.py +++ b/Wrappers/Python/ccpi/framework.py @@ -16,6 +16,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import division import abc import numpy import sys @@ -91,8 +93,10 @@ class CCPiBaseClass(ABC): if self.debug: print ("{0}: {1}".format(self.__class__.__name__, msg)) -class DataSet(): - '''Generic class to hold data''' +class DataSet(object): + '''Generic class to hold data + + Data is currently held in a numpy arrays''' def __init__ (self, array, deep_copy=True, dimension_labels=None, **kwargs): @@ -199,8 +203,174 @@ class DataSet(): numpy.shape(array))) self.array = array[:] - - + def checkDimensions(self, other): + return self.shape == other.shape + + def __add__(self, other): + if issubclass(type(other), DataSet): + if self.checkDimensions(other): + out = self.as_array() + other.as_array() + return DataSet(out, + deep_copy=True, + dimension_labels=self.dimension_labels) + else: + raise ValueError('Wrong shape: {0} and {1}'.format(self.shape, + other.shape)) + elif isinstance(other, (int, float, complex)): + return DataSet(self.as_array() + other, + deep_copy=True, + dimension_labels=self.dimension_labels) + else: + raise TypeError('Cannot {0} DataSet with {1}'.format("add" , + type(other))) + # __add__ + + def __sub__(self, other): + if issubclass(type(other), DataSet): + if self.checkDimensions(other): + out = self.as_array() - other.as_array() + return DataSet(out, + deep_copy=True, + dimension_labels=self.dimension_labels) + else: + raise ValueError('Wrong shape: {0} and {1}'.format(self.shape, + other.shape)) + elif isinstance(other, (int, float, complex)): + return DataSet(self.as_array() - other, + deep_copy=True, + dimension_labels=self.dimension_labels) + else: + raise TypeError('Cannot {0} DataSet with {1}'.format("subtract" , + type(other))) + # __sub__ + def __truediv__(self,other): + return self.__div__(other) + + def __div__(self, other): + print ("calling __div__") + if issubclass(type(other), DataSet): + if self.checkDimensions(other): + out = self.as_array() / other.as_array() + return DataSet(out, + deep_copy=True, + dimension_labels=self.dimension_labels) + else: + raise ValueError('Wrong shape: {0} and {1}'.format(self.shape, + other.shape)) + elif isinstance(other, (int, float, complex)): + return DataSet(self.as_array() / other, + deep_copy=True, + dimension_labels=self.dimension_labels) + else: + raise TypeError('Cannot {0} DataSet with {1}'.format("divide" , + type(other))) + # __div__ + + def __pow__(self, other): + if issubclass(type(other), DataSet): + if self.checkDimensions(other): + out = self.as_array() ** other.as_array() + return DataSet(out, + deep_copy=True, + dimension_labels=self.dimension_labels) + else: + raise ValueError('Wrong shape: {0} and {1}'.format(self.shape, + other.shape)) + elif isinstance(other, (int, float, complex)): + return DataSet(self.as_array() ** other, + deep_copy=True, + dimension_labels=self.dimension_labels) + else: + raise TypeError('Cannot {0} DataSet with {1}'.format("power" , + type(other))) + # __pow__ + + def __mul__(self, other): + if issubclass(type(other), DataSet): + if self.checkDimensions(other): + out = self.as_array() * other.as_array() + return DataSet(out, + deep_copy=True, + dimension_labels=self.dimension_labels) + else: + raise ValueError('Wrong shape: {0} and {1}'.format(self.shape, + other.shape)) + elif isinstance(other, (int, float, complex)): + return DataSet(self.as_array() * other, + deep_copy=True, + dimension_labels=self.dimension_labels) + else: + raise TypeError('Cannot {0} DataSet with {1}'.format("multiply" , + type(other))) + # __mul__ + + + #def __abs__(self): + # operation = FM.OPERATION.ABS + # return self.callFieldMath(operation, None, self.mask, self.maskOnValue) + # __abs__ + + # reverse operand + def __radd__(self, other): + return self + other + # __radd__ + + def __rsub__(self, other): + return (-1 * self) + other + # __rsub__ + + def __rmul__(self, other): + return self * other + # __rmul__ + + def __rdiv__(self, other): + print ("call __rdiv__") + return pow(self / other, -1) + # __rdiv__ + def __rtruediv__(self, other): + return self.__rdiv__(other) + + def __rpow__(self, other): + if isinstance(other, (int, float)) : + fother = numpy.ones(numpy.shape(self.array)) * other + return DataSet(fother ** self.array , + dimension_labels=self.dimension_labels) + elif issubclass(other, DataSet): + if self.checkDimensions(other): + return DataSet(other.as_array() ** self.array , + dimension_labels=self.dimension_labels) + else: + raise ValueError('Dimensions do not match') + # __rpow__ + + + # in-place arithmetic operators: + # (+=, -=, *=, /= , //=, + + def __iadd__(self, other): + return self + other + # __iadd__ + + def __imul__(self, other): + return self * other + # __imul__ + + def __isub__(self, other): + return self - other + # __isub__ + + def __idiv__(self, other): + print ("call __idiv__") + return self / other + # __idiv__ + + def __str__ (self): + repres = "" + repres += "Number of dimensions: {0}\n".format(self.number_of_dimensions) + repres += "Shape: {0}\n".format(self.shape) + repres += "Axis labels: {0}\n".format(self.dimension_labels) + repres += "Representation: {0}\n".format(self.array) + return repres @@ -219,7 +389,9 @@ class VolumeData(DataSet): raise ValueError('Number of dimensions are not 2 or 3: {0}'\ .format(array.number_of_dimensions)) - DataSet.__init__(self, array.as_array(), deep_copy, + #DataSet.__init__(self, array.as_array(), deep_copy, + # array.dimension_labels, **kwargs) + super(VolumeData, self).__init__(array.as_array(), deep_copy, array.dimension_labels, **kwargs) elif type(array) == numpy.ndarray: if not ( array.ndim == 3 or array.ndim == 2 ): @@ -236,8 +408,9 @@ class VolumeData(DataSet): dimension_labels = ['horizontal' , 'vertical'] - DataSet.__init__(self, array, deep_copy, dimension_labels, **kwargs) - + #DataSet.__init__(self, array, deep_copy, dimension_labels, **kwargs) + super(VolumeData, self).__init__(array, deep_copy, + dimension_labels, **kwargs) # load metadata from kwargs if present for key, value in kwargs.items(): @@ -287,7 +460,7 @@ class SinogramData(DataSet): # assume it is parallel beam pass -class DataSetProcessor(): +class DataSetProcessor(object): '''Defines a generic DataSet processor accepts DataSet as inputs and @@ -341,6 +514,7 @@ class DataSetProcessor(): elif self.mTime > self.runTime: shouldRun = True + # CHECK this if self.store_output and shouldRun: self.runTime = datetime.now() self.output = self.process() @@ -405,8 +579,8 @@ class AX(DataSetProcessor): 'input':None, } - DataSetProcessor.__init__(self, **kwargs) - + #DataSetProcessor.__init__(self, **kwargs) + super(AX, self).__init__(**kwargs) def checkInput(self, dataset): return True @@ -438,8 +612,8 @@ class PixelByPixelDataSetProcessor(DataSetProcessor): kwargs = {'pyfunc':None, 'input':None, } - DataSetProcessor.__init__(self, **kwargs) - + #DataSetProcessor.__init__(self, **kwargs) + super(PixelByPixelDataSetProcessor, self).__init__(**kwargs) def checkInput(self, dataset): return True @@ -530,4 +704,29 @@ if __name__ == '__main__': chain.setInputProcessor(ax) print ("chain in {0} out {1}".format(ax.getOutput().as_array(), chain.getOutput().as_array())) -
\ No newline at end of file + # testing arithmetic operations + + print (b) + print ((b+1)) + print ((1+b)) + + print (b) + print ((b*2)) + + print (b) + print ((2*b)) + + print (b) + print ((b/2)) + + print (b) + print ((2/b)) + + print (b) + print ((b**2)) + + print (b) + print ((2**b)) + + +
\ No newline at end of file diff --git a/Wrappers/Python/test/regularizers.py b/Wrappers/Python/test/regularizers.py index 25873c7..04ac3aa 100644 --- a/Wrappers/Python/test/regularizers.py +++ b/Wrappers/Python/test/regularizers.py @@ -17,6 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function import matplotlib.pyplot as plt import numpy as np import os @@ -27,7 +28,7 @@ from ccpi.filters.cpu_regularizers_boost import SplitBregman_TV , FGP_TV ,\ #from ccpi.filters.cpu_regularizers_cython import some try: - from ccpi.filter import gpu_regularizers as gpu + from ccpi.filters import gpu_regularizers as gpu class PatchBasedRegGPU(DataSetProcessor23D): '''Regularizers DataSetProcessor for PatchBasedReg @@ -40,7 +41,7 @@ try: 'similarity_window_ratio': None, 'PB_filtering_parameter': None } - DataSetProcessor.__init__(self, **attributes) + super(PatchBasedRegGPU, self).__init__(**attributes) def process(self): @@ -67,7 +68,7 @@ try: 'similarity_window_ratio': None, 'PB_filtering_parameter': None } - DataSetProcessor.__init__(self, **attributes) + super(Diff4thHajiaboli, self).__init__(self, **attributes) def process(self): @@ -97,7 +98,7 @@ class SBTV(DataSetProcessor23D): 'tolerance_constant': 0.0001, 'TV_penalty':0 } - DataSetProcessor.__init__(self, **attributes) + super(SBTV , self).__init__(**attributes) def process(self): @@ -124,7 +125,7 @@ class FGPTV(DataSetProcessor23D): 'tolerance_constant': 0.0001, 'TV_penalty':0 } - DataSetProcessor.__init__(self, **attributes) + super(FGPTV, self).__init__(**attributes) def process(self): @@ -153,7 +154,7 @@ class LLT(DataSetProcessor23D): 'tolerance_constant': 0, 'restrictive_Z_smoothing': None } - DataSetProcessor.__init__(self, **attributes) + super(LLT, self).__init__(**attributes) def process(self): @@ -182,7 +183,7 @@ class PatchBasedReg(DataSetProcessor23D): 'similarity_window_ratio': None, 'PB_filtering_parameter': None } - DataSetProcessor.__init__(self, **attributes) + super(PatchBasedReg, self).__init__(**attributes) def process(self): @@ -204,13 +205,17 @@ class TGVPD(DataSetProcessor23D): ''' - def __init__(self): + def __init__(self,**kwargs): attributes = {'regularization_parameter':None, 'first_order_term': None, 'second_order_term': None, 'number_of_iterations': None } - DataSetProcessor.__init__(self, **attributes) + for key, value in kwargs.items(): + if key in attributes.keys(): + attributes[key] = value + + super(TGVPD, self).__init__(**attributes) def process(self): @@ -247,15 +252,17 @@ if __name__ == '__main__': "lena_gray_512.tif") Im = plt.imread(filename) Im = np.asarray(Im, dtype='float32') - - perc = 0.15 + + Im = Im/255 + + perc = 0.075 u0 = Im + np.random.normal(loc = Im , - scale = perc * Im , - size = np.shape(Im)) + scale = perc * Im , + size = np.shape(Im)) # map the u0 u0->u0>0 f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) u0 = f(u0).astype('float32') - + lena = DataSet(u0, False, ['X','Y']) ## plot @@ -272,9 +279,9 @@ if __name__ == '__main__': reg3 = SBTV() - reg3.number_of_iterations = 350 - reg3.tolerance_constant = 0.01 - reg3.regularization_parameter = 40 + reg3.number_of_iterations = 40 + reg3.tolerance_constant = 0.0001 + reg3.regularization_parameter = 15 reg3.TV_penalty = 0 reg3.setInput(lena) dataprocessoroutput = reg3.getOutput() @@ -293,9 +300,9 @@ if __name__ == '__main__': ########################################################################## reg4 = FGPTV() - reg4.number_of_iterations = 350 - reg4.tolerance_constant = 0.01 - reg4.regularization_parameter = 40 + reg4.number_of_iterations = 200 + reg4.tolerance_constant = 1e-4 + reg4.regularization_parameter = 0.05 reg4.TV_penalty = 0 reg4.setInput(lena) dataprocessoroutput2 = reg4.getOutput() @@ -313,10 +320,10 @@ if __name__ == '__main__': ########################################################################### reg6 = LLT() - reg6.regularization_parameter = 25 - reg6.time_step = 0.0003 - reg6.number_of_iterations = 300 - reg6.tolerance_constant = 0.001 + reg6.regularization_parameter = 5 + reg6.time_step = 0.00035 + reg6.number_of_iterations = 350 + reg6.tolerance_constant = 0.0001 reg6.restrictive_Z_smoothing = 0 reg6.setInput(lena) llt = reg6.getOutput() @@ -336,7 +343,7 @@ if __name__ == '__main__': reg7.regularization_parameter = 0.05 reg7.searching_window_ratio = 3 reg7.similarity_window_ratio = 1 - reg7.PB_filtering_parameter = 0.08 + reg7.PB_filtering_parameter = 0.06 reg7.setInput(lena) pbr = reg7.getOutput() # plot @@ -352,7 +359,7 @@ if __name__ == '__main__': ########################################################################### reg5 = TGVPD() - reg5.regularization_parameter = 0.05 + reg5.regularization_parameter = 0.07 reg5.first_order_term = 1.3 reg5.second_order_term = 1 reg5.number_of_iterations = 550 |