diff options
-rw-r--r-- | src/Python/ccpi/fista/FISTAReconstructor.py | 24 | ||||
-rw-r--r-- | src/Python/test_reconstructor-os.py | 112 |
2 files changed, 81 insertions, 55 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index fda9cf0..85bfac5 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -583,3 +583,27 @@ class FISTAReconstructor(): string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' print (string.format(i,Resid_error[i], self.objective[i])) return (X , X_t, t) + + def os_iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + + # some useful constants + proj_geom , vol_geom, sino , \ + SlicesZ, weights , alpha_ring , + lambdaR_L1 , L_const = self.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , + 'ring_lambda_R_L1', 'Lipschitz_constant']) diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py index f6d7d4b..aee70a4 100644 --- a/src/Python/test_reconstructor-os.py +++ b/src/Python/test_reconstructor-os.py @@ -122,10 +122,13 @@ if True: X_t = X.copy() print ("initialized") proj_geom , vol_geom, sino , \ - SlicesZ = fistaRecon.getParameter(['projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ']) + SlicesZ, weights , alpha_ring = fistaRecon.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha']) + lambdaR_L1 , alpha_ring , weights , L_const= \ + fistaRecon.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant']) #fistaRecon.setParameter(number_of_iterations = 3) iterFISTA = fistaRecon.getParameter('number_of_iterations') @@ -136,12 +139,13 @@ if True: t = 1 + ## additional for proj_geomSUB = proj_geom.copy() fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram'])) residual2 = fistaRecon.residual2 - sino_updt_FULL = residual.copy() + sino_updt_FULL = fistaRecon.residual.copy() print ("starting iterations") ## % Outer FISTA iterations loop @@ -156,7 +160,8 @@ if True: # hence additional work is required one solution is to work with a full # sinogram at times - + r_old = fistaRecon.r.copy() + t_old = t SlicesZ, anglesNumb, Detectors = \ numpy.shape(fistaRecon.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 if (i > 1 and lambdaR_L1 > 0) : @@ -167,8 +172,8 @@ if True: (sino[:,kkk,:]).squeeze() -\ (alpha_ring * r_x) ) - r_old = fistaRecon.r.copy() - vec = residual.sum(axis = 1) + + vec = fistaRecon.residual.sum(axis = 1) #if SlicesZ > 1: # vec = vec[:,1,:] # 1 or 0? r_x = fistaRecon.r_x @@ -227,56 +232,53 @@ if True: - ## RING REMOVAL - residual = fistaRecon.residual - - lambdaR_L1 , alpha_ring , weights , L_const= \ - fistaRecon.getParameter(['ring_lambda_R_L1', - 'ring_alpha' , 'weights', - 'Lipschitz_constant']) - if lambdaR_L1 > 0 : - print ("ring removal") - residualSub = numpy.zeros(shape) -## for a chosen subset -## for kkk = 1:numProjSub -## indC = CurrSubIndeces(kkk); -## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); -## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram -## end - for kkk in range(numProjSub): - indC = CurrSubIndices[kkk] - residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ - (sino_updt_Sub[:,kkk,:].squeeze() - \ - sino[:,indC,:].squeeze() - alpha_ring * r_x) - # filling the full sinogram - sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() + ## RING REMOVAL + residual = fistaRecon.residual + + + if lambdaR_L1 > 0 : + print ("ring removal") + residualSub = numpy.zeros(shape) + ## for a chosen subset + ## for kkk = 1:numProjSub + ## indC = CurrSubIndeces(kkk); + ## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); + ## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram + ## end + for kkk in range(numProjSub): + indC = CurrSubIndices[kkk] + residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ + (sino_updt_Sub[:,kkk,:].squeeze() - \ + sino[:,indC,:].squeeze() - alpha_ring * r_x) + # filling the full sinogram + sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() - else: - #PWLS model - residualSub = weights[:,CurrSubIndices,:] * \ - ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) - objective[i] = 0.5 * numpy.linalg.norm(residualSub) + else: + #PWLS model + residualSub = weights[:,CurrSubIndices,:] * \ + ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) + objective[i] = 0.5 * numpy.linalg.norm(residualSub) - if geometry_type == 'parallel' or \ - geometry_type == 'fanflat' or \ - geometry_type == 'fanflat_vec' : - # if geometry is 2D use slice-by-slice projection-backprojection - # routine - x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) - for kkk in range(SlicesZ): - - x_id, x_temp[kkk] = \ - astra.creators.create_backprojection3d_gpu( - residualSub[kkk:kkk+1], - proj_geomSUB, vol_geom) - - else: - x_id, x_temp = \ - astra.creators.create_backprojection3d_gpu( - residualSub, proj_geomSUB, vol_geom) + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + # if geometry is 2D use slice-by-slice projection-backprojection + # routine + x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residualSub[kkk:kkk+1], + proj_geomSUB, vol_geom) + + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residualSub, proj_geomSUB, vol_geom) - astra.matlab.data3d('delete', x_id) - X = X_t - (1/L_const) * x_temp + astra.matlab.data3d('delete', x_id) + X = X_t - (1/L_const) * x_temp |