From 640d707362604d803bcf17f882dc70af83299bb3 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 31 Oct 2017 11:48:29 +0000 Subject: first working OS FISTA python test --- src/Python/test/test_reconstructor-os.py | 71 +++++++++++++++++++++----------- 1 file changed, 47 insertions(+), 24 deletions(-) diff --git a/src/Python/test/test_reconstructor-os.py b/src/Python/test/test_reconstructor-os.py index 3f419cf..8a0aad8 100644 --- a/src/Python/test/test_reconstructor-os.py +++ b/src/Python/test/test_reconstructor-os.py @@ -13,6 +13,8 @@ from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor import astra import matplotlib.pyplot as plt from ccpi.imaging.Regularizer import Regularizer +from ccpi.reconstruction.AstraDevice import AstraDevice +from ccpi.reconstruction.DeviceModel import DeviceModel def RMSE(signal1, signal2): '''RMSE Root Mean Squared Error''' @@ -65,11 +67,22 @@ vol_geom = astra.creators.create_vol_geom( image_size_x, ## First pass the arguments to the FISTAReconstructor and test the ## Lipschitz constant - +astradevice = AstraDevice(DeviceModel.DeviceType.PARALLEL3D.value, + [proj_geom['DetectorRowCount'] , + proj_geom['DetectorColCount'] , + proj_geom['DetectorSpacingX'] , + proj_geom['DetectorSpacingY'] , + proj_geom['ProjectionAngles'] + ], + [ + vol_geom['GridColCount'], + vol_geom['GridRowCount'], + vol_geom['GridSliceCount'] ] ) fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D , - weights=Weights3D) + weights=Weights3D, + device=astradevice) print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) fistaRecon.setParameter(number_of_iterations = 12) @@ -81,18 +94,21 @@ fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) #reg = Regularizer(Regularizer.Algorithm.FGP_TV) #reg.setParameter(regularization_parameter=0.005, # number_of_iterations=50) -reg = Regularizer(Regularizer.Algorithm.LLT_model) -reg.setParameter(regularization_parameter=25, - time_step=0.0003, +reg = Regularizer(Regularizer.Algorithm.FGP_TV) +reg.setParameter(regularization_parameter=5e6, tolerance_constant=0.0001, - number_of_iterations=300) + number_of_iterations=50) +fistaRecon.setParameter(regularizer=reg) +lc = fistaRecon.getParameter('Lipschitz_constant') +reg.setParameter(regularization_parameter=5e6/lc) ## Ordered subset if True: - subsets = 16 + subsets = 8 fistaRecon.setParameter(subsets=subsets) fistaRecon.createOrderedSubsets() +else: angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] #binEdges = numpy.linspace(angles.min(), # angles.max(), @@ -192,26 +208,29 @@ if True: #if SlicesZ > 1: # vec = vec[:,1,:] # 1 or 0? r_x = fistaRecon.r_x - fistaRecon.r = (r_x - (1./L_const) * vec).copy() + # update ring variable + fistaRecon.r = (r_x - (1./L_const) * vec).copy() # subset loop counterInd = 1 geometry_type = fistaRecon.getParameter('projector_geometry')['type'] - if geometry_type == 'parallel' or \ - geometry_type == 'fanflat' or \ - geometry_type == 'fanflat_vec' : - - for kkk in range(SlicesZ): - sino_id, sinoT[kkk] = \ - astra.creators.create_sino3d_gpu( - X_t[kkk:kkk+1], proj_geomSUB, vol_geom) - sino_updt_Sub[kkk] = sinoT.T.copy() - - else: - sino_id, sino_updt_Sub = \ - astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom) - - astra.matlab.data3d('delete', sino_id) + angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] + +## if geometry_type == 'parallel' or \ +## geometry_type == 'fanflat' or \ +## geometry_type == 'fanflat_vec' : +## +## for kkk in range(SlicesZ): +## sino_id, sinoT[kkk] = \ +## astra.creators.create_sino3d_gpu( +## X_t[kkk:kkk+1], proj_geomSUB, vol_geom) +## sino_updt_Sub[kkk] = sinoT.T.copy() +## +## else: +## sino_id, sino_updt_Sub = \ +## astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom) +## +## astra.matlab.data3d('delete', sino_id) for ss in range(fistaRecon.getParameter('subsets')): print ("Subset {0}".format(ss)) @@ -276,6 +295,7 @@ if True: else: #PWLS model + # I guess we need to use mask here instead residualSub = weights[:,CurrSubIndices,:] * \ ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) objective[i] = 0.5 * numpy.linalg.norm(residualSub) @@ -310,7 +330,10 @@ if True: # for slices: # out = regularizer(input=X) print ("regularizer") - X = reg(input=X)[0] + reg = fistaRecon.getParameter('regularizer') + + X = reg(input=X, + output_all=False) ## FINAL -- cgit v1.2.3