summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorjakobsj <jakobsj@users.noreply.github.com>2018-03-14 12:14:54 +0000
committerEdoardo Pasca <edo.paskino@gmail.com>2018-03-14 12:14:54 +0000
commit925d47ad8f5f6024122324af1b81a701b4ffc4aa (patch)
tree45ee2157a80cd98f06e0c2dd96d056b1c81deceb /Wrappers/Python
parentaabb107bbbfd41067deec8290d29d00bbb336af8 (diff)
downloadframework-925d47ad8f5f6024122324af1b81a701b4ffc4aa.tar.gz
framework-925d47ad8f5f6024122324af1b81a701b4ffc4aa.tar.bz2
framework-925d47ad8f5f6024122324af1b81a701b4ffc4aa.tar.xz
framework-925d47ad8f5f6024122324af1b81a701b4ffc4aa.zip
Add CGLS algorithm (#52)
* Added modular CGLS alg and use it in simple_demo * CGLS demo plot titles #51 * Increase num CGLS iters * CGLS: Added copy of b to r and x_init to x * Added clone method to DataSet Use clone in CGLS Added many plots to demo
Diffstat (limited to 'Wrappers/Python')
-rw-r--r--Wrappers/Python/ccpi/framework.py10
-rw-r--r--Wrappers/Python/ccpi/reconstruction/algs.py37
-rw-r--r--Wrappers/Python/wip/simple_demo.py61
3 files changed, 106 insertions, 2 deletions
diff --git a/Wrappers/Python/ccpi/framework.py b/Wrappers/Python/ccpi/framework.py
index 5a507d9..b2f8a7e 100644
--- a/Wrappers/Python/ccpi/framework.py
+++ b/Wrappers/Python/ccpi/framework.py
@@ -425,6 +425,15 @@ class DataSet(object):
repres += "Axis labels: {0}\n".format(self.dimension_labels)
repres += "Representation: \n{0}\n".format(self.array)
return repres
+
+ def clone(self):
+ '''returns a copy of itself'''
+
+ return type(self)(self.array,
+ dimension_labels=self.dimension_labels,
+ deep_copy=True,
+ geometry=self.geometry )
+
@@ -895,4 +904,5 @@ if __name__ == '__main__':
geom_type='parallel', pixel_num_v=3,
pixel_num_h=5 , channels=2)
sino = SinogramData(geometry=sgeometry)
+ sino2 = sino.clone()
\ No newline at end of file
diff --git a/Wrappers/Python/ccpi/reconstruction/algs.py b/Wrappers/Python/ccpi/reconstruction/algs.py
index be0abb5..088b36e 100644
--- a/Wrappers/Python/ccpi/reconstruction/algs.py
+++ b/Wrappers/Python/ccpi/reconstruction/algs.py
@@ -21,6 +21,7 @@ import numpy
import time
from ccpi.reconstruction.funcs import BaseFunction
+from ccpi.framework import SinogramData, VolumeData
def FISTA(x_init, f=None, g=None, opt=None):
@@ -126,3 +127,39 @@ def FBPD(x_init, f=None, g=None, h=None, opt=None):
timing = numpy.cumsum(timing[0:it+1]);
return x, it, timing, criter
+
+def CGLS(A,b,max_iter,x_init):
+ '''Conjugate Gradient Least Squares algorithm'''
+
+ r = b.clone()
+ x = x_init.clone()
+
+ d = A.adjoint(r)
+
+ normr2 = (d**2).sum()
+
+ timing = numpy.zeros(max_iter)
+ criter = numpy.zeros(max_iter)
+
+ # algorithm loop
+ for it in range(0, max_iter):
+
+ t = time.time()
+
+ Ad = A.direct(d)
+ alpha = normr2/( (Ad**2).sum() )
+ x = x + alpha*d
+ r = r - alpha*Ad
+ s = A.adjoint(r)
+
+ normr2_new = (s**2).sum()
+ beta = normr2_new/normr2
+ normr2 = normr2_new
+ d = s + beta*d
+
+ # time and criterion
+ timing[it] = time.time() - t
+ criter[it] = (r**2).sum()
+
+ return x, it, timing, criter
+
diff --git a/Wrappers/Python/wip/simple_demo.py b/Wrappers/Python/wip/simple_demo.py
index da617c1..766e448 100644
--- a/Wrappers/Python/wip/simple_demo.py
+++ b/Wrappers/Python/wip/simple_demo.py
@@ -56,7 +56,7 @@ elif test_case==2:
dist_center_detector=OrigDetec)
# ASTRA operator using volume and sinogram geometries
-Aop = AstraProjectorSimple(vg, pg, 'gpu')
+Aop = AstraProjectorSimple(vg, pg, 'cpu')
# Unused old astra projector without geometry
# Aop_old = AstraProjector(det_w, det_num, SourceOrig,
@@ -83,6 +83,7 @@ x_init = VolumeData(np.zeros(x.shape),geometry=vg)
x_fista0, it0, timing0, criter0 = FISTA(x_init, f, None)
plt.imshow(x_fista0.array)
+plt.title('FISTA0')
plt.show()
# Now least squares plus 1-norm regularization
@@ -93,6 +94,7 @@ g0 = Norm1(lam)
x_fista1, it1, timing1, criter1 = FISTA(x_init, f, g0)
plt.imshow(x_fista1.array)
+plt.title('FISTA')
plt.show()
plt.semilogy(criter1)
@@ -103,7 +105,62 @@ opt = {'tol': 1e-4, 'iter': 10000}
x_fbpd1, it_fbpd1, timing_fbpd1, criter_fbpd1 = FBPD(x_init,None,f,g0,opt=opt)
plt.imshow(x_fbpd1.array)
+plt.title('FBPD')
plt.show()
plt.semilogy(criter_fbpd1)
-plt.show() \ No newline at end of file
+plt.show()
+
+# Run CGLS, which should agree with the FISTA0
+x_CGLS, it_CGLS, timing_CGLS, criter_CGLS = CGLS(Aop, b, 1000, x_init)
+
+plt.imshow(x_CGLS.array)
+plt.title('CGLS')
+plt.title('CGLS recon, compare FISTA0')
+plt.show()
+
+plt.semilogy(criter_CGLS)
+plt.title('CGLS criterion')
+plt.show()
+
+
+#%%
+cols = 3
+rows = 2
+current = 1
+fig = plt.figure()
+# projections row
+a=fig.add_subplot(rows,cols,current)
+a.set_title('phantom {0}'.format(numpy.shape(Phantom.as_array())))
+imgplot = plt.imshow(Phantom.as_array())
+
+current = current + 1
+a=fig.add_subplot(rows,cols,current)
+a.set_title('FISTA0')
+imgplot = plt.imshow(x_fista0.as_array())
+
+current = current + 1
+a=fig.add_subplot(rows,cols,current)
+a.set_title('FISTA1')
+imgplot = plt.imshow(x_fista1.as_array())
+
+current = current + 1
+a=fig.add_subplot(rows,cols,current)
+a.set_title('FBPD')
+imgplot = plt.imshow(x_fbpd1.as_array())
+
+current = current + 1
+a=fig.add_subplot(rows,cols,current)
+a.set_title('CGLS')
+imgplot = plt.imshow(x_CGLS.as_array())
+
+current = current + 1
+a=fig.add_subplot(rows,cols,current)
+a.set_title('criteria')
+imgplot = plt.loglog(criter0 , label='FISTA0')
+imgplot = plt.loglog(criter1 , label='FISTA1')
+imgplot = plt.loglog(criter_fbpd1, label='FBPD')
+imgplot = plt.loglog(criter_CGLS, label='CGLS')
+a.legend(loc='right')
+plt.show()
+#%% \ No newline at end of file