summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/Python/ccpi/fista/FISTAReconstructor.py165
-rw-r--r--src/Python/test_reconstructor.py25
2 files changed, 165 insertions, 25 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py
index 33e67a3..fda9cf0 100644
--- a/src/Python/ccpi/fista/FISTAReconstructor.py
+++ b/src/Python/ccpi/fista/FISTAReconstructor.py
@@ -85,6 +85,7 @@ class FISTAReconstructor():
self.pars['detectors'] = detectors
self.pars['number_of_angles'] = nangles
self.pars['SlicesZ'] = sliceZ
+ self.pars['output_volume'] = None
print (self.pars)
# handle optional input parameters (at instantiation)
@@ -108,7 +109,11 @@ class FISTAReconstructor():
'regularizer' ,
'ring_lambda_R_L1',
'ring_alpha',
- 'subsets')
+ 'subsets',
+ 'output_volume',
+ 'os_subsets',
+ 'os_indices',
+ 'os_bins')
self.acceptedInputKeywords = list(kw)
# handle keyworded parameters
@@ -176,8 +181,6 @@ class FISTAReconstructor():
'''
for key , value in kwargs.items():
if key in self.acceptedInputKeywords:
- if key == 'use_studentt_fidelity':
- raise Exception('use_studentt_fidelity Not implemented')
self.pars[key] = value
else:
raise Exception('Wrong parameter {0} for '.format(key) +
@@ -382,11 +385,15 @@ class FISTAReconstructor():
counter = counter + binsDiscr[jj] - 1
-
- return IndicesReorg
+ # store the OS in parameters
+ self.setParameter(os_subsets=subsets,
+ os_bins=binsDiscr,
+ os_indices=IndicesReorg)
def prepareForIteration(self):
+ print ("FISTA Reconstructor: prepare for iteration")
+
self.residual_error = numpy.zeros((self.pars['number_of_iterations']))
self.objective = numpy.zeros((self.pars['number_of_iterations']))
@@ -401,19 +408,17 @@ class FISTAReconstructor():
if self.getParameter('Lipschitz_constant') is None:
self.pars['Lipschitz_constant'] = \
self.calculateLipschitzConstantWithPowerMethod()
+ # errors vector (if the ground truth is given)
+ self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations')));
+ # objective function values vector
+ self.objective = numpy.zeros((self.getParameter('number_of_iterations')));
# prepareForIteration
def iterate(self, Xin=None):
- # convenience variable storage
- proj_geom , vol_geom, sino , \
- SlicesZ = self.getParameter([ 'projector_geometry' ,
- 'output_geometry',
- 'input_sinogram',
- 'SlicesZ'])
-
- t = 1
+ print ("FISTA Reconstructor: iterate")
+
if Xin is None:
if self.getParameter('initialize'):
X = self.initialize()
@@ -423,15 +428,25 @@ class FISTAReconstructor():
else:
# copy by reference
X = Xin
-
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
X_t = X.copy()
+ # convenience variable storage
+ proj_geom , vol_geom, sino , \
+ SlicesZ = self.getParameter([ 'projector_geometry' ,
+ 'output_geometry',
+ 'input_sinogram',
+ 'SlicesZ' ])
+
+ t = 1
for i in range(self.getParameter('number_of_iterations')):
X_old = X.copy()
t_old = t
r_old = self.r.copy()
if self.getParameter('projector_geometry')['type'] == 'parallel' or \
- self.getParameter('projector_geometry')['type'] == 'parallel3d':
+ self.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
# if the geometry is parallel use slice-by-slice
# projection-backprojection routine
#sino_updt = zeros(size(sino),'single');
@@ -439,10 +454,9 @@ class FISTAReconstructor():
proj_geomT['DetectorRowCount'] = 1
vol_geomT = vol_geom.copy()
vol_geomT['GridSliceCount'] = 1;
- sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
+ self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
for kkk in range(SlicesZ):
- print (kkk)
- sino_id, sino_updt[kkk] = \
+ sino_id, self.sino_updt[kkk] = \
astra.creators.create_sino3d_gpu(
X_t[kkk:kkk+1], proj_geomT, vol_geomT)
astra.matlab.data3d('delete', sino_id)
@@ -450,11 +464,122 @@ class FISTAReconstructor():
# for divergent 3D geometry (watch the GPU memory overflow in
# ASTRA versions < 1.8)
#[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
- sino_id, sino_updt = astra.matlab.create_sino3d_gpu(
+ sino_id, self.sino_updt = astra.creators.create_sino3d_gpu(
X_t, proj_geom, vol_geom)
## RING REMOVAL
-
+ self.ringRemoval(i)
+ ## Projection/Backprojection Routine
+ self.projectionBackprojection(X, X_t)
+ astra.matlab.data3d('delete', sino_id)
## REGULARIZATION
+ X = self.regularize(X)
+ ## Update Loop
+ X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old)
+ self.setParameter(output_volume=X)
+ return X
+ ## iterate
+
+ def ringRemoval(self, i):
+ print ("FISTA Reconstructor: ring removal")
+ residual = self.residual
+ lambdaR_L1 , alpha_ring , weights , L_const , sino= \
+ self.getParameter(['ring_lambda_R_L1',
+ 'ring_alpha' , 'weights',
+ 'Lipschitz_constant',
+ 'input_sinogram'])
+ r_x = self.r_x
+ sino_updt = self.sino_updt
+
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram'))
+ if lambdaR_L1 > 0 :
+ for kkk in range(anglesNumb):
+
+ residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+ vec = residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:].squeeze()
+ self.r = (r_x - (1./L_const) * vec).copy()
+ self.objective[i] = (0.5 * (residual ** 2).sum())
+
+ def projectionBackprojection(self, X, X_t):
+ print ("FISTA Reconstructor: projection-backprojection routine")
+
+ # a few useful variables
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(self.getParameter('input_sinogram'))
+ residual = self.residual
+ proj_geom , vol_geom , L_const = \
+ self.getParameter(['projector_geometry' ,
+ 'output_geometry',
+ 'Lipschitz_constant'])
+
+
+ if self.getParameter('projector_geometry')['type'] == 'parallel' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ self.getParameter('projector_geometry')['type'] == 'fanflat_vec':
+ # if the geometry is parallel use slice-by-slice
+ # projection-backprojection routine
+ #sino_updt = zeros(size(sino),'single');
+ proj_geomT = proj_geom.copy()
+ proj_geomT['DetectorRowCount'] = 1
+ vol_geomT = vol_geom.copy()
+ vol_geomT['GridSliceCount'] = 1;
+ x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
+
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residual[kkk:kkk+1],
+ proj_geomT, vol_geomT)
+ astra.matlab.data3d('delete', x_id)
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residual, proj_geom, vol_geom)
+
+ X = X_t - (1/L_const) * x_temp
+ #astra.matlab.data3d('delete', sino_id)
+ astra.matlab.data3d('delete', x_id)
+
+ def regularize(self, X):
+ print ("FISTA Reconstructor: regularize")
+
+ regularizer = self.getParameter('regularizer')
+ if regularizer is not None:
+ return regularizer(input=X)
+ else:
+ return X
+
+ def updateLoop(self, i, X, X_old, r_old, t, t_old):
+ print ("FISTA Reconstructor: update loop")
+ lambdaR_L1 = self.getParameter('ring_lambda_R_L1')
+ if lambdaR_L1 > 0:
+ self.r = numpy.max(
+ numpy.abs(self.r) - lambdaR_L1 , 0) * \
+ numpy.sign(self.r)
+ t = (1 + numpy.sqrt(1 + 4 * t**2))/2
+ X_t = X + (((t_old -1)/t) * (X - X_old))
+
+ if lambdaR_L1 > 0:
+ self.r_x = self.r + \
+ (((t_old-1)/t) * (self.r - r_old))
+
+ if self.getParameter('region_of_interest') is None:
+ string = 'Iteration Number {0} | Objective {1} \n'
+ print (string.format( i, self.objective[i]))
+ else:
+ ROI , X_ideal = fistaRecon.getParameter('region_of_interest',
+ 'ideal_image')
+ Resid_error[i] = RMSE(X*ROI, X_ideal*ROI)
+ string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
+ print (string.format(i,Resid_error[i], self.objective[i]))
+ return (X , X_t, t)
diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py
index 2f188b4..07668ba 100644
--- a/src/Python/test_reconstructor.py
+++ b/src/Python/test_reconstructor.py
@@ -100,7 +100,7 @@ if False:
counter = counter + binsDiscr[jj] - 1
-if True:
+if False:
print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
print ("prepare for iteration")
fistaRecon.prepareForIteration()
@@ -145,7 +145,8 @@ if True:
t_old = t
r_old = fistaRecon.r.copy()
if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \
- fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d':
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' :
# if the geometry is parallel use slice-by-slice
# projection-backprojection routine
#sino_updt = zeros(size(sino),'single');
@@ -157,13 +158,13 @@ if True:
for kkk in range(SlicesZ):
sino_id, sino_updt[kkk] = \
astra.creators.create_sino3d_gpu(
- X_t[kkk:kkk+1], proj_geomT, vol_geomT)
+ X_t[kkk:kkk+1], proj_geom, vol_geom)
astra.matlab.data3d('delete', sino_id)
else:
# for divergent 3D geometry (watch the GPU memory overflow in
# ASTRA versions < 1.8)
#[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
- sino_id, sino_updt = astra.matlab.create_sino3d_gpu(
+ sino_id, sino_updt = astra.creators.create_sino3d_gpu(
X_t, proj_geom, vol_geom)
## RING REMOVAL
@@ -206,7 +207,8 @@ if True:
# Projection/Backprojection Routine
if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \
- fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d':
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\
+ fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec':
x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
print ("Projection/Backprojection Routine")
for kkk in range(SlicesZ):
@@ -284,3 +286,16 @@ if True:
## else
## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i));
## end
+else:
+ fistaRecon = FISTAReconstructor(proj_geom,
+ vol_geom,
+ Sino3D ,
+ weights=Weights3D)
+
+ print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
+ fistaRecon.setParameter(number_of_iterations = 12)
+ fistaRecon.setParameter(Lipschitz_constant = 767893952.0)
+ fistaRecon.setParameter(ring_alpha = 21)
+ fistaRecon.setParameter(ring_lambda_R_L1 = 0.002)
+ fistaRecon.prepareForIteration()
+ X = fistaRecon.iterate(numpy.load("X.npy"))