summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/Python/ccpi/fista/FISTAReconstructor.py184
1 files changed, 160 insertions, 24 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py
index cbd27da..8318ea6 100644
--- a/src/Python/ccpi/fista/FISTAReconstructor.py
+++ b/src/Python/ccpi/fista/FISTAReconstructor.py
@@ -78,19 +78,28 @@ class FISTAReconstructor():
# handle parmeters:
# obligatory parameters
self.pars = dict()
- self.pars['projector_geometry'] = projector_geometry
- self.pars['output_geometry'] = output_geometry
- self.pars['input_sinogram'] = input_sinogram
+ self.pars['projector_geometry'] = projector_geometry # proj_geom
+ self.pars['output_geometry'] = output_geometry # vol_geom
+ self.pars['input_sinogram'] = input_sinogram # sino
detectors, nangles, sliceZ = numpy.shape(input_sinogram)
self.pars['detectors'] = detectors
- self.pars['number_og_angles'] = nangles
+ self.pars['number_of_angles'] = nangles
self.pars['SlicesZ'] = sliceZ
print (self.pars)
# handle optional input parameters (at instantiation)
# Accepted input keywords
- kw = ('number_of_iterations',
+ kw = (
+ # mandatory fields
+ 'projector_geometry',
+ 'output_geometry',
+ 'input_sinogram',
+ 'detectors',
+ 'number_of_angles',
+ 'SlicesZ',
+ # optional fields
+ 'number_of_iterations',
'Lipschitz_constant' ,
'ideal_image' ,
'weights' ,
@@ -98,8 +107,9 @@ class FISTAReconstructor():
'initialize' ,
'regularizer' ,
'ring_lambda_R_L1',
- 'ring_alpha')
- self.acceptedInputKeywords = kw
+ 'ring_alpha',
+ 'subsets')
+ self.acceptedInputKeywords = list(kw)
# handle keyworded parameters
if kwargs is not None:
@@ -122,8 +132,7 @@ class FISTAReconstructor():
if 'Lipschitz_constant' in kwargs.keys():
self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
else:
- self.pars['Lipschitz_constant'] = \
- self.calculateLipschitzConstantWithPowerMethod()
+ self.pars['Lipschitz_constant'] = None
if not 'ideal_image' in kwargs.keys():
self.pars['ideal_image'] = None
@@ -143,31 +152,44 @@ class FISTAReconstructor():
self.pars['ring_lambda_R_L1'] = 0
if not 'ring_alpha' in kwargs.keys():
self.pars['ring_alpha'] = 1
-
+
+ if not 'subsets' in kwargs.keys():
+ self.pars['subsets'] = 0
+ else:
+ self.createOrderedSubsets()
+
+ if not 'initialize' in kwargs.keys():
+ self.pars['initialize'] = False
def setParameter(self, **kwargs):
- '''set named parameter for the regularization engine
+ '''set named parameter for the reconstructor engine
raises Exception if the named parameter is not recognized
- Typical usage is:
-
- reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
- reg.setParameter(input=u0)
- reg.setParameter(regularization_parameter=10.)
- it can be also used as
- reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
- reg.setParameter(input=u0 , regularization_parameter=10.)
'''
-
for key , value in kwargs.items():
- if key in self.acceptedInputKeywords.keys():
+ if key in self.acceptedInputKeywords:
self.pars[key] = value
else:
- raise Exception('Wrong parameter {0} for '.format(key) +
- 'Reconstruction algorithm')
+ raise Exception('Wrong parameter {0} for '.format(key) +
+ 'reconstructor')
# setParameter
+
+ def getParameter(self, key):
+ if type(key) is str:
+ if key in self.acceptedInputKeywords:
+ return self.pars[key]
+ else:
+ raise Exception('Unrecongnised parameter: {0} '.format(key) )
+ elif type(key) is list:
+ outpars = []
+ for k in key:
+ outpars.append(self.getParameter(k))
+ return outpars
+ else:
+ raise Exception('Unhandled input {0}' .format(str(type(key))))
+
def calculateLipschitzConstantWithPowerMethod(self):
''' using Power method (PM) to establish L constant'''
@@ -289,5 +311,119 @@ class FISTAReconstructor():
if regularizer is not None:
self.pars['regularizer'] = regularizer
+
+ def initialize(self):
+ # convenience variable storage
+ proj_geom = self.pars['projector_geometry']
+ vol_geom = self.pars['output_geometry']
+ sino = self.pars['input_sinogram']
+
+ # a 'warm start' with SIRT method
+ # Create a data object for the reconstruction
+ rec_id = astra.matlab.data3d('create', '-vol',
+ vol_geom);
+
+ #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino);
+ sinogram_id = astra.matlab.data3d('create', '-proj3d',
+ proj_geom,
+ sino)
+
+ sirt_config = astra.astra_dict('SIRT3D_CUDA')
+ sirt_config['ReconstructionDataId' ] = rec_id
+ sirt_config['ProjectionDataId'] = sinogram_id
+
+ sirt = astra.algorithm.create(sirt_config)
+ astra.algorithm.run(sirt, iterations=35)
+ X = astra.matlab.data3d('get', rec_id)
+
+ # clean up memory
+ astra.matlab.data3d('delete', rec_id)
+ astra.matlab.data3d('delete', sinogram_id)
+ astra.algorithm.delete(sirt)
+
+
+
+ return X
+
+ def createOrderedSubsets(self, subsets=None):
+ if subsets is None:
+ try:
+ subsets = self.getParameter('subsets')
+ except Exception():
+ subsets = 0
+ #return subsets
+
+ angles = self.getParameter('projector_geometry')['ProjectionAngles']
+
+
+
+
+
+
+ def prepareForIteration(self):
+ self.residual_error = numpy.zeros((self.pars['number_of_iterations']))
+ self.objective = numpy.zeros((self.pars['number_of_iterations']))
+
+ #2D array (for 3D data) of sparse "ring"
+ detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram'])
+ self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float)
+ # another ring variable
+ self.rx = self.r.copy()
+
+ self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram']))
+
+ if self.getParameter('Lipschitz_constant') is None:
+ self.pars['Lipschitz_constant'] = \
+ self.calculateLipschitzConstantWithPowerMethod()
+
+ # prepareForIteration
+
+ def iterate(self, Xin=None):
+ # convenience variable storage
+ proj_geom , vol_geom, sino , \
+ SlicesZ = self.getParameter(['projector_geometry' ,
+ 'output_geometry',
+ 'input_sinogram',
+ 'SlicesZ'])
+
+ t = 1
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ X = Xin.copy()
+
+ X_t = X.copy()
+
+ for i in range(self.getParameter('number_of_iterations')):
+ X_old = X.copy()
+ t_old = t
+ r_old = self.r.copy()
+ if self.pars['projector_geometry']['type'] == 'parallel' or \
+ self.pars['projector_geometry']['type'] == 'parallel3d':
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+ sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+
+ #for kkk = 1:SlicesZ
+ # [sino_id, sino_updt(:,:,kkk)] =
+ # astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT);
+ # astra_mex_data3d('delete', sino_id);
+ for kkk in range(SlicesZ):
+ sino_id, sino_updt[kkk] = \
+ astra.creators.create_sino3d_gpu(
+ X_t[kkk], proj_geomT, vol_geomT)
+
+ else:
+ # for divergent 3D geometry (watch GPU memory overflow in
+ # Astra < 1.8
+ sino_id, y = astra.creators.create_sino3d_gpu(X_t,
+ proj_geom,
+ vol_geom)
-
+
+