From e412e73b172a99776d108d3b4c1f8662a20e9fce Mon Sep 17 00:00:00 2001
From: Daniil Kazantsev <dkazanc@hotmail.com>
Date: Thu, 3 Aug 2017 00:26:46 +0100
Subject: 2D or 3D regularization choices added

---
 main_func/FISTA_REC.m | 114 ++++++++++++++++++++++++++++++++++++++++----------
 1 file changed, 93 insertions(+), 21 deletions(-)

diff --git a/main_func/FISTA_REC.m b/main_func/FISTA_REC.m
index 1e93719..43ed0cb 100644
--- a/main_func/FISTA_REC.m
+++ b/main_func/FISTA_REC.m
@@ -3,6 +3,7 @@ function [X,  output] = FISTA_REC(params)
 % <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
 % ___Input___:
 % params.[] file:
+%----------------General Parameters------------------------
 %       - .proj_geom (geometry of the projector) [required]
 %       - .vol_geom (geometry of the reconstructed object) [required]
 %       - .sino (vectorized in 2D or 3D sinogram) [required]
@@ -13,12 +14,19 @@ function [X,  output] = FISTA_REC(params)
 %       - .ROI (Region-of-interest, only if X_ideal is given)
 %       - .initialize (a 'warm start' using SIRT method from ASTRA)
 %----------------Regularization choices------------------------
-%       - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
-%       - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
-%       - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
+%       1 .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
+%       2 .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
+%       3 .Regul_LambdaLLT (Higher order LLT regularization parameter)
+%          3.1 .Regul_tauLLT (time step parameter for LLT (HO) term)
+%       4 .Regul_LambdaPatchBased (Patch-based nonlocal regularization parameter)
+%                 4.1  .Regul_PB_SearchW (ratio of the searching window (e.g. 3 = (2*3+1) = 7 pixels window))
+%                 4.2  .Regul_PB_SimilW (ratio of the similarity window (e.g. 1 = (2*1+1) = 3 pixels window))
+%                 4.3  .Regul_PB_h (PB penalty function threshold)
+%       5 .Regul_LambdaTGV (Total Generalized variation regularization parameter)
 %       - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
 %       - .Regul_Iterations (iterations for the selected penalty, default 25)
-%       - .Regul_tauLLT (time step parameter for LLT term)
+%       - .Regul_Dimension ('2D' or '3D' way to apply regularization, '3D' is the default)
+%----------------Ring removal------------------------
 %       - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
 %       - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
 %----------------Visualization parameters------------------------
@@ -150,8 +158,8 @@ if (isfield(params,'Regul_Iterations'))
 else
     IterationsRegul = 25;
 end
-if (isfield(params,'Regul_LambdaHO'))
-    lambdaHO = params.Regul_LambdaHO;
+if (isfield(params,'Regul_LambdaLLT'))
+    lambdaHO = params.Regul_LambdaLLT;
 else
     lambdaHO = 0;
 end
@@ -165,6 +173,26 @@ if (isfield(params,'Regul_tauLLT'))
 else
     tauHO = 0.0001;
 end
+if (isfield(params,'Regul_LambdaPatchBased'))
+    lambdaPB = params.Regul_LambdaPatchBased;
+else
+    lambdaPB = 0;
+end
+if (isfield(params,'Regul_PB_SearchW'))
+    SearchW = params.Regul_PB_SearchW;
+else
+    SearchW = 3; % default
+end
+if (isfield(params,'Regul_PB_SimilW'))
+    SimilW = params.Regul_PB_SimilW;
+else
+    SimilW = 1; % default
+end
+if (isfield(params,'Regul_PB_h'))
+    h_PB = params.Regul_PB_h;
+else
+    h_PB = 0.1; % default
+end
 if (isfield(params,'Ring_LambdaR_L1'))
     lambdaR_L1 = params.Ring_LambdaR_L1;
 else
