summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
authorjakobsj <jakobsj@users.noreply.github.com>2018-09-18 11:56:20 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2018-09-18 11:56:20 +0100
commitf8f294f51aac110049da3eb7c1cb8cbd6a46282f (patch)
tree0094b4d3370e619c6d1d3c9b5ec1d2272af1414e /Wrappers
parent2557fb9765d8bdbb236d3b0e3b3d6bed486839f3 (diff)
downloadframework-f8f294f51aac110049da3eb7c1cb8cbd6a46282f.tar.gz
framework-f8f294f51aac110049da3eb7c1cb8cbd6a46282f.tar.bz2
framework-f8f294f51aac110049da3eb7c1cb8cbd6a46282f.tar.xz
framework-f8f294f51aac110049da3eb7c1cb8cbd6a46282f.zip
Add SIRT and Box constraints (#125)
* Quick prototype of SIRT with nonnegativity added * Add indicator function for boxconstrint in SIRT and FISTA with demo
Diffstat (limited to 'Wrappers')
-rw-r--r--Wrappers/Python/ccpi/framework.py7
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algs.py63
-rwxr-xr-xWrappers/Python/ccpi/optimisation/funcs.py24
-rw-r--r--Wrappers/Python/wip/demo_test_sirt.py176
4 files changed, 270 insertions, 0 deletions
diff --git a/Wrappers/Python/ccpi/framework.py b/Wrappers/Python/ccpi/framework.py
index d82010b..0c2432f 100644
--- a/Wrappers/Python/ccpi/framework.py
+++ b/Wrappers/Python/ccpi/framework.py
@@ -490,6 +490,13 @@ class DataContainer(object):
dimension_labels=self.dimension_labels,
geometry=self.geometry)
+ def minimum(self,otherscalar):
+ out = numpy.minimum(self.as_array(),otherscalar)
+ return type(self)(out,
+ deep_copy=True,
+ dimension_labels=self.dimension_labels,
+ geometry=self.geometry)
+
def sign(self):
out = numpy.sign(self.as_array() )
return type(self)(out,
diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py
index f025bbb..570df4b 100755
--- a/Wrappers/Python/ccpi/optimisation/algs.py
+++ b/Wrappers/Python/ccpi/optimisation/algs.py
@@ -21,6 +21,7 @@ import numpy
import time
from ccpi.optimisation.funcs import Function
+from ccpi.framework import ImageData, AcquisitionData
def FISTA(x_init, f=None, g=None, opt=None):
'''Fast Iterative Shrinkage-Thresholding Algorithm
@@ -222,4 +223,66 @@ def CGLS(x_init, operator , data , opt=None):
criter[it] = (r**2).sum()
return x, it, timing, criter
+
+def SIRT(x_init, operator , data , opt=None, constraint=None):
+ '''Simultaneous Iterative Reconstruction Technique
+
+ Parameters:
+ x_init: initial guess
+ operator: operator for forward/backward projections
+ data: data to operate on
+ opt: additional algorithm
+ constraint: func of Indicator type specifying convex constraint.
+ '''
+
+ if opt is None:
+ opt = {'tol': 1e-4, 'iter': 1000}
+ else:
+ try:
+ max_iter = opt['iter']
+ except KeyError as ke:
+ opt[ke] = 1000
+ try:
+ opt['tol'] = 1000
+ except KeyError as ke:
+ opt[ke] = 1e-4
+ tol = opt['tol']
+ max_iter = opt['iter']
+
+ # Set default constraint to unconstrained
+ if constraint==None:
+ constraint = Function()
+
+ x = x_init.clone()
+
+ timing = numpy.zeros(max_iter)
+ criter = numpy.zeros(max_iter)
+
+ # Relaxation parameter must be strictly between 0 and 2. For now fix at 1.0
+ relax_par = 1.0
+
+ # Set up scaling matrices D and M.
+ im1 = ImageData(geometry=x_init.geometry)
+ im1.array[:] = 1.0
+ M = 1/operator.direct(im1)
+ del im1
+ aq1 = AcquisitionData(geometry=M.geometry)
+ aq1.array[:] = 1.0
+ D = 1/operator.adjoint(aq1)
+ del aq1
+
+ # algorithm loop
+ for it in range(0, max_iter):
+ t = time.time()
+ r = data - operator.direct(x)
+
+ x = constraint.prox(x + relax_par * (D*operator.adjoint(M*r)),None)
+
+ timing[it] = time.time() - t
+ if it > 0:
+ criter[it-1] = (r**2).sum()
+
+ r = data - operator.direct(x)
+ criter[it] = (r**2).sum()
+ return x, it, timing, criter
diff --git a/Wrappers/Python/ccpi/optimisation/funcs.py b/Wrappers/Python/ccpi/optimisation/funcs.py
index 90bc9e3..4a0a139 100755
--- a/Wrappers/Python/ccpi/optimisation/funcs.py
+++ b/Wrappers/Python/ccpi/optimisation/funcs.py
@@ -128,4 +128,28 @@ class Norm1(Function):
def prox(self,x,tau):
return (x.abs() - tau*self.gamma).maximum(0) * x.sign()
+
+# Box constraints indicator function. Calling returns 0 if argument is within
+# the box. The prox operator is projection onto the box. Only implements one
+# scalar lower and one upper as constraint on all elements. Should generalise
+# to vectors to allow different constraints one elements.
+class IndicatorBox(Function):
+
+ def __init__(self,lower=-numpy.inf,upper=numpy.inf):
+ # Do nothing
+ self.lower = lower
+ self.upper = upper
+ super(IndicatorBox, self).__init__()
+
+ def __call__(self,x):
+
+ if (numpy.all(x.array>=self.lower) and
+ numpy.all(x.array <= self.upper) ):
+ val = 0
+ else:
+ val = numpy.inf
+ return val
+
+ def prox(self,x,tau=None):
+ return (x.maximum(self.lower)).minimum(self.upper)
diff --git a/Wrappers/Python/wip/demo_test_sirt.py b/Wrappers/Python/wip/demo_test_sirt.py
new file mode 100644
index 0000000..6f5a44d
--- /dev/null
+++ b/Wrappers/Python/wip/demo_test_sirt.py
@@ -0,0 +1,176 @@
+# This demo illustrates how to use the SIRT algorithm without and with
+# nonnegativity and box constraints. The ASTRA 2D projectors are used.
+
+# First make all imports
+from ccpi.framework import ImageData, ImageGeometry, AcquisitionGeometry, \
+ AcquisitionData
+from ccpi.optimisation.algs import FISTA, FBPD, CGLS, SIRT
+from ccpi.optimisation.funcs import Norm2sq, Norm1, TV2D, IndicatorBox
+from ccpi.astra.ops import AstraProjectorSimple
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+# Choose either a parallel-beam (1=parallel2D) or fan-beam (2=cone2D) test case
+test_case = 1
+
+# Set up phantom size NxN by creating ImageGeometry, initialising the
+# ImageData object with this geometry and empty array and finally put some
+# data into its array, and display as image.
+N = 128
+ig = ImageGeometry(voxel_num_x=N,voxel_num_y=N)
+Phantom = ImageData(geometry=ig)
+
+x = Phantom.as_array()
+x[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5
+x[round(N/8):round(7*N/8),round(3*N/8):round(5*N/8)] = 1
+
+plt.imshow(x)
+plt.title('Phantom image')
+plt.show()
+
+# Set up AcquisitionGeometry object to hold the parameters of the measurement
+# setup geometry: # Number of angles, the actual angles from 0 to
+# pi for parallel beam and 0 to 2pi for fanbeam, set the width of a detector
+# pixel relative to an object pixel, the number of detector pixels, and the
+# source-origin and origin-detector distance (here the origin-detector distance
+# set to 0 to simulate a "virtual detector" with same detector pixel size as
+# object pixel size).
+angles_num = 20
+det_w = 1.0
+det_num = N
+SourceOrig = 200
+OrigDetec = 0
+
+if test_case==1:
+ angles = np.linspace(0,np.pi,angles_num,endpoint=False)
+ ag = AcquisitionGeometry('parallel',
+ '2D',
+ angles,
+ det_num,det_w)
+elif test_case==2:
+ angles = np.linspace(0,2*np.pi,angles_num,endpoint=False)
+ ag = AcquisitionGeometry('cone',
+ '2D',
+ angles,
+ det_num,
+ det_w,
+ dist_source_center=SourceOrig,
+ dist_center_detector=OrigDetec)
+else:
+ NotImplemented
+
+# Set up Operator object combining the ImageGeometry and AcquisitionGeometry
+# wrapping calls to ASTRA as well as specifying whether to use CPU or GPU.
+Aop = AstraProjectorSimple(ig, ag, 'gpu')
+
+# Forward and backprojection are available as methods direct and adjoint. Here
+# generate test data b and do simple backprojection to obtain z.
+b = Aop.direct(Phantom)
+z = Aop.adjoint(b)
+
+plt.imshow(b.array)
+plt.title('Simulated data')
+plt.show()
+
+plt.imshow(z.array)
+plt.title('Backprojected data')
+plt.show()
+
+# Using the test data b, different reconstruction methods can now be set up as
+# demonstrated in the rest of this file. In general all methods need an initial
+# guess and some algorithm options to be set:
+x_init = ImageData(np.zeros(x.shape),geometry=ig)
+opt = {'tol': 1e-4, 'iter': 1000}
+
+# First a CGLS reconstruction can be done:
+x_CGLS, it_CGLS, timing_CGLS, criter_CGLS = CGLS(x_init, Aop, b, opt)
+
+plt.imshow(x_CGLS.array)
+plt.title('CGLS')
+plt.colorbar()
+plt.show()
+
+plt.semilogy(criter_CGLS)
+plt.title('CGLS criterion')
+plt.show()
+
+# A SIRT unconstrained reconstruction can be done: similarly:
+x_SIRT, it_SIRT, timing_SIRT, criter_SIRT = SIRT(x_init, Aop, b, opt)
+
+plt.imshow(x_SIRT.array)
+plt.title('SIRT unconstrained')
+plt.colorbar()
+plt.show()
+
+plt.semilogy(criter_SIRT)
+plt.title('SIRT unconstrained criterion')
+plt.show()
+
+# A SIRT nonnegativity constrained reconstruction can be done using the
+# additional input "constraint" set to a box indicator function with 0 as the
+# lower bound and the default upper bound of infinity:
+x_SIRT0, it_SIRT0, timing_SIRT0, criter_SIRT0 = SIRT(x_init, Aop, b, opt,
+ constraint=IndicatorBox(lower=0))
+
+plt.imshow(x_SIRT0.array)
+plt.title('SIRT nonneg')
+plt.colorbar()
+plt.show()
+
+plt.semilogy(criter_SIRT0)
+plt.title('SIRT nonneg criterion')
+plt.show()
+
+# A SIRT reconstruction with box constraints on [0,1] can also be done:
+x_SIRT01, it_SIRT01, timing_SIRT01, criter_SIRT01 = SIRT(x_init, Aop, b, opt,
+ constraint=IndicatorBox(lower=0,upper=1))
+
+plt.imshow(x_SIRT01.array)
+plt.title('SIRT box(0,1)')
+plt.colorbar()
+plt.show()
+
+plt.semilogy(criter_SIRT01)
+plt.title('SIRT box(0,1) criterion')
+plt.show()
+
+# The indicator function can also be used with the FISTA algorithm to do
+# least squares with nonnegativity constraint.
+
+# Create least squares object instance with projector, test data and a constant
+# coefficient of 0.5:
+f = Norm2sq(Aop,b,c=0.5)
+
+# Run FISTA for least squares without constraints
+x_fista, it, timing, criter = FISTA(x_init, f, None,opt)
+
+plt.imshow(x_fista.array)
+plt.title('FISTA Least squares')
+plt.show()
+
+plt.semilogy(criter)
+plt.title('FISTA Least squares criterion')
+plt.show()
+
+# Run FISTA for least squares with nonnegativity constraint
+x_fista0, it0, timing0, criter0 = FISTA(x_init, f, IndicatorBox(lower=0),opt)
+
+plt.imshow(x_fista0.array)
+plt.title('FISTA Least squares nonneg')
+plt.show()
+
+plt.semilogy(criter0)
+plt.title('FISTA Least squares nonneg criterion')
+plt.show()
+
+# Run FISTA for least squares with box constraint [0,1]
+x_fista01, it01, timing01, criter01 = FISTA(x_init, f, IndicatorBox(lower=0,upper=1),opt)
+
+plt.imshow(x_fista01.array)
+plt.title('FISTA Least squares box(0,1)')
+plt.show()
+
+plt.semilogy(criter01)
+plt.title('FISTA Least squares box(0,1) criterion')
+plt.show() \ No newline at end of file