summaryrefslogtreecommitdiffstats
path: root/src/Python
diff options
context:
space:
mode:
Diffstat (limited to 'src/Python')
-rw-r--r--src/Python/test/test_reconstructor-os.py71
1 files changed, 47 insertions, 24 deletions
diff --git a/src/Python/test/test_reconstructor-os.py b/src/Python/test/test_reconstructor-os.py
index 3f419cf..8a0aad8 100644
--- a/src/Python/test/test_reconstructor-os.py
+++ b/src/Python/test/test_reconstructor-os.py
@@ -13,6 +13,8 @@ from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor
import astra
import matplotlib.pyplot as plt
from ccpi.imaging.Regularizer import Regularizer
+from ccpi.reconstruction.AstraDevice import AstraDevice
+from ccpi.reconstruction.DeviceModel import DeviceModel
def RMSE(signal1, signal2):
'''RMSE Root Mean Squared Error'''
@@ -65,11 +67,22 @@ vol_geom = astra.creators.create_vol_geom( image_size_x,
## First pass the arguments to the FISTAReconstructor and test the
## Lipschitz constant
-
+astradevice = AstraDevice(DeviceModel.DeviceType.PARALLEL3D.value,
+ [proj_geom['DetectorRowCount'] ,
+ proj_geom['DetectorColCount'] ,
+ proj_geom['DetectorSpacingX'] ,
+ proj_geom['DetectorSpacingY'] ,
+ proj_geom['ProjectionAngles']
+ ],
+ [
+ vol_geom['GridColCount'],
+ vol_geom['GridRowCount'],
+ vol_geom['GridSliceCount'] ] )
fistaRecon = FISTAReconstructor(proj_geom,
vol_geom,
Sino3D ,
- weights=Weights3D)
+ weights=Weights3D,
+ device=astradevice)
print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
fistaRecon.setParameter(number_of_iterations = 12)
@@ -81,18 +94,21 @@ fistaRecon.setParameter(ring_lambda_R_L1 = 0.002)
#reg = Regularizer(Regularizer.Algorithm.FGP_TV)
#reg.setParameter(regularization_parameter=0.005,
# number_of_iterations=50)
-reg = Regularizer(Regularizer.Algorithm.LLT_model)
-reg.setParameter(regularization_parameter=25,
- time_step=0.0003,
+reg = Regularizer(Regularizer.Algorithm.FGP_TV)
+reg.setParameter(regularization_parameter=5e6,
tolerance_constant=0.0001,
- number_of_iterations=300)
+ number_of_iterations=50)
+fistaRecon.setParameter(regularizer=reg)
+lc = fistaRecon.getParameter('Lipschitz_constant')
+reg.setParameter(regularization_parameter=5e6/lc)
## Ordered subset
if True:
- subsets = 16
+ subsets = 8
fistaRecon.setParameter(subsets=subsets)
fistaRecon.createOrderedSubsets()
+else:
angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles']
#binEdges = numpy.linspace(angles.min(),
# angles.max(),
@@ -192,26 +208,29 @@ if True:
#if SlicesZ > 1:
# vec = vec[:,1,:] # 1 or 0?
r_x = fistaRecon.r_x
- fistaRecon.r = (r_x - (1./L_const) * vec).copy()
+ # update ring variable
+ fistaRecon.r = (r_x - (1./L_const) * vec).copy()
# subset loop
counterInd = 1
geometry_type = fistaRecon.getParameter('projector_geometry')['type']
- if geometry_type == 'parallel' or \
- geometry_type == 'fanflat' or \
- geometry_type == 'fanflat_vec' :
-
- for kkk in range(SlicesZ):
- sino_id, sinoT[kkk] = \
- astra.creators.create_sino3d_gpu(
- X_t[kkk:kkk+1], proj_geomSUB, vol_geom)
- sino_updt_Sub[kkk] = sinoT.T.copy()
-
- else:
- sino_id, sino_updt_Sub = \
- astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom)
-
- astra.matlab.data3d('delete', sino_id)
+ angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles']
+
+## if geometry_type == 'parallel' or \
+## geometry_type == 'fanflat' or \
+## geometry_type == 'fanflat_vec' :
+##
+## for kkk in range(SlicesZ):
+## sino_id, sinoT[kkk] = \
+## astra.creators.create_sino3d_gpu(
+## X_t[kkk:kkk+1], proj_geomSUB, vol_geom)
+## sino_updt_Sub[kkk] = sinoT.T.copy()
+##
+## else:
+## sino_id, sino_updt_Sub = \
+## astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom)
+##
+## astra.matlab.data3d('delete', sino_id)
for ss in range(fistaRecon.getParameter('subsets')):
print ("Subset {0}".format(ss))
@@ -276,6 +295,7 @@ if True:
else:
#PWLS model
+ # I guess we need to use mask here instead
residualSub = weights[:,CurrSubIndices,:] * \
( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() )
objective[i] = 0.5 * numpy.linalg.norm(residualSub)
@@ -310,7 +330,10 @@ if True:
# for slices:
# out = regularizer(input=X)
print ("regularizer")
- X = reg(input=X)[0]
+ reg = fistaRecon.getParameter('regularizer')
+
+ X = reg(input=X,
+ output_all=False)
## FINAL