summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Python/ccpi/fista/FISTAReconstructor.py24
-rw-r--r--src/Python/test_reconstructor-os.py288
2 files changed, 143 insertions, 169 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py
index fda9cf0..85bfac5 100644
--- a/src/Python/ccpi/fista/FISTAReconstructor.py
+++ b/src/Python/ccpi/fista/FISTAReconstructor.py
@@ -583,3 +583,27 @@ class FISTAReconstructor():
string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
print (string.format(i,Resid_error[i], self.objective[i]))
return (X , X_t, t)
+
+ def os_iterate(self, Xin=None):
+ print ("FISTA Reconstructor: iterate")
+
+ if Xin is None:
+ if self.getParameter('initialize'):
+ X = self.initialize()
+ else:
+ N = vol_geom['GridColCount']
+ X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+ else:
+ # copy by reference
+ X = Xin
+ # store the output volume in the parameters
+ self.setParameter(output_volume=X)
+ X_t = X.copy()
+
+ # some useful constants
+ proj_geom , vol_geom, sino , \
+ SlicesZ, weights , alpha_ring ,
+ lambdaR_L1 , L_const = self.getParameter(
+ ['projector_geometry' , 'output_geometry',
+ 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' ,
+ 'ring_lambda_R_L1', 'Lipschitz_constant'])
diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py
index 6f3721f..aee70a4 100644
--- a/src/Python/test_reconstructor-os.py
+++ b/src/Python/test_reconstructor-os.py
@@ -122,10 +122,13 @@ if True:
X_t = X.copy()
print ("initialized")
proj_geom , vol_geom, sino , \
- SlicesZ = fistaRecon.getParameter(['projector_geometry' ,
- 'output_geometry',
- 'input_sinogram',
- 'SlicesZ'])
+ SlicesZ, weights , alpha_ring = fistaRecon.getParameter(
+ ['projector_geometry' , 'output_geometry',
+ 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha'])
+ lambdaR_L1 , alpha_ring , weights , L_const= \
+ fistaRecon.getParameter(['ring_lambda_R_L1',
+ 'ring_alpha' , 'weights',
+ 'Lipschitz_constant'])
#fistaRecon.setParameter(number_of_iterations = 3)
iterFISTA = fistaRecon.getParameter('number_of_iterations')
@@ -136,10 +139,14 @@ if True:
t = 1
+
## additional for
proj_geomSUB = proj_geom.copy()
- fistaRecon.residual2 = numpy.zeros(numpy.shape(self.pars['input_sinogram']))
+ fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram']))
+ residual2 = fistaRecon.residual2
+ sino_updt_FULL = fistaRecon.residual.copy()
+
print ("starting iterations")
## % Outer FISTA iterations loop
for i in range(fistaRecon.getParameter('number_of_iterations')):
@@ -153,157 +160,126 @@ if True:
# hence additional work is required one solution is to work with a full
# sinogram at times
- ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4
- if (lambdaR_L1 > 0) :
- sino_id2, sino_updt2 = astra.creators.create_sino3d_gpu(
- X, proj_geom, vol_geom)
- astra.matlab.data3d('delete', sino_id2)
+ r_old = fistaRecon.r.copy()
+ t_old = t
+ SlicesZ, anglesNumb, Detectors = \
+ numpy.shape(fistaRecon.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4
+ if (i > 1 and lambdaR_L1 > 0) :
+ for kkk in range(anglesNumb):
+
+ residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+ ((sino_updt_FULL[:,kkk,:]).squeeze() - \
+ (sino[:,kkk,:]).squeeze() -\
+ (alpha_ring * r_x)
+ )
+
+ vec = fistaRecon.residual.sum(axis = 1)
+ #if SlicesZ > 1:
+ # vec = vec[:,1,:] # 1 or 0?
+ r_x = fistaRecon.r_x
+ fistaRecon.r = (r_x - (1./L_const) * vec).copy()
# subset loop
counterInd = 1
+ geometry_type = fistaRecon.getParameter('projector_geometry')['type']
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+
+ for kkk in range(SlicesZ):
+ sino_id, sinoT[kkk] = \
+ astra.creators.create_sino3d_gpu(
+ X_t[kkk:kkk+1], proj_geomSUB, vol_geom)
+ sino_updt_Sub[kkk] = sinoT.T.copy()
+
+ else:
+ sino_id, sino_updt_Sub = \
+ astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom)
+
+ astra.matlab.data3d('delete', sino_id)
+
for ss in range(fistaRecon.getParameter('subsets')):
print ("Subset {0}".format(ss))
X_old = X.copy()
t_old = t
- r_old = fistaRecon.r.copy()
-
+
# the number of projections per subset
numProjSub = fistaRecon.getParameter('os_bins')[ss]
CurrSubIndices = fistaRecon.getParameter('os_indices')\
[counterInd:counterInd+numProjSub-1]
proj_geomSUB['ProjectionAngles'] = angles[CurrSubIndeces]
-
-## if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \
-## fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \
-## fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' :
-## # if the geometry is parallel use slice-by-slice
-## # projection-backprojection routine
-## #sino_updt = zeros(size(sino),'single');
-## proj_geomT = proj_geom.copy()
-## proj_geomT['DetectorRowCount'] = 1
-## vol_geomT = vol_geom.copy()
-## vol_geomT['GridSliceCount'] = 1;
-## sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float)
-## for kkk in range(SlicesZ):
-## sino_id, sino_updt[kkk] = \
-## astra.creators.create_sino3d_gpu(
-## X_t[kkk:kkk+1], proj_geom, vol_geom)
-## astra.matlab.data3d('delete', sino_id)
-## else:
-## # for divergent 3D geometry (watch the GPU memory overflow in
-## # ASTRA versions < 1.8)
-## #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
-## sino_id, sino_updt = astra.creators.create_sino3d_gpu(
-## X_t, proj_geom, vol_geom)
-
- ## RING REMOVAL
- residual = fistaRecon.residual
- residual2 = fistaRecon.residual2
-
- lambdaR_L1 , alpha_ring , weights , L_const= \
- fistaRecon.getParameter(['ring_lambda_R_L1',
- 'ring_alpha' , 'weights',
- 'Lipschitz_constant'])
- r_x = fistaRecon.r_x
- SlicesZ, anglesNumb, Detectors = \
- numpy.shape(fistaRecon.getParameter('input_sinogram'))
- if lambdaR_L1 > 0 :
- print ("ring removal")
-## % the ring removal part (Group-Huber fidelity)
-## % first 2 iterations do additional work reconstructing whole dataset to ensure
-## % the stablility
-## if (i < 3)
-## [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom);
-## astra_mex_data3d('delete', sino_id2);
-## else
-## [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom);
-## end
-
-## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4
- if i < 3:
- pass
- else:
- sino_id, sino_updt = astra.creators.create_sino3d_gpu(
- X_t, proj_geomSUB, vol_geom)
-## sino_id, sino_updt = astra.creators.create_sino3d_gpu(
-## X, proj_geom, vol_geom)
-## astra.matlab.data3d('delete', sino_id)
-
- for kkk in range(anglesNumb):
-
- residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
- ((sino_updt2[:,kkk,:]).squeeze() - \
- (sino[:,kkk,:]).squeeze() -\
- (alpha_ring * r_x)
- )
- shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram')))
- shape[1] = numProjSub
- fistaRecon.residual = numpy.zeros(shape)
- if fistaRecon.residual.__hash__() != residual.__hash__():
- residual = fistaRecon.residual
-## for kkk = 1:numProjSub
-## indC = CurrSubIndeces(kkk);
-## if (i < 3)
-## residual(:,kkk,:) = squeeze(residual2(:,indC,:));
-## else
-## residual(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x));
-## end
-## end
- for kk in range(numProjSub):
- indC = fistaRecon.getParameter('os_indices')[kkk]
- if i < 3:
- residual[:,kkk,:] = residual2[:,indC,:].squeeze()
- else:
- residual(:,kkk,:) = \
- weights[:,indC,:].squeeze() * sino_updt[:,kkk,:].squeeze() - \
- sino[:,indC,:].squeeze() - alpha_ring * fistaRecon.r_x
- #squeeze(weights(:,indC,:)).* \
- # (squeeze(sino_updt(:,kkk,:)) - \
- #(squeeze(sino(:,indC,:)) - alpha_ring.*r_x));
-
+ shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram')))
+ shape[1] = numProjSub
+ sino_updt_Sub = numpy.zeros(shape)
+
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+
+ for kkk in range(SlicesZ):
+ sino_id, sinoT = astra.creators.create_sino3d_gpu (
+ X_t[kkk:kkk+1] , proj_geomSUB, vol_geom)
+ sino_updt_Sub[kkk] = sinoT.T.copy()
+
+ else:
+ # for 3D geometry (watch the GPU memory overflow in ASTRA < 1.8)
+ sino_id, sino_updt_Sub = \
+ astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom)
+
+ astra.matlab.data3d('delete', sino_id)
- vec = residual.sum(axis = 1)
- #if SlicesZ > 1:
- # vec = vec[:,1,:].squeeze()
- fistaRecon.r = (r_x - (1./L_const) * vec).copy()
- objective[i] = (0.5 * (residual ** 2).sum())
-## % the ring removal part (Group-Huber fidelity)
-## for kkk = 1:anglesNumb
-## residual(:,kkk,:) = squeeze(weights(:,kkk,:)).*
-## (squeeze(sino_updt(:,kkk,:)) -
-## (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x));
-## end
-## vec = sum(residual,2);
-## if (SlicesZ > 1)
-## vec = squeeze(vec(:,1,:));
-## end
-## r = r_x - (1./L_const).*vec;
-## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output
+
-
- # Projection/Backprojection Routine
- if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \
- fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\
- fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec':
- x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32)
- print ("Projection/Backprojection Routine")
- for kkk in range(SlicesZ):
-
- x_id, x_temp[kkk] = \
- astra.creators.create_backprojection3d_gpu(
- residual[kkk:kkk+1],
- proj_geomT, vol_geomT)
- astra.matlab.data3d('delete', x_id)
- else:
- x_id, x_temp = \
- astra.creators.create_backprojection3d_gpu(
- residual, proj_geom, vol_geom)
+ ## RING REMOVAL
+ residual = fistaRecon.residual
+
+
+ if lambdaR_L1 > 0 :
+ print ("ring removal")
+ residualSub = numpy.zeros(shape)
+ ## for a chosen subset
+ ## for kkk = 1:numProjSub
+ ## indC = CurrSubIndeces(kkk);
+ ## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x));
+ ## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram
+ ## end
+ for kkk in range(numProjSub):
+ indC = CurrSubIndices[kkk]
+ residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
+ (sino_updt_Sub[:,kkk,:].squeeze() - \
+ sino[:,indC,:].squeeze() - alpha_ring * r_x)
+ # filling the full sinogram
+ sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze()
+
+ else:
+ #PWLS model
+ residualSub = weights[:,CurrSubIndices,:] * \
+ ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() )
+ objective[i] = 0.5 * numpy.linalg.norm(residualSub)
+
+ if geometry_type == 'parallel' or \
+ geometry_type == 'fanflat' or \
+ geometry_type == 'fanflat_vec' :
+ # if geometry is 2D use slice-by-slice projection-backprojection
+ # routine
+ x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32)
+ for kkk in range(SlicesZ):
+
+ x_id, x_temp[kkk] = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub[kkk:kkk+1],
+ proj_geomSUB, vol_geom)
+
+ else:
+ x_id, x_temp = \
+ astra.creators.create_backprojection3d_gpu(
+ residualSub, proj_geomSUB, vol_geom)
+
+ astra.matlab.data3d('delete', x_id)
+ X = X_t - (1/L_const) * x_temp
- X = X_t - (1/L_const) * x_temp
- astra.matlab.data3d('delete', sino_id)
- astra.matlab.data3d('delete', x_id)
## REGULARIZATION
@@ -322,12 +298,9 @@ if True:
fistaRecon.r = numpy.max(
numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \
numpy.sign(fistaRecon.r)
- t = (1 + numpy.sqrt(1 + 4 * t**2))/2
- X_t = X + (((t_old -1)/t) * (X - X_old))
-
- if lambdaR_L1 > 0:
- fistaRecon.r_x = fistaRecon.r + \
- (((t_old-1)/t) * (fistaRecon.r - r_old))
+ # updating r
+ r_x = fistaRecon.r + ((t_old-1)/t) * (fistaRecon.r - r_old)
+
if fistaRecon.getParameter('region_of_interest') is None:
string = 'Iteration Number {0} | Objective {1} \n'
@@ -340,30 +313,7 @@ if True:
string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
print (string.format(i,Resid_error[i], objective[i]))
-## if (lambdaR_L1 > 0)
-## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector
-## end
-##
-## t = (1 + sqrt(1 + 4*t^2))/2; % updating t
-## X_t = X + ((t_old-1)/t).*(X - X_old); % updating X
-##
-## if (lambdaR_L1 > 0)
-## r_x = r + ((t_old-1)/t).*(r - r_old); % updating r
-## end
-##
-## if (show == 1)
-## figure(10); imshow(X(:,:,slice), [0 maxvalplot]);
-## if (lambdaR_L1 > 0)
-## figure(11); plot(r); title('Rings offset vector')
-## end
-## pause(0.01);
-## end
-## if (strcmp(X_ideal, 'none' ) == 0)
-## Resid_error(i) = RMSE(X(ROI), X_ideal(ROI));
-## fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i));
-## else
-## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i));
-## end
+
else:
fistaRecon = FISTAReconstructor(proj_geom,
vol_geom,