diff options
-rw-r--r-- | src/Python/test_reconstructor-os.py | 136 |
1 files changed, 21 insertions, 115 deletions
diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py index 8820db7..f6d7d4b 100644 --- a/src/Python/test_reconstructor-os.py +++ b/src/Python/test_reconstructor-os.py @@ -250,102 +250,34 @@ if True: sino[:,indC,:].squeeze() - alpha_ring * r_x) # filling the full sinogram sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() - - -## % if geometry is 2D use slice-by-slice projection-backprojection routine -## for kkk = 1:SlicesZ -## [sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geomSUB, vol_geom); -## sino_updt_Sub(:,:,kkk) = sinoT'; -## astra_mex_data2d('delete', sino_id); -## end - -## -## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 - if i < 3: - pass - else: - sino_id, sino_updt = astra.creators.create_sino3d_gpu( - X_t, proj_geomSUB, vol_geom) -## sino_id, sino_updt = astra.creators.create_sino3d_gpu( -## X, proj_geom, vol_geom) -## astra.matlab.data3d('delete', sino_id) - - for kkk in range(anglesNumb): - - residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ - ((sino_updt2[:,kkk,:]).squeeze() - \ - (sino[:,kkk,:]).squeeze() -\ - (alpha_ring * r_x) - ) - shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) - shape[1] = numProjSub - fistaRecon.residual = numpy.zeros(shape) - if fistaRecon.residual.__hash__() != residual.__hash__(): - residual = fistaRecon.residual -## for kkk = 1:numProjSub -## indC = CurrSubIndeces(kkk); -## if (i < 3) -## residual(:,kkk,:) = squeeze(residual2(:,indC,:)); -## else -## residual(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); -## end -## end - for kk in range(numProjSub): - indC = fistaRecon.getParameter('os_indices')[kkk] - if i < 3: - residual[:,kkk,:] = residual2[:,indC,:].squeeze() - else: - residual[:,kkk,:] = \ - weights[:,indC,:].squeeze() * sino_updt[:,kkk,:].squeeze() - \ - sino[:,indC,:].squeeze() - alpha_ring * fistaRecon.r_x - #squeeze(weights(:,indC,:)).* \ - # (squeeze(sino_updt(:,kkk,:)) - \ - #(squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); - - - - vec = residual.sum(axis = 1) - #if SlicesZ > 1: - # vec = vec[:,1,:].squeeze() - fistaRecon.r = (r_x - (1./L_const) * vec).copy() - objective[i] = (0.5 * (residual ** 2).sum()) -## % the ring removal part (Group-Huber fidelity) -## for kkk = 1:anglesNumb -## residual(:,kkk,:) = squeeze(weights(:,kkk,:)).* -## (squeeze(sino_updt(:,kkk,:)) - -## (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); -## end -## vec = sum(residual,2); -## if (SlicesZ > 1) -## vec = squeeze(vec(:,1,:)); -## end -## r = r_x - (1./L_const).*vec; -## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output - - + else: + #PWLS model + residualSub = weights[:,CurrSubIndices,:] * \ + ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) + objective[i] = 0.5 * numpy.linalg.norm(residualSub) - # Projection/Backprojection Routine - if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ - fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\ - fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec': - x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) - print ("Projection/Backprojection Routine") + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + # if geometry is 2D use slice-by-slice projection-backprojection + # routine + x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) for kkk in range(SlicesZ): x_id, x_temp[kkk] = \ astra.creators.create_backprojection3d_gpu( - residual[kkk:kkk+1], - proj_geomT, vol_geomT) - astra.matlab.data3d('delete', x_id) + residualSub[kkk:kkk+1], + proj_geomSUB, vol_geom) + else: x_id, x_temp = \ astra.creators.create_backprojection3d_gpu( - residual, proj_geom, vol_geom) + residualSub, proj_geomSUB, vol_geom) - X = X_t - (1/L_const) * x_temp - astra.matlab.data3d('delete', sino_id) astra.matlab.data3d('delete', x_id) + X = X_t - (1/L_const) * x_temp + ## REGULARIZATION @@ -364,12 +296,9 @@ if True: fistaRecon.r = numpy.max( numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ numpy.sign(fistaRecon.r) - t = (1 + numpy.sqrt(1 + 4 * t**2))/2 - X_t = X + (((t_old -1)/t) * (X - X_old)) - - if lambdaR_L1 > 0: - fistaRecon.r_x = fistaRecon.r + \ - (((t_old-1)/t) * (fistaRecon.r - r_old)) + # updating r + r_x = fistaRecon.r + ((t_old-1)/t) * (fistaRecon.r - r_old) + if fistaRecon.getParameter('region_of_interest') is None: string = 'Iteration Number {0} | Objective {1} \n' @@ -382,30 +311,7 @@ if True: string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' print (string.format(i,Resid_error[i], objective[i])) -## if (lambdaR_L1 > 0) -## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector -## end -## -## t = (1 + sqrt(1 + 4*t^2))/2; % updating t -## X_t = X + ((t_old-1)/t).*(X - X_old); % updating X -## -## if (lambdaR_L1 > 0) -## r_x = r + ((t_old-1)/t).*(r - r_old); % updating r -## end -## -## if (show == 1) -## figure(10); imshow(X(:,:,slice), [0 maxvalplot]); -## if (lambdaR_L1 > 0) -## figure(11); plot(r); title('Rings offset vector') -## end -## pause(0.01); -## end -## if (strcmp(X_ideal, 'none' ) == 0) -## Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); -## fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); -## else -## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); -## end + else: fistaRecon = FISTAReconstructor(proj_geom, vol_geom, |