diff options
-rw-r--r-- | demos/Demo_Phantom3D_Parallel.m | 5 | ||||
-rw-r--r-- | main_func/FISTA_REC.m | 179 |
2 files changed, 108 insertions, 76 deletions
diff --git a/demos/Demo_Phantom3D_Parallel.m b/demos/Demo_Phantom3D_Parallel.m index fd8096a..ac9827c 100644 --- a/demos/Demo_Phantom3D_Parallel.m +++ b/demos/Demo_Phantom3D_Parallel.m @@ -33,6 +33,7 @@ params.sino = single(sino_tomophan3D); % sinogram params.iterFISTA = 5; %max number of outer iterations
params.X_ideal = TomoPhantom; % ideal phantom
params.show = 1; % visualize reconstruction on each iteration
+params.subsets = 12;
params.slice = round(N/2); params.maxvalplot = 1;
tic; [X_FISTA, output] = FISTA_REC(params); toc;
@@ -44,6 +45,6 @@ figure(2); subplot(1,2,1); imshow(X_FISTA(:,:,params.slice),[0 params.maxvalplot]); title('FISTA-LS reconstruction'); colorbar;
subplot(1,2,2); imshow(Resid3D(:,:,params.slice),[0 0.1]); title('residual'); colorbar;
figure(3);
-subplot(1,2,1); plot(error_FISTA); title('RMSE plot'); colorbar;
-subplot(1,2,2); plot(obj_FISTA); title('Objective plot'); colorbar;
+subplot(1,2,1); plot(error_FISTA); title('RMSE plot');
+subplot(1,2,2); plot(obj_FISTA); title('Objective plot');
%%
\ No newline at end of file diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m index dde0e73..bea1860 100644 --- a/main_func/FISTA_REC.m +++ b/main_func/FISTA_REC.m @@ -106,14 +106,14 @@ if (isfield(params,'L_const')) else % using Power method (PM) to establish L constant fprintf('%s %s %s \n', 'Calculating Lipshitz constant for',proj_geom.type, 'beam geometry...'); - if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec')) - % for 2D geometry we can do just one selected slice + if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec')) + % for 2D geometry we can do just one selected slice niter = 15; % number of iteration for the PM x1 = rand(N,N,1); sqweight = sqrt(weights(:,:,1)); [sino_id, y] = astra_create_sino_cuda(x1, proj_geom, vol_geom); y = sqweight.*y'; - astra_mex_data2d('delete', sino_id); + astra_mex_data2d('delete', sino_id); for i = 1:niter [x1] = astra_create_backprojection_cuda((sqweight.*y)', proj_geom, vol_geom); s = norm(x1(:)); @@ -121,9 +121,9 @@ else [sino_id, y] = astra_create_sino_cuda(x1, proj_geom, vol_geom); y = sqweight.*y'; astra_mex_data2d('delete', sino_id); - end + end elseif (strcmp(proj_geom.type,'cone') || strcmp(proj_geom.type,'parallel3d') || strcmp(proj_geom.type,'parallel3d_vec') || strcmp(proj_geom.type,'cone_vec')) - % 3D geometry + % 3D geometry niter = 8; % number of iteration for PM x1 = rand(N,N,SlicesZ); sqweight = sqrt(weights); @@ -268,7 +268,7 @@ if (isfield(params,'initialize')) X = params.initialize; if ((size(X,1) ~= N) || (size(X,2) ~= N) || (size(X,3) ~= SlicesZ)) error('%s \n', 'The initialized volume has different dimensions!'); - end + end else X = zeros(N,N,SlicesZ, 'single'); % storage for the solution end @@ -320,10 +320,11 @@ if (subsets == 0) t_old = t; r_old = r; - % if the geometry is 2D use slice-by-slice projection-backprojection routine - if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec')) + + if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec')) + % if geometry is 2D use slice-by-slice projection-backprojection routine sino_updt = zeros(size(sino),'single'); - for kkk = 1:SlicesZ + for kkk = 1:SlicesZ [sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geom, vol_geom); sino_updt(:,:,kkk) = sinoT'; astra_mex_data2d('delete', sino_id); @@ -359,12 +360,11 @@ if (subsets == 0) else % no ring removal (LS model) residual = weights.*(sino_updt - sino); - objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + objective(i) = 0.5*norm(residual(:)); % for the objective function output end - % if the geometry is 2D use slice-by-slice projection-backprojection routine - if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec')) + if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec')) x_temp = zeros(size(X),'single'); for kkk = 1:SlicesZ [x_temp(:,:,kkk)] = astra_create_backprojection_cuda(squeeze(residual(:,:,kkk))', proj_geom, vol_geom); @@ -373,9 +373,9 @@ if (subsets == 0) [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geom, vol_geom); astra_mex_data3d('delete', id); end - X = X_t - (1/L_const).*x_temp; + X = X_t - (1/L_const).*x_temp; - % ----------------Regularization part------------------------ + % ----------------Regularization part------------------------% if (lambdaFGP_TV > 0) % FGP-TV regularization if ((strcmp('2D', Dimension) == 1)) @@ -484,94 +484,128 @@ if (subsets == 0) end end else - % Ordered Subsets (OS) FISTA reconstruction routine (normally one order of magnitude faster than classical) + % Ordered Subsets (OS) FISTA reconstruction routine (normally one order of magnitude faster than the classical version) t = 1; X_t = X; - proj_geomSUB = proj_geom; - + proj_geomSUB = proj_geom; r = zeros(Detectors,SlicesZ, 'single'); % 2D array (for 3D data) of sparse "ring" vectors r_x = r; % another ring variable residual2 = zeros(size(sino),'single'); + sino_updt_FULL = zeros(size(sino),'single'); % Outer FISTA iterations loop - for i = 1:iterFISTA + for i = 1:iterFISTA - % With OS approach it becomes trickier to correlate independent subsets, hence additional work is required - % one solution is to work with a full sinogram at times - if ((i >= 3) && (lambdaR_L1 > 0)) - [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X, proj_geom, vol_geom); - astra_mex_data3d('delete', sino_id2); + if ((i > 1) && (lambdaR_L1 > 0)) + % in order to make Group-Huber fidelity work with ordered subsets + % we still need to work with full sinogram + + % the offset variable must be calculated for the whole + % updated sinogram - sino_updt_FULL + for kkk = 1:anglesNumb + residual2(:,kkk,:) = squeeze(weights(:,kkk,:)).*(squeeze(sino_updt_FULL(:,kkk,:)) - (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); + end + + r_old = r; + vec = sum(residual2,2); + if (SlicesZ > 1) + vec = squeeze(vec(:,1,:)); + end + r = r_x - (1./L_const).*vec; % update ring variable end % subsets loop counterInd = 1; for ss = 1:subsets X_old = X; - t_old = t; - r_old = r; + t_old = t; numProjSub = binsDiscr(ss); % the number of projections per subset CurrSubIndeces = IndicesReorg(counterInd:(counterInd + numProjSub - 1)); % extract indeces attached to the subset proj_geomSUB.ProjectionAngles = angles(CurrSubIndeces); + sino_updt_Sub = zeros(Detectors, numProjSub, SlicesZ,'single'); if (lambdaR_L1 > 0) - - % the ring removal part (Group-Huber fidelity) - % first 2 iterations do additional work reconstructing whole dataset to ensure - % the stablility - if (i < 3) - [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - astra_mex_data3d('delete', sino_id2); + % Group-Huber fidelity (ring removal) + if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec')) + % if geometry is 2D use slice-by-slice projection-backprojection routine + for kkk = 1:SlicesZ + [sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geomSUB, vol_geom); + sino_updt_Sub(:,:,kkk) = sinoT'; + astra_mex_data2d('delete', sino_id); + end else - [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); - end - - for kkk = 1:anglesNumb - residual2(:,kkk,:) = squeeze(weights(:,kkk,:)).*(squeeze(sino_updt2(:,kkk,:)) - (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); + % for 3D geometry (watch the GPU memory overflow in earlier ASTRA versions < 1.8) + [sino_id, sino_updt_Sub] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); + astra_mex_data3d('delete', sino_id); end - residual = zeros(Detectors, numProjSub, SlicesZ,'single'); + residualSub = zeros(Detectors, numProjSub, SlicesZ,'single'); % residual for a chosen subset for kkk = 1:numProjSub indC = CurrSubIndeces(kkk); - if (i < 3) - residual(:,kkk,:) = squeeze(residual2(:,indC,:)); - else - residual(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); + residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); + sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram + end + + elseif (studentt > 0) + % student t data fidelity + if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec')) + % if geometry is 2D use slice-by-slice projection-backprojection routine + for kkk = 1:SlicesZ + [sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geomSUB, vol_geom); + sino_updt_Sub(:,:,kkk) = sinoT'; + astra_mex_data2d('delete', sino_id); end + else + % for 3D geometry (watch the GPU memory overflow in earlier ASTRA versions < 1.8) + [sino_id, sino_updt_Sub] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); + astra_mex_data3d('delete', sino_id); end - vec = sum(residual2,2); - if (SlicesZ > 1) - vec = squeeze(vec(:,1,:)); + + % artifacts removal with Students t penalty + residualSub = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt_Sub - squeeze(sino(:,CurrSubIndeces,:))); + + for kkk = 1:SlicesZ + res_vec = reshape(residualSub(:,:,kkk), Detectors*numProjSub, 1); % 1D vectorized sinogram + %s = 100; + %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec); + [ff, gr] = studentst(res_vec, 1); + residualSub(:,:,kkk) = reshape(gr, Detectors, numProjSub); end - r = r_x - (1./L_const).*vec; + objective(i) = ff; % for the objective function output else - [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); - - if (studentt == 1) - % artifacts removal with Students t penalty - residual = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt - squeeze(sino(:,CurrSubIndeces,:))); - + % PWLS model + if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec')) + % if geometry is 2D use slice-by-slice projection-backprojection routine for kkk = 1:SlicesZ - res_vec = reshape(residual(:,:,kkk), Detectors*numProjSub, 1); % 1D vectorized sinogram - %s = 100; - %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec); - [ff, gr] = studentst(res_vec, 1); - residual(:,:,kkk) = reshape(gr, Detectors, numProjSub); + [sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geomSUB, vol_geom); + sino_updt_Sub(:,:,kkk) = sinoT'; + astra_mex_data2d('delete', sino_id); end - objective(i) = ff; % for the objective function output else - % no ring removal (LS model) - residual = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt - squeeze(sino(:,CurrSubIndeces,:))); + % for 3D geometry (watch the GPU memory overflow in earlier ASTRA versions < 1.8) + [sino_id, sino_updt_Sub] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); + astra_mex_data3d('delete', sino_id); end + residualSub = squeeze(weights(:,CurrSubIndeces,:)).*(sino_updt_Sub - squeeze(sino(:,CurrSubIndeces,:))); + objective(i) = 0.5*norm(residualSub(:)); % for the objective function output end - [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geomSUB, vol_geom); - X = X_t - (1/L_const).*x_temp; - astra_mex_data3d('delete', sino_id); - astra_mex_data3d('delete', id); + if (strcmp(proj_geom.type,'parallel') || strcmp(proj_geom.type,'fanflat') || strcmp(proj_geom.type,'fanflat_vec')) + % if geometry is 2D use slice-by-slice projection-backprojection routine + x_temp = zeros(size(X),'single'); + for kkk = 1:SlicesZ + [x_temp(:,:,kkk)] = astra_create_backprojection_cuda(squeeze(residualSub(:,:,kkk))', proj_geomSUB, vol_geom); + end + else + [id, x_temp] = astra_create_backprojection3d_cuda(residualSub, proj_geomSUB, vol_geom); + astra_mex_data3d('delete', id); + end + + X = X_t - (1/L_const).*x_temp; - % regularization + % ----------------Regularization part------------------------% if (lambdaFGP_TV > 0) % FGP-TV regularization if ((strcmp('2D', Dimension) == 1)) @@ -653,20 +687,17 @@ else end end - if (lambdaR_L1 > 0) - r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector - end - t = (1 + sqrt(1 + 4*t^2))/2; % updating t X_t = X + ((t_old-1)/t).*(X - X_old); % updating X - - if (lambdaR_L1 > 0) - r_x = r + ((t_old-1)/t).*(r - r_old); % updating r - end - counterInd = counterInd + numProjSub; end + % working with a 'ring vector' + if (lambdaR_L1 > 0) + r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector + r_x = r + ((t_old-1)/t).*(r - r_old); % updating r + end + if (show == 1) figure(10); imshow(X(:,:,slice), [0 maxvalplot]); if (lambdaR_L1 > 0) |