diff options
Diffstat (limited to 'src/Python')
-rw-r--r-- | src/Python/ccpi/fista/FISTAReconstructor.py | 24 | ||||
-rw-r--r-- | src/Python/test_reconstructor-os.py | 288 |
2 files changed, 143 insertions, 169 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index fda9cf0..85bfac5 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -583,3 +583,27 @@ class FISTAReconstructor(): string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' print (string.format(i,Resid_error[i], self.objective[i])) return (X , X_t, t) + + def os_iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + + # some useful constants + proj_geom , vol_geom, sino , \ + SlicesZ, weights , alpha_ring , + lambdaR_L1 , L_const = self.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , + 'ring_lambda_R_L1', 'Lipschitz_constant']) diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py index 6f3721f..aee70a4 100644 --- a/src/Python/test_reconstructor-os.py +++ b/src/Python/test_reconstructor-os.py @@ -122,10 +122,13 @@ if True: X_t = X.copy() print ("initialized") proj_geom , vol_geom, sino , \ - SlicesZ = fistaRecon.getParameter(['projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ']) + SlicesZ, weights , alpha_ring = fistaRecon.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha']) + lambdaR_L1 , alpha_ring , weights , L_const= \ + fistaRecon.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant']) #fistaRecon.setParameter(number_of_iterations = 3) iterFISTA = fistaRecon.getParameter('number_of_iterations') @@ -136,10 +139,14 @@ if True: t = 1 + ## additional for proj_geomSUB = proj_geom.copy() - fistaRecon.residual2 = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) + fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram'])) + residual2 = fistaRecon.residual2 + sino_updt_FULL = fistaRecon.residual.copy() + print ("starting iterations") ## % Outer FISTA iterations loop for i in range(fistaRecon.getParameter('number_of_iterations')): @@ -153,157 +160,126 @@ if True: # hence additional work is required one solution is to work with a full # sinogram at times - ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 - if (lambdaR_L1 > 0) : - sino_id2, sino_updt2 = astra.creators.create_sino3d_gpu( - X, proj_geom, vol_geom) - astra.matlab.data3d('delete', sino_id2) + r_old = fistaRecon.r.copy() + t_old = t + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(fistaRecon.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 + if (i > 1 and lambdaR_L1 > 0) : + for kkk in range(anglesNumb): + + residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt_FULL[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + + vec = fistaRecon.residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:] # 1 or 0? + r_x = fistaRecon.r_x + fistaRecon.r = (r_x - (1./L_const) * vec).copy() # subset loop counterInd = 1 + geometry_type = fistaRecon.getParameter('projector_geometry')['type'] + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + + for kkk in range(SlicesZ): + sino_id, sinoT[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geomSUB, vol_geom) + sino_updt_Sub[kkk] = sinoT.T.copy() + + else: + sino_id, sino_updt_Sub = \ + astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom) + + astra.matlab.data3d('delete', sino_id) + for ss in range(fistaRecon.getParameter('subsets')): print ("Subset {0}".format(ss)) X_old = X.copy() t_old = t - r_old = fistaRecon.r.copy() - + # the number of projections per subset numProjSub = fistaRecon.getParameter('os_bins')[ss] CurrSubIndices = fistaRecon.getParameter('os_indices')\ [counterInd:counterInd+numProjSub-1] proj_geomSUB['ProjectionAngles'] = angles[CurrSubIndeces] - -## if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ -## fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \ -## fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' : -## # if the geometry is parallel use slice-by-slice -## # projection-backprojection routine -## #sino_updt = zeros(size(sino),'single'); -## proj_geomT = proj_geom.copy() -## proj_geomT['DetectorRowCount'] = 1 -## vol_geomT = vol_geom.copy() -## vol_geomT['GridSliceCount'] = 1; -## sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) -## for kkk in range(SlicesZ): -## sino_id, sino_updt[kkk] = \ -## astra.creators.create_sino3d_gpu( -## X_t[kkk:kkk+1], proj_geom, vol_geom) -## astra.matlab.data3d('delete', sino_id) -## else: -## # for divergent 3D geometry (watch the GPU memory overflow in -## # ASTRA versions < 1.8) -## #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); -## sino_id, sino_updt = astra.creators.create_sino3d_gpu( -## X_t, proj_geom, vol_geom) - - ## RING REMOVAL - residual = fistaRecon.residual - residual2 = fistaRecon.residual2 - - lambdaR_L1 , alpha_ring , weights , L_const= \ - fistaRecon.getParameter(['ring_lambda_R_L1', - 'ring_alpha' , 'weights', - 'Lipschitz_constant']) - r_x = fistaRecon.r_x - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(fistaRecon.getParameter('input_sinogram')) - if lambdaR_L1 > 0 : - print ("ring removal") -## % the ring removal part (Group-Huber fidelity) -## % first 2 iterations do additional work reconstructing whole dataset to ensure -## % the stablility -## if (i < 3) -## [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); -## astra_mex_data3d('delete', sino_id2); -## else -## [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); -## end - -## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 - if i < 3: - pass - else: - sino_id, sino_updt = astra.creators.create_sino3d_gpu( - X_t, proj_geomSUB, vol_geom) -## sino_id, sino_updt = astra.creators.create_sino3d_gpu( -## X, proj_geom, vol_geom) -## astra.matlab.data3d('delete', sino_id) - - for kkk in range(anglesNumb): - - residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ - ((sino_updt2[:,kkk,:]).squeeze() - \ - (sino[:,kkk,:]).squeeze() -\ - (alpha_ring * r_x) - ) - shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) - shape[1] = numProjSub - fistaRecon.residual = numpy.zeros(shape) - if fistaRecon.residual.__hash__() != residual.__hash__(): - residual = fistaRecon.residual -## for kkk = 1:numProjSub -## indC = CurrSubIndeces(kkk); -## if (i < 3) -## residual(:,kkk,:) = squeeze(residual2(:,indC,:)); -## else -## residual(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); -## end -## end - for kk in range(numProjSub): - indC = fistaRecon.getParameter('os_indices')[kkk] - if i < 3: - residual[:,kkk,:] = residual2[:,indC,:].squeeze() - else: - residual(:,kkk,:) = \ - weights[:,indC,:].squeeze() * sino_updt[:,kkk,:].squeeze() - \ - sino[:,indC,:].squeeze() - alpha_ring * fistaRecon.r_x - #squeeze(weights(:,indC,:)).* \ - # (squeeze(sino_updt(:,kkk,:)) - \ - #(squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); - + shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) + shape[1] = numProjSub + sino_updt_Sub = numpy.zeros(shape) + + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + + for kkk in range(SlicesZ): + sino_id, sinoT = astra.creators.create_sino3d_gpu ( + X_t[kkk:kkk+1] , proj_geomSUB, vol_geom) + sino_updt_Sub[kkk] = sinoT.T.copy() + + else: + # for 3D geometry (watch the GPU memory overflow in ASTRA < 1.8) + sino_id, sino_updt_Sub = \ + astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom) + + astra.matlab.data3d('delete', sino_id) - vec = residual.sum(axis = 1) - #if SlicesZ > 1: - # vec = vec[:,1,:].squeeze() - fistaRecon.r = (r_x - (1./L_const) * vec).copy() - objective[i] = (0.5 * (residual ** 2).sum()) -## % the ring removal part (Group-Huber fidelity) -## for kkk = 1:anglesNumb -## residual(:,kkk,:) = squeeze(weights(:,kkk,:)).* -## (squeeze(sino_updt(:,kkk,:)) - -## (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); -## end -## vec = sum(residual,2); -## if (SlicesZ > 1) -## vec = squeeze(vec(:,1,:)); -## end -## r = r_x - (1./L_const).*vec; -## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + - - # Projection/Backprojection Routine - if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ - fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\ - fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec': - x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) - print ("Projection/Backprojection Routine") - for kkk in range(SlicesZ): - - x_id, x_temp[kkk] = \ - astra.creators.create_backprojection3d_gpu( - residual[kkk:kkk+1], - proj_geomT, vol_geomT) - astra.matlab.data3d('delete', x_id) - else: - x_id, x_temp = \ - astra.creators.create_backprojection3d_gpu( - residual, proj_geom, vol_geom) + ## RING REMOVAL + residual = fistaRecon.residual + + + if lambdaR_L1 > 0 : + print ("ring removal") + residualSub = numpy.zeros(shape) + ## for a chosen subset + ## for kkk = 1:numProjSub + ## indC = CurrSubIndeces(kkk); + ## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); + ## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram + ## end + for kkk in range(numProjSub): + indC = CurrSubIndices[kkk] + residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ + (sino_updt_Sub[:,kkk,:].squeeze() - \ + sino[:,indC,:].squeeze() - alpha_ring * r_x) + # filling the full sinogram + sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() + + else: + #PWLS model + residualSub = weights[:,CurrSubIndices,:] * \ + ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) + objective[i] = 0.5 * numpy.linalg.norm(residualSub) + + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + # if geometry is 2D use slice-by-slice projection-backprojection + # routine + x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residualSub[kkk:kkk+1], + proj_geomSUB, vol_geom) + + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residualSub, proj_geomSUB, vol_geom) + + astra.matlab.data3d('delete', x_id) + X = X_t - (1/L_const) * x_temp - X = X_t - (1/L_const) * x_temp - astra.matlab.data3d('delete', sino_id) - astra.matlab.data3d('delete', x_id) ## REGULARIZATION @@ -322,12 +298,9 @@ if True: fistaRecon.r = numpy.max( numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ numpy.sign(fistaRecon.r) - t = (1 + numpy.sqrt(1 + 4 * t**2))/2 - X_t = X + (((t_old -1)/t) * (X - X_old)) - - if lambdaR_L1 > 0: - fistaRecon.r_x = fistaRecon.r + \ - (((t_old-1)/t) * (fistaRecon.r - r_old)) + # updating r + r_x = fistaRecon.r + ((t_old-1)/t) * (fistaRecon.r - r_old) + if fistaRecon.getParameter('region_of_interest') is None: string = 'Iteration Number {0} | Objective {1} \n' @@ -340,30 +313,7 @@ if True: string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' print (string.format(i,Resid_error[i], objective[i])) -## if (lambdaR_L1 > 0) -## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector -## end -## -## t = (1 + sqrt(1 + 4*t^2))/2; % updating t -## X_t = X + ((t_old-1)/t).*(X - X_old); % updating X -## -## if (lambdaR_L1 > 0) -## r_x = r + ((t_old-1)/t).*(r - r_old); % updating r -## end -## -## if (show == 1) -## figure(10); imshow(X(:,:,slice), [0 maxvalplot]); -## if (lambdaR_L1 > 0) -## figure(11); plot(r); title('Rings offset vector') -## end -## pause(0.01); -## end -## if (strcmp(X_ideal, 'none' ) == 0) -## Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); -## fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); -## else -## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); -## end + else: fistaRecon = FISTAReconstructor(proj_geom, vol_geom, |