diff options
-rw-r--r-- | demos/DemoRD2.m | 9 | ||||
-rw-r--r-- | main_func/FISTA_REC.m | 206 |
2 files changed, 89 insertions, 126 deletions
diff --git a/demos/DemoRD2.m b/demos/DemoRD2.m index 518e24b..b991e70 100644 --- a/demos/DemoRD2.m +++ b/demos/DemoRD2.m @@ -60,7 +60,7 @@ params.vol_geom = vol_geom; params.sino = Sino3D; params.iterFISTA = 40; params.L_const = 7.6789e+08; -params.Regul_LambdaTV = 0.005; % TV regularization parameter for FISTA-TV +params.Regul_Lambda_FGPTV = 0.005; % TV regularization parameter for FGP-TV params.weights = Weights3D; params.show = 1; params.maxvalplot = 2.5; params.slice = 10; @@ -76,7 +76,7 @@ params.vol_geom = vol_geom; params.sino = Sino3D; params.iterFISTA = 40; params.L_const = 7.6789e+08; -params.Regul_LambdaTV = 0.005; % TV regularization parameter for FISTA-TV +params.Regul_Lambda_FGPTV = 0.005; % TV regularization parameter for FGP-TV params.Ring_LambdaR_L1 = 0.002; % Soft-Thresh L1 ring variable parameter params.Ring_Alpha = 21; % to boost ring removal procedure params.weights = Weights3D; @@ -93,10 +93,9 @@ params.proj_geom = proj_geom; % pass geometry to the function params.vol_geom = vol_geom; params.sino = Sino3D; params.iterFISTA = 40; -params.Regul_LambdaTV = 0.005; % TV regularization parameter for FISTA-TV +params.Regul_Lambda_FGPTV = 0.005; % TV regularization parameter for FGP-TV params.Regul_LambdaHO = 200; % regularization parameter for LLT problem -params.Regul_tauHO = 0.0005; % time-step parameter for the explicit scheme -params.Regul_iterHO = 250; % the max number of TV iterations +params.Regul_tauLLT = 0.0005; % time-step parameter for the explicit scheme params.Ring_LambdaR_L1 = 0.002; % Soft-Thresh L1 ring variable parameter params.Ring_Alpha = 21; % to boost ring removal procedure params.weights = Weights3D; diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m index 688dcc3..18e430e 100644 --- a/main_func/FISTA_REC.m +++ b/main_func/FISTA_REC.m @@ -6,30 +6,33 @@ function [X, output] = FISTA_REC(params) % - .proj_geom (geometry of the projector) [required] % - .vol_geom (geometry of the reconstructed object) [required] % - .sino (vectorized in 2D or 3D sinogram) [required] -% - .iterFISTA (iterations for the main loop) +% - .iterFISTA (iterations for the main loop, default 40) % - .L_const (Lipschitz constant, default Power method) ) % - .X_ideal (ideal image, if given) % - .weights (statisitcal weights, size of the sinogram) % - .ROI (Region-of-interest, only if X_ideal is given) -% - .Regul_LambdaTV (TV regularization parameter, default 0 - reg. TV is switched off) -% - .Regul_tol (tolerance to terminate TV regularization, default 1.0e-04) -% - .Regul_iterTV (iterations for the TV penalty, default 0) -% - .Regul_LambdaHO (Higher Order LLT regularization parameter, default 0 - LLT reg. switched off) -% - .Regul_iterHO (iterations for HO penalty, default 50) -% - .Regul_tauHO (time step parameter for HO term) -% - .Ring_LambdaR_L1 (regularization parameter for L1 ring minimization, if lambdaR_L1 > 0 then switch on ring removal, default 0) +% - .fidelity (choose between "LS" and "student" data fidelities, default LS) +% - .initialize (a 'warm start' using SIRT method from ASTRA) +%----------------Regularization choices------------------------ +% - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) +% - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) +% - .Regul_Lambda_L1 (L1 regularization by soft-thresholding) +% - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) +% - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) +% - .Regul_Iterations (iterations for the selected penalty, default 25) +% - .Regul_tauLLT (time step parameter for LLT term) +% - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) % - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) -% - .fidelity (choose between "LS" and "student" data fidelities) -% - .initializ (a 'warm start' using SIRT method from ASTRA) -% - .precondition (1 - switch on Fourier filtering before backprojection) +%----------------Visualization parameters------------------------ % - .show (visualize reconstruction 1/0, (0 default)) % - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) % - .slice (for 3D volumes - slice number to imshow) % ___Output___: % 1. X - reconstructed image/volume -% 2. Resid_error - residual error (if X_ideal is given) -% 3. value of the objective function -% 4. forward projection of X +% 2. output - structure with +% - .Resid_error - residual error (if X_ideal is given) +% - .objective: value of the objective function +% - .L_const: Lipshitz constant to avoid recalculations % References: % 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse @@ -61,7 +64,7 @@ end if (isfield(params,'iterFISTA')) iterFISTA = params.iterFISTA; else - iterFISTA = 30; + iterFISTA = 40; end if (isfield(params,'weights')) weights = params.weights; @@ -72,7 +75,7 @@ if (isfield(params,'L_const')) L_const = params.L_const; else % using Power method (PM) to establish L constant - niter = 6; % number of iteration for PM + niter = 8; % number of iteration for PM x = rand(N,N,SlicesZ); sqweight = sqrt(weights); [sino_id, y] = astra_create_sino3d_cuda(x, proj_geom, vol_geom); @@ -100,20 +103,30 @@ if (isfield(params,'ROI')) else ROI = find(X_ideal>=0.0); end -if (isfield(params,'Regul_LambdaTV')) - lambdaTV = params.Regul_LambdaTV; +if (isfield(params,'Regul_Lambda_FGPTV')) + lambdaFGP_TV = params.Regul_Lambda_FGPTV; else - lambdaTV = 0; + lambdaFGP_TV = 0; +end +if (isfield(params,'Regul_Lambda_SBTV')) + lambdaSB_TV = params.Regul_Lambda_SBTV; +else + lambdaSB_TV = 0; +end +if (isfield(params,'Regul_Lambda_L1')) + lambdaL1 = params.Regul_Lambda_L1; +else + lambdaL1 = 0; end if (isfield(params,'Regul_tol')) tol = params.Regul_tol; else tol = 1.0e-04; end -if (isfield(params,'Regul_iterTV')) - iterTV = params.Regul_iterTV; +if (isfield(params,'Regul_Iterations')) + IterationsRegul = params.Regul_Iterations; else - iterTV = 25; + IterationsRegul = 25; end if (isfield(params,'Regul_LambdaHO')) lambdaHO = params.Regul_LambdaHO; @@ -125,8 +138,8 @@ if (isfield(params,'Regul_iterHO')) else iterHO = 50; end -if (isfield(params,'Regul_tauHO')) - tauHO = params.Regul_tauHO; +if (isfield(params,'Regul_tauLLT')) + tauHO = params.Regul_tauLLT; else tauHO = 0.0001; end @@ -191,132 +204,83 @@ end Resid_error = zeros(iterFISTA,1); % error vector objective = zeros(iterFISTA,1); % obhective vector -if (lambdaR_L1 > 0) - % do reconstruction WITH ring removal (Group-Huber fidelity) t = 1; X_t = X; - add_ring = zeros(size(sino),'single'); % size of sinogram array + % add_ring = zeros(size(sino),'single'); % size of sinogram array r = zeros(Detectors,SlicesZ, 'single'); % 2D array (for 3D data) of sparse "ring" vectors - r_x = r; + r_x = r; % another ring variable - % iterations loop + + % Outer iterations loop for i = 1:iterFISTA X_old = X; - t_old = t; - r_old = r; - + t_old = t; + r_old = r; + [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - for kkk = 1:anglesNumb - add_ring(:,kkk,:) = squeeze(sino(:,kkk,:)) - alpha_ring.*r_x; - end - - residual = weights.*(sino_updt - add_ring); - - vec = sum(residual,2); - if (SlicesZ > 1) - vec = squeeze(vec(:,1,:)); - end - - r = r_x - (1./L_const).*vec; - + if (lambdaR_L1 > 1) + % add ring removal part (Group-Huber fidelity) + for kkk = 1:anglesNumb + % add_ring(:,kkk,:) = squeeze(sino(:,kkk,:)) - alpha_ring.*r_x; + residual(:,kkk,:) = weights(:,kkk,:).*(sino_updt(:,kkk,:) - (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); + end + + vec = sum(residual,2); + if (SlicesZ > 1) + vec = squeeze(vec(:,1,:)); + end + r = r_x - (1./L_const).*vec; + else + % no ring removal + residual = weights.*(sino_updt - sino); + end + % residual = weights.*(sino_updt - add_ring); + [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geom, vol_geom); X = X_t - (1/L_const).*x_temp; astra_mex_data3d('delete', sino_id); astra_mex_data3d('delete', id); - if ((lambdaTV > 0) && (lambdaHO == 0)) - [X, f_val] = FGP_TV(single(X), lambdaTV, iterTV, tol); % TV regularization using FISTA + if (lambdaFGP_TV > 0) + % FGP-TV regularization + [X, f_val] = FGP_TV(single(X), lambdaFGP_TV, IterationsRegul, tol, 'iso'); objective(i) = 0.5.*norm(residual(:))^2 + f_val; - % X = SplitBregman_TV(single(X), lambdaTV, iterTV, tol); % TV-Split Bregman regularization on CPU (memory limited) - elseif ((lambdaHO > 0) && (lambdaTV == 0)) - % Higher Order regularization - X = LLT_model(single(X), lambdaHO, tauHO, iterHO, tol, 0); % LLT higher order model - elseif ((lambdaTV > 0) && (lambdaHO > 0)) - %X1 = SplitBregman_TV(single(X), lambdaTV, iterTV, tol); % TV-Split Bregman regularization on CPU (memory limited) - X1 = FGP_TV(single(X), lambdaTV, iterTV, tol); % TV regularization using FISTA - X2 = LLT_model(single(X), lambdaHO, tauHO, iterHO, 3.0e-05, 0); % LLT higher order model - X = 0.5.*(X1 + X2); % averaged combination of two solutions - elseif ((lambdaTV == 0) && (lambdaHO == 0)) + elseif (lambdaSB_TV > 0) + % Split Bregman regularization + X = SplitBregman_TV(single(X), lambdaSB_TV, IterationsRegul, tol); % (more memory efficent) + objective(i) = 0.5.*norm(residual(:))^2; + elseif (lambdaL1 > 0) + % L1 soft-threhsolding regularization + X = max(abs(X)-lambdaL1, 0).*sign(X); + objective(i) = 0.5.*norm(residual(:))^2; + elseif (lambdaHO > 0) + % Higher Order (LLT) regularization + X2 = LLT_model(single(X), lambdaHO, tauHO, IterationsRegul, 3.0e-05, 0); + X = 0.5.*(X + X2); % averaged combination of two solutions objective(i) = 0.5.*norm(residual(:))^2; end + if (lambdaR_L1 > 1) r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator + end t = (1 + sqrt(1 + 4*t^2))/2; % updating t X_t = X + ((t_old-1)/t).*(X - X_old); % updating X + + if (lambdaR_L1 > 1) r_x = r + ((t_old-1)/t).*(r - r_old); % updating r + end if (show == 1) figure(10); imshow(X(:,:,slice), [0 maxvalplot]); + if (lambdaR_L1 > 1) figure(11); plot(r); title('Rings offset vector') - pause(0.03); - end - if (strcmp(X_ideal, 'none' ) == 0) - Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); - fprintf('%s %i %s %s %.4f %s %s %.4f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); - else - fprintf('%s %i %s %s %.4f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); - end - - end - -else - % WITHOUT ring removal - t = 1; - X_t = X; - - % FISTA outer iterations loop - for i = 1:iterFISTA - - X_old = X; - t_old = t; - - [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - residual = weights.*(sino_updt - sino); - - % employ students t fidelity term - if (strcmp(fidelity,'student') == 1) - res_vec = reshape(residual, anglesNumb*Detectors*SlicesZ,1); - %s = 100; - %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec); - [ff, gr] = studentst(res_vec,1); - residual = reshape(gr, Detectors, anglesNumb, SlicesZ); - end - - [id, x_temp] = astra_create_backprojection3d_cuda(residual, proj_geom, vol_geom); - X = X_t - (1/L_const).*x_temp; - astra_mex_data3d('delete', sino_id); - astra_mex_data3d('delete', id); - - if ((lambdaTV > 0) && (lambdaHO == 0)) - [X,f_val] = FGP_TV(single(X), lambdaTV, iterTV, tol); % TV regularization using FISTA - if (strcmp(fidelity,'student') == 1) - objective(i) = ff + f_val; - else - objective(i) = 0.5.*norm(residual(:))^2 + f_val; end - %X = SplitBregman_TV(single(X), lambdaTV, iterTV, tol); % TV-Split Bregman regularization on CPU (memory limited) - elseif ((lambdaHO > 0) && (lambdaTV == 0)) - % Higher Order regularization - X = LLT_model(single(X), lambdaHO, tauHO, iterHO, tol, 0); % LLT higher order model - elseif ((lambdaTV > 0) && (lambdaHO > 0)) - X1 = SplitBregman_TV(single(X), lambdaTV, iterTV, tol); % TV-Split Bregman regularization on CPU (memory limited) - X2 = LLT_model(single(X), lambdaHO, tauHO, iterHO, tol, 0); % LLT higher order model - X = 0.5.*(X1 + X2); % averaged combination of two solutions - elseif ((lambdaTV == 0) && (lambdaHO == 0)) - objective(i) = 0.5.*norm(residual(:))^2; - end - - t = (1 + sqrt(1 + 4*t^2))/2; % updating t - X_t = X + ((t_old-1)/t).*(X - X_old); % updating X - - if (show == 1) - figure(11); imshow(X(:,:,slice), [0 maxvalplot]); - pause(0.03); + pause(0.01); end if (strcmp(X_ideal, 'none' ) == 0) Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); @@ -324,7 +288,7 @@ else else fprintf('%s %i %s %s %.4f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); end - end + end end output.Resid_error = Resid_error; output.objective = objective; |