From 74891953a24416b9680dee13354d57b42cd8f63c Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 18 Feb 2019 15:23:29 +0000 Subject: added reverse multiplication of operator with number --- Wrappers/Python/ccpi/optimisation/ops.py | 37 +++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) (limited to 'Wrappers/Python') diff --git a/Wrappers/Python/ccpi/optimisation/ops.py b/Wrappers/Python/ccpi/optimisation/ops.py index 450b084..3845621 100755 --- a/Wrappers/Python/ccpi/optimisation/ops.py +++ b/Wrappers/Python/ccpi/optimisation/ops.py @@ -24,26 +24,49 @@ from ccpi.framework import AcquisitionData from ccpi.framework import ImageData from ccpi.framework import ImageGeometry from ccpi.framework import AcquisitionGeometry - +from numbers import Number # Maybe operators need to know what types they take as inputs/outputs # to not just use generic DataContainer class Operator(object): + '''Operator that maps from a space X -> Y''' + def __init__(self, **kwargs): + self.scalar = 1 + def is_linear(self): + '''Returns if the operator is linear''' + return False def direct(self,x, out=None): - return x - def adjoint(self,x, out=None): - return x + raise NotImplementedError def size(self): # To be defined for specific class raise NotImplementedError - def get_max_sing_val(self): + def norm(self): raise NotImplementedError def allocate_direct(self): + '''Allocates memory on the Y space''' raise NotImplementedError def allocate_adjoint(self): + '''Allocates memory on the X space''' + raise NotImplementedError + def range_dim(self): raise NotImplementedError + def domain_dim(self): + raise NotImplementedError + def __rmul__(self, other): + '''reverse multiplication of Operator with number sets the variable scalar in the Operator''' + assert isinstance(other, Number) + self.scalar = other + return self +class LinearOperator(Operator): + '''Operator that maps from a space X -> Y''' + def is_linear(self): + '''Returns if the operator is linear''' + return True + def adjoint(self,x, out=None): + raise NotImplementedError + class Identity(Operator): def __init__(self): self.s1 = 1.0 @@ -75,12 +98,16 @@ class TomoIdentity(Operator): super(TomoIdentity, self).__init__() def direct(self,x,out=None): + if self.scalar != 1: + x *= self.scalar if out is None: return x.copy() else: out.fill(x) def adjoint(self,x, out=None): + if self.scalar != 1: + x *= self.scalar if out is None: return x.copy() else: -- cgit v1.2.3