diff options
| -rwxr-xr-x | Wrappers/Python/ccpi/framework/BlockDataContainer.py | 19 | ||||
| -rwxr-xr-x | Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py | 39 | ||||
| -rw-r--r-- | Wrappers/Python/wip/CGLS_tikhonov.py | 45 | 
3 files changed, 58 insertions, 45 deletions
diff --git a/Wrappers/Python/ccpi/framework/BlockDataContainer.py b/Wrappers/Python/ccpi/framework/BlockDataContainer.py index 8152bff..d509d25 100755 --- a/Wrappers/Python/ccpi/framework/BlockDataContainer.py +++ b/Wrappers/Python/ccpi/framework/BlockDataContainer.py @@ -19,14 +19,8 @@ class BlockDataContainer(object):      '''Class to hold DataContainers as column vector'''
      __array_priority__ = 1
      def __init__(self, *args, **kwargs):
 -        '''containers must be consistent in shape'''
 +        ''''''
          self.containers = args
 -        for i, co in enumerate(args):
 -            if i == 0:
 -                shape = co.shape
 -            else:
 -                if shape != co.shape:
 -                    raise ValueError('Expected shape is {} got {}'.format(shape, co.shape))
          self.index = 0
          #shape = kwargs.get('shape', None)
          #if shape is None:
 @@ -38,7 +32,7 @@ class BlockDataContainer(object):          if len(args) != n_elements:
              raise ValueError(
                      'Dimension and size do not match: expected {} got {}'
 -                    .format(n_elements,len(args)))
 +                    .format(n_elements, len(args)))
      def __iter__(self):
 @@ -60,7 +54,6 @@ class BlockDataContainer(object):          if isinstance(other, Number):
              return True   
          elif isinstance(other, list):
 -            # TODO look elements should be numbers
              for ot in other:
                  if not isinstance(ot, (Number,\
                                   numpy.int, numpy.int8, numpy.int16, numpy.int32, numpy.int64,\
 @@ -72,10 +65,12 @@ class BlockDataContainer(object):          elif isinstance(other, numpy.ndarray):
              return self.shape == other.shape
          return len(self.containers) == len(other.containers)
 +
      def get_item(self, row):
          if row > self.shape[0]:
              raise ValueError('Requested row {} > max {}'.format(row, self.shape[0]))
          return self.containers[row]
 +
      def __getitem__(self, row):
          return self.get_item(row)
 @@ -308,8 +303,4 @@ class BlockDataContainer(object):      def __itruediv__(self, other):
          '''Inline truedivision'''
          return self.__idiv__(other)
 -    #@property
 -    #def T(self):
 -    #   '''return the transposed of self'''
 -    #   shape = (self.shape[1], self.shape[0])
 -    #   return type(self)(*self.containers, shape=shape)
 +
 diff --git a/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py b/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py index c9bf794..f102f1e 100755 --- a/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py +++ b/Wrappers/Python/ccpi/optimisation/operators/BlockOperator.py @@ -81,6 +81,28 @@ class BlockOperator(Operator):                      prod += self.get_item(row,col).direct(x.get_item(col))              res.append(prod)          return BlockDataContainer(*res, shape=shape) + +    def adjoint(self, x, out=None): +        '''Adjoint operation for the BlockOperator + +        BlockOperator may contain both LinearOperator and Operator +        This method exists in BlockOperator as it is not known what type of +        Operator it will contain. + +        Raises: ValueError if the contained Operators are not linear +        ''' +        if not functools.reduce(lambda x,y: x and y, self.operators.is_linear(), True): +            raise ValueError('Not all operators in Block are linear.') +        shape = self.get_output_shape(x.shape, adjoint=True) +        res = [] +        for row in range(self.shape[1]): +            for col in range(self.shape[0]): +                if col == 0: +                    prod = self.get_item(col,row).adjoint(x.get_item(row)) +                else: +                    prod += self.get_item(col,row).adjoint(x.get_item(row)) +            res.append(prod) +        return BlockDataContainer(*res, shape=shape)      def get_output_shape(self, xshape, adjoint=False):          sshape = self.shape[1] @@ -142,22 +164,7 @@ class BlockLinearOperator(BlockOperator):              if not op.is_linear():                  raise ValueError('Operator {} must be LinearOperator'.format(i))          super(BlockLinearOperator, self).__init__(*args, **kwargs) -     -    def adjoint(self, x, out=None): -        '''Adjoint operation for the BlockOperator -         -        only available on BlockLinearOperator -        ''' -        shape = self.get_output_shape(x.shape, adjoint=True) -        res = [] -        for row in range(self.shape[1]): -            for col in range(self.shape[0]): -                if col == 0: -                    prod = self.get_item(col,row).adjoint(x.get_item(row)) -                else: -                    prod += self.get_item(col,row).adjoint(x.get_item(row)) -            res.append(prod) -        return BlockDataContainer(*res, shape=shape) + diff --git a/Wrappers/Python/wip/CGLS_tikhonov.py b/Wrappers/Python/wip/CGLS_tikhonov.py index 7178510..f247896 100644 --- a/Wrappers/Python/wip/CGLS_tikhonov.py +++ b/Wrappers/Python/wip/CGLS_tikhonov.py @@ -105,46 +105,57 @@ B = BlockDataContainer(b,  # setup a tomo identity  Ibig = 1e5 * TomoIdentity(geometry=ig)  Ismall = 1e-5 * TomoIdentity(geometry=ig) +Iok = 1e1 * TomoIdentity(geometry=ig)  # composite operator  Kbig = BlockOperator(A, Ibig, shape=(2,1))  Ksmall = BlockOperator(A, Ismall, shape=(2,1)) -     +Kok = BlockOperator(A, Iok, shape=(2,1)) +  #out = K.direct(X_init)  f = Norm2sq(Kbig,B)  f.L = 0.00003  fsmall = Norm2sq(Ksmall,B) -f.L = 0.00003 -     +fsmall.L = 0.00003 + +fok = Norm2sq(Kok,B) +fok.L = 0.00003 +  simplef = Norm2sq(A, b)  simplef.L = 0.00003  gd = GradientDescent( x_init=x_init, objective_function=simplef,                       rate=simplef.L)  gd.max_iteration = 10 -     + +Kbig.direct(X_init) +Kbig.adjoint(B)  cg = CGLS()  cg.set_up(X_init, Kbig, B ) -cg.max_iteration = 1 +cg.max_iteration = 5  cgsmall = CGLS()  cgsmall.set_up(X_init, Ksmall, B ) -cgsmall.max_iteration = 1 +cgsmall.max_iteration = 5  cgs = CGLS()  cgs.set_up(x_init, A, b )  cgs.max_iteration = 6 -# #     + +cgok = CGLS() +cgok.set_up(X_init, Kok, B ) +cgok.max_iteration = 6 +# #  #out.__isub__(B)  #out2 = K.adjoint(out)  #(2.0*self.c)*self.A.adjoint( self.A.direct(x) - self.b )  for _ in gd: -    print ("iteration {} {}".format(gd.iteration, gd.get_current_loss())) +    print ("iteration {} {}".format(gd.iteration, gd.get_last_loss()))  cg.run(10, lambda it,val: print ("iteration {} objective {}".format(it,val))) @@ -152,6 +163,7 @@ cgs.run(10, lambda it,val: print ("iteration {} objective {}".format(it,val)))  cgsmall.run(10, lambda it,val: print ("iteration {} objective {}".format(it,val)))  cgsmall.run(10, lambda it,val: print ("iteration {} objective {}".format(it,val))) +cgok.run(10, verbose=True)  # #    for _ in cg:  #    print ("iteration {} {}".format(cg.iteration, cg.get_current_loss()))  # #     @@ -164,19 +176,22 @@ cgsmall.run(10, lambda it,val: print ("iteration {} objective {}".format(it,val)  #    print ("iteration {} {}".format(cgs.iteration, cgs.get_current_loss()))  # #      fig = plt.figure() -plt.subplot(1,5,1) +plt.subplot(1,6,1)  plt.imshow(Phantom.subset(vertical=0).as_array())  plt.title('Simulated Phantom') -plt.subplot(1,5,2) +plt.subplot(1,6,2)  plt.imshow(gd.get_output().subset(vertical=0).as_array())  plt.title('Simple Gradient Descent') -plt.subplot(1,5,3) +plt.subplot(1,6,3)  plt.imshow(cgs.get_output().subset(vertical=0).as_array())  plt.title('Simple CGLS') -plt.subplot(1,5,4) -plt.imshow(cg.get_output().get_item(0,0).subset(vertical=0).as_array()) +plt.subplot(1,6,4) +plt.imshow(cg.get_output().get_item(0).subset(vertical=0).as_array())  plt.title('Composite CGLS\nbig lambda') -plt.subplot(1,5,5) -plt.imshow(cgsmall.get_output().get_item(0,0).subset(vertical=0).as_array()) +plt.subplot(1,6,5) +plt.imshow(cgsmall.get_output().get_item(0).subset(vertical=0).as_array())  plt.title('Composite CGLS\nsmall lambda') +plt.subplot(1,6,6) +plt.imshow(cgok.get_output().get_item(0).subset(vertical=0).as_array()) +plt.title('Composite CGLS\nok lambda')  plt.show()  | 
