diff options
author | jakobsj <jakobsj@users.noreply.github.com> | 2018-09-18 11:56:20 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2018-09-18 11:56:20 +0100 |
commit | f8f294f51aac110049da3eb7c1cb8cbd6a46282f (patch) | |
tree | 0094b4d3370e619c6d1d3c9b5ec1d2272af1414e /Wrappers | |
parent | 2557fb9765d8bdbb236d3b0e3b3d6bed486839f3 (diff) | |
download | framework-f8f294f51aac110049da3eb7c1cb8cbd6a46282f.tar.gz framework-f8f294f51aac110049da3eb7c1cb8cbd6a46282f.tar.bz2 framework-f8f294f51aac110049da3eb7c1cb8cbd6a46282f.tar.xz framework-f8f294f51aac110049da3eb7c1cb8cbd6a46282f.zip |
Add SIRT and Box constraints (#125)
* Quick prototype of SIRT with nonnegativity added
* Add indicator function for boxconstrint in SIRT and FISTA with demo
Diffstat (limited to 'Wrappers')
-rw-r--r-- | Wrappers/Python/ccpi/framework.py | 7 | ||||
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algs.py | 63 | ||||
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/funcs.py | 24 | ||||
-rw-r--r-- | Wrappers/Python/wip/demo_test_sirt.py | 176 |
4 files changed, 270 insertions, 0 deletions
diff --git a/Wrappers/Python/ccpi/framework.py b/Wrappers/Python/ccpi/framework.py index d82010b..0c2432f 100644 --- a/Wrappers/Python/ccpi/framework.py +++ b/Wrappers/Python/ccpi/framework.py @@ -490,6 +490,13 @@ class DataContainer(object): dimension_labels=self.dimension_labels, geometry=self.geometry) + def minimum(self,otherscalar): + out = numpy.minimum(self.as_array(),otherscalar) + return type(self)(out, + deep_copy=True, + dimension_labels=self.dimension_labels, + geometry=self.geometry) + def sign(self): out = numpy.sign(self.as_array() ) return type(self)(out, diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py index f025bbb..570df4b 100755 --- a/Wrappers/Python/ccpi/optimisation/algs.py +++ b/Wrappers/Python/ccpi/optimisation/algs.py @@ -21,6 +21,7 @@ import numpy import time from ccpi.optimisation.funcs import Function +from ccpi.framework import ImageData, AcquisitionData def FISTA(x_init, f=None, g=None, opt=None): '''Fast Iterative Shrinkage-Thresholding Algorithm @@ -222,4 +223,66 @@ def CGLS(x_init, operator , data , opt=None): criter[it] = (r**2).sum() return x, it, timing, criter + +def SIRT(x_init, operator , data , opt=None, constraint=None): + '''Simultaneous Iterative Reconstruction Technique + + Parameters: + x_init: initial guess + operator: operator for forward/backward projections + data: data to operate on + opt: additional algorithm + constraint: func of Indicator type specifying convex constraint. + ''' + + if opt is None: + opt = {'tol': 1e-4, 'iter': 1000} + else: + try: + max_iter = opt['iter'] + except KeyError as ke: + opt[ke] = 1000 + try: + opt['tol'] = 1000 + except KeyError as ke: + opt[ke] = 1e-4 + tol = opt['tol'] + max_iter = opt['iter'] + + # Set default constraint to unconstrained + if constraint==None: + constraint = Function() + + x = x_init.clone() + + timing = numpy.zeros(max_iter) + criter = numpy.zeros(max_iter) + + # Relaxation parameter must be strictly between 0 and 2. For now fix at 1.0 + relax_par = 1.0 + + # Set up scaling matrices D and M. + im1 = ImageData(geometry=x_init.geometry) + im1.array[:] = 1.0 + M = 1/operator.direct(im1) + del im1 + aq1 = AcquisitionData(geometry=M.geometry) + aq1.array[:] = 1.0 + D = 1/operator.adjoint(aq1) + del aq1 + + # algorithm loop + for it in range(0, max_iter): + t = time.time() + r = data - operator.direct(x) + + x = constraint.prox(x + relax_par * (D*operator.adjoint(M*r)),None) + + timing[it] = time.time() - t + if it > 0: + criter[it-1] = (r**2).sum() + + r = data - operator.direct(x) + criter[it] = (r**2).sum() + return x, it, timing, criter diff --git a/Wrappers/Python/ccpi/optimisation/funcs.py b/Wrappers/Python/ccpi/optimisation/funcs.py index 90bc9e3..4a0a139 100755 --- a/Wrappers/Python/ccpi/optimisation/funcs.py +++ b/Wrappers/Python/ccpi/optimisation/funcs.py @@ -128,4 +128,28 @@ class Norm1(Function): def prox(self,x,tau): return (x.abs() - tau*self.gamma).maximum(0) * x.sign() + +# Box constraints indicator function. Calling returns 0 if argument is within +# the box. The prox operator is projection onto the box. Only implements one +# scalar lower and one upper as constraint on all elements. Should generalise +# to vectors to allow different constraints one elements. +class IndicatorBox(Function): + + def __init__(self,lower=-numpy.inf,upper=numpy.inf): + # Do nothing + self.lower = lower + self.upper = upper + super(IndicatorBox, self).__init__() + + def __call__(self,x): + + if (numpy.all(x.array>=self.lower) and + numpy.all(x.array <= self.upper) ): + val = 0 + else: + val = numpy.inf + return val + + def prox(self,x,tau=None): + return (x.maximum(self.lower)).minimum(self.upper) diff --git a/Wrappers/Python/wip/demo_test_sirt.py b/Wrappers/Python/wip/demo_test_sirt.py new file mode 100644 index 0000000..6f5a44d --- /dev/null +++ b/Wrappers/Python/wip/demo_test_sirt.py @@ -0,0 +1,176 @@ +# This demo illustrates how to use the SIRT algorithm without and with +# nonnegativity and box constraints. The ASTRA 2D projectors are used. + +# First make all imports +from ccpi.framework import ImageData, ImageGeometry, AcquisitionGeometry, \ + AcquisitionData +from ccpi.optimisation.algs import FISTA, FBPD, CGLS, SIRT +from ccpi.optimisation.funcs import Norm2sq, Norm1, TV2D, IndicatorBox +from ccpi.astra.ops import AstraProjectorSimple + +import numpy as np +import matplotlib.pyplot as plt + +# Choose either a parallel-beam (1=parallel2D) or fan-beam (2=cone2D) test case +test_case = 1 + +# Set up phantom size NxN by creating ImageGeometry, initialising the +# ImageData object with this geometry and empty array and finally put some +# data into its array, and display as image. +N = 128 +ig = ImageGeometry(voxel_num_x=N,voxel_num_y=N) +Phantom = ImageData(geometry=ig) + +x = Phantom.as_array() +x[round(N/4):round(3*N/4),round(N/4):round(3*N/4)] = 0.5 +x[round(N/8):round(7*N/8),round(3*N/8):round(5*N/8)] = 1 + +plt.imshow(x) +plt.title('Phantom image') +plt.show() + +# Set up AcquisitionGeometry object to hold the parameters of the measurement +# setup geometry: # Number of angles, the actual angles from 0 to +# pi for parallel beam and 0 to 2pi for fanbeam, set the width of a detector +# pixel relative to an object pixel, the number of detector pixels, and the +# source-origin and origin-detector distance (here the origin-detector distance +# set to 0 to simulate a "virtual detector" with same detector pixel size as +# object pixel size). +angles_num = 20 +det_w = 1.0 +det_num = N +SourceOrig = 200 +OrigDetec = 0 + +if test_case==1: + angles = np.linspace(0,np.pi,angles_num,endpoint=False) + ag = AcquisitionGeometry('parallel', + '2D', + angles, + det_num,det_w) +elif test_case==2: + angles = np.linspace(0,2*np.pi,angles_num,endpoint=False) + ag = AcquisitionGeometry('cone', + '2D', + angles, + det_num, + det_w, + dist_source_center=SourceOrig, + dist_center_detector=OrigDetec) +else: + NotImplemented + +# Set up Operator object combining the ImageGeometry and AcquisitionGeometry +# wrapping calls to ASTRA as well as specifying whether to use CPU or GPU. +Aop = AstraProjectorSimple(ig, ag, 'gpu') + +# Forward and backprojection are available as methods direct and adjoint. Here +# generate test data b and do simple backprojection to obtain z. +b = Aop.direct(Phantom) +z = Aop.adjoint(b) + +plt.imshow(b.array) +plt.title('Simulated data') +plt.show() + +plt.imshow(z.array) +plt.title('Backprojected data') +plt.show() + +# Using the test data b, different reconstruction methods can now be set up as +# demonstrated in the rest of this file. In general all methods need an initial +# guess and some algorithm options to be set: +x_init = ImageData(np.zeros(x.shape),geometry=ig) +opt = {'tol': 1e-4, 'iter': 1000} + +# First a CGLS reconstruction can be done: +x_CGLS, it_CGLS, timing_CGLS, criter_CGLS = CGLS(x_init, Aop, b, opt) + +plt.imshow(x_CGLS.array) +plt.title('CGLS') +plt.colorbar() +plt.show() + +plt.semilogy(criter_CGLS) +plt.title('CGLS criterion') +plt.show() + +# A SIRT unconstrained reconstruction can be done: similarly: +x_SIRT, it_SIRT, timing_SIRT, criter_SIRT = SIRT(x_init, Aop, b, opt) + +plt.imshow(x_SIRT.array) +plt.title('SIRT unconstrained') +plt.colorbar() +plt.show() + +plt.semilogy(criter_SIRT) +plt.title('SIRT unconstrained criterion') +plt.show() + +# A SIRT nonnegativity constrained reconstruction can be done using the +# additional input "constraint" set to a box indicator function with 0 as the +# lower bound and the default upper bound of infinity: +x_SIRT0, it_SIRT0, timing_SIRT0, criter_SIRT0 = SIRT(x_init, Aop, b, opt, + constraint=IndicatorBox(lower=0)) + +plt.imshow(x_SIRT0.array) +plt.title('SIRT nonneg') +plt.colorbar() +plt.show() + +plt.semilogy(criter_SIRT0) +plt.title('SIRT nonneg criterion') +plt.show() + +# A SIRT reconstruction with box constraints on [0,1] can also be done: +x_SIRT01, it_SIRT01, timing_SIRT01, criter_SIRT01 = SIRT(x_init, Aop, b, opt, + constraint=IndicatorBox(lower=0,upper=1)) + +plt.imshow(x_SIRT01.array) +plt.title('SIRT box(0,1)') +plt.colorbar() +plt.show() + +plt.semilogy(criter_SIRT01) +plt.title('SIRT box(0,1) criterion') +plt.show() + +# The indicator function can also be used with the FISTA algorithm to do +# least squares with nonnegativity constraint. + +# Create least squares object instance with projector, test data and a constant +# coefficient of 0.5: +f = Norm2sq(Aop,b,c=0.5) + +# Run FISTA for least squares without constraints +x_fista, it, timing, criter = FISTA(x_init, f, None,opt) + +plt.imshow(x_fista.array) +plt.title('FISTA Least squares') +plt.show() + +plt.semilogy(criter) +plt.title('FISTA Least squares criterion') +plt.show() + +# Run FISTA for least squares with nonnegativity constraint +x_fista0, it0, timing0, criter0 = FISTA(x_init, f, IndicatorBox(lower=0),opt) + +plt.imshow(x_fista0.array) +plt.title('FISTA Least squares nonneg') +plt.show() + +plt.semilogy(criter0) +plt.title('FISTA Least squares nonneg criterion') +plt.show() + +# Run FISTA for least squares with box constraint [0,1] +x_fista01, it01, timing01, criter01 = FISTA(x_init, f, IndicatorBox(lower=0,upper=1),opt) + +plt.imshow(x_fista01.array) +plt.title('FISTA Least squares box(0,1)') +plt.show() + +plt.semilogy(criter01) +plt.title('FISTA Least squares box(0,1) criterion') +plt.show()
\ No newline at end of file |