diff options
| -rwxr-xr-x | Wrappers/Python/ccpi/framework/__init__.py | 1 | ||||
| -rwxr-xr-x | Wrappers/Python/ccpi/framework/framework.py | 3 | ||||
| -rw-r--r-- | Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py | 13 | 
3 files changed, 11 insertions, 6 deletions
| diff --git a/Wrappers/Python/ccpi/framework/__init__.py b/Wrappers/Python/ccpi/framework/__init__.py index 66e2f56..229edb5 100755 --- a/Wrappers/Python/ccpi/framework/__init__.py +++ b/Wrappers/Python/ccpi/framework/__init__.py @@ -15,6 +15,7 @@ from datetime import timedelta, datetime  import warnings
  from functools import reduce
 +
  from .framework import DataContainer
  from .framework import ImageData, AcquisitionData
  from .framework import ImageGeometry, AcquisitionGeometry
 diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py index ae9faf7..07c2ead 100755 --- a/Wrappers/Python/ccpi/framework/framework.py +++ b/Wrappers/Python/ccpi/framework/framework.py @@ -29,6 +29,7 @@ import warnings  from functools import reduce  from numbers import Number +  def find_key(dic, val):      """return the key of dictionary dic given the value"""      return [k for k, v in dic.items() if v == val][0] @@ -496,6 +497,7 @@ class DataContainer(object):      ## algebra       def __add__(self, other, *args, **kwargs):          out = kwargs.get('out', None) +                  if issubclass(type(other), DataContainer):                  if self.check_dimensions(other):                  out = self.as_array() + other.as_array() @@ -601,6 +603,7 @@ class DataContainer(object):                                 deep_copy=True,                                  dimension_labels=self.dimension_labels,                                 geometry=self.geometry) +                      else:              raise TypeError('Cannot {0} DataContainer with {1}'.format("multiply" ,                              type(other))) diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py index 1229c4e..084818c 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py @@ -8,7 +8,7 @@ Created on Mon Feb  4 16:18:06 2019  from ccpi.optimisation.algorithms import Algorithm  from ccpi.framework import ImageData  import numpy as np -#import matplotlib.pyplot as plt +import matplotlib.pyplot as plt  import time  from ccpi.optimisation.operators import BlockOperator  from ccpi.framework import BlockDataContainer @@ -120,12 +120,13 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):      objective = [] +          for i in range(niter):          # Gradient descent, Dual problem solution          y_tmp = y_old + sigma * operator.direct(xbar)          y = f.proximal_conjugate(y_tmp, sigma) -         +          # Gradient ascent, Primal problem solution          x_tmp = x_old - tau * operator.adjoint(y)          x = g.proximal(x_tmp, tau) @@ -135,15 +136,15 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):          x_old = x          y_old = y    -         +                          if i%100==0:              primal = f(operator.direct(x)) + g(x)              dual = -(f.convex_conjugate(y) + g(-1*operator.adjoint(y))) -            print( i, primal, dual) +            print( i, primal, dual, primal-dual) -            plt.imshow(x.as_array()) -            plt.show() +#            plt.imshow(x.as_array()) +#            plt.show()  #            print(f(operator.direct(x)) + g(x), i)      t_end = time.time()         | 