@@ -175,6 +203,14 @@ if (isfield(params,'Ring_Alpha'))
 else
     alpha_ring = 1;
 end
+if (isfield(params,'Regul_Dimension'))
+    Dimension = params.Regul_Dimension;
+    if ((strcmp('2D', Dimension) ~= 1) && (strcmp('3D', Dimension) ~= 1))
+        Dimension = '3D';
+    end
+else
+    Dimension = '3D';
+end
 if (isfield(params,'show'))
     show = params.show;
 else
@@ -293,21 +329,57 @@ for i = 1:iterFISTA
     astra_mex_data3d('delete', sino_id);
     astra_mex_data3d('delete', id);
     
-    if (lambdaFGP_TV > 0)
-        % FGP-TV regularization
-        [X, f_val] = FGP_TV(single(X), lambdaFGP_TV, IterationsRegul, tol, 'iso');
-        objective(i) = objective(i) + f_val;
-    end
-    if (lambdaSB_TV > 0)
-        % Split Bregman regularization
-        X = SplitBregman_TV(single(X), lambdaSB_TV, IterationsRegul, tol);  % (more memory efficent)
-    end
-    if (lambdaHO > 0)
-        % Higher Order (LLT) regularization
-        X2 = LLT_model(single(X), lambdaHO, tauHO, iterHO, 3.0e-05, 0);
-        X = 0.5.*(X + X2); % averaged combination of two solutions
-    end
-    
+    % regularization
+     if (lambdaFGP_TV > 0)
+         % FGP-TV regularization
+        if ((strcmp('2D', Dimension) == 1))
+            % 2D regularization
+            for kkk = 1:SlicesZ
+                [X(:,:,kkk), f_val] = FGP_TV(single(X(:,:,kkk)), lambdaFGP_TV, IterationsRegul, tol, 'iso');
+            end
+        else
+            % 3D regularization
+            [X, f_val] = FGP_TV(single(X), lambdaFGP_TV, IterationsRegul, tol, 'iso');
+        end
+         objective(i) = objective(i) + f_val;
+     end
+     if (lambdaSB_TV > 0)
+         % Split Bregman regularization
+        if ((strcmp('2D', Dimension) == 1))
+            % 2D regularization
+            for kkk = 1:SlicesZ
+                X(:,:,kkk) = SplitBregman_TV(single(X(:,:,kkk)), lambdaSB_TV, IterationsRegul, tol);  % (more memory efficent)
+            end
+        else
+            % 3D regularization
+            X = SplitBregman_TV(single(X), lambdaSB_TV, IterationsRegul, tol);  % (more memory efficent)
+        end
+     end
+     if (lambdaHO > 0)
+         % Higher Order (LLT) regularization
+        X2 = zeros(N,N,SlicesZ,'single');
+        if ((strcmp('2D', Dimension) == 1))
+            % 2D regularization
+            for kkk = 1:SlicesZ
+                X2(:,:,kkk) = LLT_model(single(X(:,:,kkk)), lambdaHO, tauHO, iterHO, 3.0e-05, 0);
+            end
+        else
+            % 3D regularization
+            X2 = LLT_model(single(X), lambdaHO, tauHO, iterHO, 3.0e-05, 0);
+        end
+         X = 0.5.*(X + X2); % averaged combination of two solutions
+     end
+    if (lambdaPB > 0)
+        % Patch-Based regularization (can be slow on CPU)
+        if ((strcmp('2D', Dimension) == 1))
+            % 2D regularization
+            for kkk = 1:SlicesZ
+                X(:,:,kkk) = PatchBased_Regul(single(X(:,:,kkk)), SearchW, SimilW, h_PB, lambdaPB);
+            end
+        else
+            X = PatchBased_Regul(single(X), SearchW, SimilW, h_PB, lambdaPB);
+        end
+    end  
     
     
     if (lambdaR_L1 > 0)
-- 
cgit v1.2.3