diff options
Diffstat (limited to 'src/Python')
| -rw-r--r-- | src/Python/ccpi/reconstruction/FISTAReconstructor.py | 124 | 
1 files changed, 92 insertions, 32 deletions
diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index c903712..1e464a1 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -27,6 +27,7 @@ import numpy  from enum import Enum  import astra +from ccpi.reconstruction.AstraDevice import AstraDevice @@ -87,6 +88,19 @@ class FISTAReconstructor():          self.pars['SlicesZ'] = sliceZ          self.pars['output_volume'] = None + +         +        device = createAstraDevice(projector_geometry, output_geometry) +        self.setParameter(device_model=device) +        proj_geomT = projector_geometry.copy(); +        proj_geomT['DetectorRowCount'] = 1; +        vol_geomT = output_geometry.copy(); +        vol_geomT['GridSliceCount'] = 1; +        reduced_device = createAstraDevice(proj_geomT, vol_geomT) +        self.setParameter(reduced_device_model=reduced_device) + +        self.use_device = True +                  print (self.pars)          # handle optional input parameters (at instantiation) @@ -113,7 +127,9 @@ class FISTAReconstructor():                'output_volume',                'os_subsets',                'os_indices', -              'os_bins') +              'os_bins', +              'device_model', +              'reduced_device_model')          self.acceptedInputKeywords = list(kw)          # handle keyworded parameters @@ -171,7 +187,20 @@ class FISTAReconstructor():              self.pars['initialize'] = False -             +    def createAstraDevice(self, projector_geometry, output_geometry): +        '''TODO remove''' +         +        device = AstraDevice(DeviceModel.PARALLEL3D, +                    {'detectorSpacingX' : projector_geometry['DetectorSpacingX'] , +                     'detectorSpacingY' : projector_geometry['DetectorSpacingY'] , +                     'cameraX' : projector_geometry['DetectorColCount'] , +                     'cameraY' : projector_geometry['DetectorRowCount'] , +                     'angles' : projector_geometry['ProjectionAngles'] } , +                    { +                        'X' : output_geometry['GridColCount'], +                        'Y' : output_geometry['GridRowCount'] +                        'Z' : output_geometry['GridSliceCount']} ) +        return device      def setParameter(self, **kwargs):          '''set named parameter for the reconstructor engine @@ -436,13 +465,17 @@ class FISTAReconstructor():          X_t = X.copy()          # convenience variable storage          proj_geom , vol_geom, sino , \ -          SlicesZ  = self.getParameter([ 'projector_geometry' , +          SlicesZ , ring_lambda_R_L1 = self.getParameter([ 'projector_geometry' ,                                                  'output_geometry',                                                  'input_sinogram', -                                                'SlicesZ' ]) +                                                'SlicesZ' , +                                                'ring_lambda_R_L1'])          t = 1 +        device = self.getParameter('device_model') +        reduced_device = self.getParameter('reduced_device_model') +                  for i in range(self.getParameter('number_of_iterations')):              X_old = X.copy()              t_old = t @@ -453,29 +486,42 @@ class FISTAReconstructor():                  # if the geometry is parallel use slice-by-slice                  # projection-backprojection routine                  #sino_updt = zeros(size(sino),'single'); -                proj_geomT = proj_geom.copy() -                proj_geomT['DetectorRowCount'] = 1 -                vol_geomT = vol_geom.copy() -                vol_geomT['GridSliceCount'] = 1; -                self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) -                for kkk in range(SlicesZ): -                    sino_id, self.sino_updt[kkk] = \ -                             astra.creators.create_sino3d_gpu( -                                 X_t[kkk:kkk+1], proj_geomT, vol_geomT) -                    astra.matlab.data3d('delete', sino_id) +                 +                if self.use_device: +                    self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) +                     +                    for kkk in range(SlicesZ): +                        self.sino_updt[kkk] = \ +                            reduced_device.doForwardProject( X_t[kkk:kkk+1] ) +                else: +                    proj_geomT = proj_geom.copy() +                    proj_geomT['DetectorRowCount'] = 1 +                    vol_geomT = vol_geom.copy() +                    vol_geomT['GridSliceCount'] = 1; +                    self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) +                    for kkk in range(SlicesZ): +                        sino_id, self.sino_updt[kkk] = \ +                                 astra.creators.create_sino3d_gpu( +                                     X_t[kkk:kkk+1], proj_geomT, vol_geomT) +                        astra.matlab.data3d('delete', sino_id)              else:                  # for divergent 3D geometry (watch the GPU memory overflow in                  # ASTRA versions < 1.8)                  #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); -                sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( -                    X_t, proj_geom, vol_geom) +                if self.use_device: +                    self.sino_updt = device.doForwardProject(X_t) +                else: +                    sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( +                        X_t, proj_geom, vol_geom) +                    astra.matlab.data3d('delete', sino_id)              ## RING REMOVAL -            self.ringRemoval(i) +            if ring_lambda_R_L1 != 0: +                self.ringRemoval(i)              ## Projection/Backprojection Routine              self.projectionBackprojection(X, X_t) -            astra.matlab.data3d('delete', sino_id) +                          ## REGULARIZATION              X = self.regularize(X)              ## Update Loop @@ -523,6 +569,8 @@ class FISTAReconstructor():                                                    'output_geometry',                                                    'Lipschitz_constant']) +        device, reduced_device = self.getParameter(['device_model', +                                                    'reduced_device_model'])          if self.getParameter('projector_geometry')['type'] == 'parallel' or \             self.getParameter('projector_geometry')['type'] == 'fanflat' or \ @@ -530,27 +578,39 @@ class FISTAReconstructor():              # if the geometry is parallel use slice-by-slice              # projection-backprojection routine              #sino_updt = zeros(size(sino),'single'); -            proj_geomT = proj_geom.copy() -            proj_geomT['DetectorRowCount'] = 1 -            vol_geomT = vol_geom.copy() -            vol_geomT['GridSliceCount'] = 1;              x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) -             -            for kkk in range(SlicesZ): -                x_id, x_temp[kkk] = \ -                         astra.creators.create_backprojection3d_gpu( -                             residual[kkk:kkk+1], -                             proj_geomT, vol_geomT) -                astra.matlab.data3d('delete', x_id) +            if use_device: +                proj_geomT = proj_geom.copy() +                proj_geomT['DetectorRowCount'] = 1 +                vol_geomT = vol_geom.copy() +                vol_geomT['GridSliceCount'] = 1; +                 +                for kkk in range(SlicesZ): +                     +                    x_id, x_temp[kkk] = \ +                             astra.creators.create_backprojection3d_gpu( +                                 residual[kkk:kkk+1], +                                 proj_geomT, vol_geomT) +                    astra.matlab.data3d('delete', x_id) +            else: +                for kkk in range(SliceZ): +                    x_temp[kkk] = \ +                        reduced_device.doBackwardProject(residual[kkk:kkk+1])          else: -            x_id, x_temp = \ +            if use_device: +                x_id, x_temp = \                    astra.creators.create_backprojection3d_gpu( -                      residual, proj_geom, vol_geom)             +                      residual, proj_geom, vol_geom) +                astra.matlab.data3d('delete', x_id) +            else: +                x_temp = \ +                    device.doBackwardProject(residual) +                                 X = X_t - (1/L_const) * x_temp          #astra.matlab.data3d('delete', sino_id) -        astra.matlab.data3d('delete', x_id) +              def regularize(self, X):          print ("FISTA Reconstructor: regularize")  | 
