summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorepapoutsellis <epapoutsellis@gmail.com>2019-04-04 17:40:28 +0100
committerepapoutsellis <epapoutsellis@gmail.com>2019-04-04 17:40:28 +0100
commit2a1607f35aebc1938f30ba66f700a8f893ed5be4 (patch)
tree54a601cbd08b6b85603e963b58606d6430dcd43c
parentea7113b7d86453077dc45674ab8506aac5f2b8e0 (diff)
downloadframework-2a1607f35aebc1938f30ba66f700a8f893ed5be4.tar.gz
framework-2a1607f35aebc1938f30ba66f700a8f893ed5be4.tar.bz2
framework-2a1607f35aebc1938f30ba66f700a8f893ed5be4.tar.xz
framework-2a1607f35aebc1938f30ba66f700a8f893ed5be4.zip
to work with precond
-rwxr-xr-xWrappers/Python/ccpi/framework/__init__.py1
-rwxr-xr-xWrappers/Python/ccpi/framework/framework.py3
-rw-r--r--Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py13
3 files changed, 11 insertions, 6 deletions
diff --git a/Wrappers/Python/ccpi/framework/__init__.py b/Wrappers/Python/ccpi/framework/__init__.py
index 66e2f56..229edb5 100755
--- a/Wrappers/Python/ccpi/framework/__init__.py
+++ b/Wrappers/Python/ccpi/framework/__init__.py
@@ -15,6 +15,7 @@ from datetime import timedelta, datetime
import warnings
from functools import reduce
+
from .framework import DataContainer
from .framework import ImageData, AcquisitionData
from .framework import ImageGeometry, AcquisitionGeometry
diff --git a/Wrappers/Python/ccpi/framework/framework.py b/Wrappers/Python/ccpi/framework/framework.py
index ae9faf7..07c2ead 100755
--- a/Wrappers/Python/ccpi/framework/framework.py
+++ b/Wrappers/Python/ccpi/framework/framework.py
@@ -29,6 +29,7 @@ import warnings
from functools import reduce
from numbers import Number
+
def find_key(dic, val):
"""return the key of dictionary dic given the value"""
return [k for k, v in dic.items() if v == val][0]
@@ -496,6 +497,7 @@ class DataContainer(object):
## algebra
def __add__(self, other, *args, **kwargs):
out = kwargs.get('out', None)
+
if issubclass(type(other), DataContainer):
if self.check_dimensions(other):
out = self.as_array() + other.as_array()
@@ -601,6 +603,7 @@ class DataContainer(object):
deep_copy=True,
dimension_labels=self.dimension_labels,
geometry=self.geometry)
+
else:
raise TypeError('Cannot {0} DataContainer with {1}'.format("multiply" ,
type(other)))
diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
index 1229c4e..084818c 100644
--- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
+++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py
@@ -8,7 +8,7 @@ Created on Mon Feb 4 16:18:06 2019
from ccpi.optimisation.algorithms import Algorithm
from ccpi.framework import ImageData
import numpy as np
-#import matplotlib.pyplot as plt
+import matplotlib.pyplot as plt
import time
from ccpi.optimisation.operators import BlockOperator
from ccpi.framework import BlockDataContainer
@@ -120,12 +120,13 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
objective = []
+
for i in range(niter):
# Gradient descent, Dual problem solution
y_tmp = y_old + sigma * operator.direct(xbar)
y = f.proximal_conjugate(y_tmp, sigma)
-
+
# Gradient ascent, Primal problem solution
x_tmp = x_old - tau * operator.adjoint(y)
x = g.proximal(x_tmp, tau)
@@ -135,15 +136,15 @@ def PDHG_old(f, g, operator, tau = None, sigma = None, opt = None, **kwargs):
x_old = x
y_old = y
-
+
if i%100==0:
primal = f(operator.direct(x)) + g(x)
dual = -(f.convex_conjugate(y) + g(-1*operator.adjoint(y)))
- print( i, primal, dual)
+ print( i, primal, dual, primal-dual)
- plt.imshow(x.as_array())
- plt.show()
+# plt.imshow(x.as_array())
+# plt.show()
# print(f(operator.direct(x)) + g(x), i)
t_end = time.time()