diff options
Diffstat (limited to 'src/Python/test')
| -rw-r--r-- | src/Python/test/test_reconstructor.py | 67 | 
1 files changed, 47 insertions, 20 deletions
diff --git a/src/Python/test/test_reconstructor.py b/src/Python/test/test_reconstructor.py index 3342301..343b9bb 100644 --- a/src/Python/test/test_reconstructor.py +++ b/src/Python/test/test_reconstructor.py @@ -12,6 +12,9 @@ import numpy  from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor  import astra  import matplotlib.pyplot as plt +from ccpi.imaging.Regularizer import Regularizer +from ccpi.reconstruction.AstraDevice import AstraDevice +from ccpi.reconstruction.DeviceModel import DeviceModel  def RMSE(signal1, signal2):      '''RMSE Root Mean Squared Error''' @@ -22,7 +25,23 @@ def RMSE(signal1, signal2):          return err      else:          raise Exception('Input signals must have the same shape') -   + +def createAstraDevice(projector_geometry, output_geometry): +    '''TODO remove''' +     +    device = AstraDevice(DeviceModel.DeviceType.PARALLEL3D.value, +                [projector_geometry['DetectorSpacingX'] , +                 projector_geometry['DetectorSpacingY'] , +                 projector_geometry['DetectorColCount'] , +                 projector_geometry['DetectorRowCount'] , +                 projector_geometry['ProjectionAngles'] +                 ], +                [ +                    output_geometry['GridColCount'], +                    output_geometry['GridRowCount'],  +                    output_geometry['GridSliceCount'] ] ) +    return device +  filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5'  nx = h5py.File(filename, "r")  #getEntry(nx, '/') @@ -65,23 +84,23 @@ vol_geom = astra.creators.create_vol_geom( image_size_x,  ## First pass the arguments to the FISTAReconstructor and test the  ## Lipschitz constant -fistaRecon = FISTAReconstructor(proj_geom, -                                vol_geom, -                                Sino3D , -                                weights=Weights3D) - -print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) -fistaRecon.setParameter(number_of_iterations = 12) -fistaRecon.setParameter(Lipschitz_constant = 767893952.0) -fistaRecon.setParameter(ring_alpha = 21) -fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) - -reg = Regularizer(Regularizer.Algorithm.LLT_model) -reg.setParameter(regularization_parameter=25, -                          time_step=0.0003, -                          tolerance_constant=0.0001, -                          number_of_iterations=300) -fistaRecon.setParameter(regularizer = reg) +##fistaRecon = FISTAReconstructor(proj_geom, +##                                vol_geom, +##                                Sino3D , +##                                weights=Weights3D) +## +##print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +##fistaRecon.setParameter(number_of_iterations = 12) +##fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +##fistaRecon.setParameter(ring_alpha = 21) +##fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) +## +##reg = Regularizer(Regularizer.Algorithm.LLT_model) +##reg.setParameter(regularization_parameter=25, +##                          time_step=0.0003, +##                          tolerance_constant=0.0001, +##                          number_of_iterations=300) +##fistaRecon.setParameter(regularizer=reg)  ## Ordered subset  if False: @@ -294,16 +313,24 @@ if False:  ##            fprintf('%s %i  %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i));  ##        end  else: +     +    # create a device for forward/backprojection +    astradevice = createAstraDevice(proj_geom, vol_geom)      fistaRecon = FISTAReconstructor(proj_geom,                                  vol_geom,                                  Sino3D , -                                weights=Weights3D) +                                device = astradevice, +                                weights=Weights3D +                                )      print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) -    fistaRecon.setParameter(number_of_iterations = 12) +    fistaRecon.setParameter(number_of_iterations = 3)      fistaRecon.setParameter(Lipschitz_constant = 767893952.0)      fistaRecon.setParameter(ring_alpha = 21)      fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + +     +          fistaRecon.prepareForIteration()      X = fistaRecon.iterate(numpy.load("X.npy"))      numpy.save("X_out.npy", X)  | 
