From 903175ed67f7645fa35edf4623b27999d6cb990f Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Fri, 20 Oct 2017 17:04:26 +0100
Subject: Further development

---
 src/Python/ccpi/fista/FISTAReconstructor.py |  24 ++++++
 src/Python/test_reconstructor-os.py         | 112 ++++++++++++++--------------
 2 files changed, 81 insertions(+), 55 deletions(-)

diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py
index fda9cf0..85bfac5 100644
--- a/src/Python/ccpi/fista/FISTAReconstructor.py
+++ b/src/Python/ccpi/fista/FISTAReconstructor.py
@@ -583,3 +583,27 @@ class FISTAReconstructor():
             string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n'
             print (string.format(i,Resid_error[i], self.objective[i]))
         return (X , X_t, t)
+
+    def os_iterate(self, Xin=None):
+        print ("FISTA Reconstructor: iterate")
+        
+        if Xin is None:    
+            if self.getParameter('initialize'):
+                X = self.initialize()
+            else:
+                N = vol_geom['GridColCount']
+                X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float)
+        else:
+            # copy by reference
+            X = Xin
+        # store the output volume in the parameters
+        self.setParameter(output_volume=X)
+        X_t = X.copy()
+
+        # some useful constants
+        proj_geom , vol_geom, sino , \
+          SlicesZ, weights , alpha_ring ,
+          lambdaR_L1 , L_const = self.getParameter(
+            ['projector_geometry' , 'output_geometry',
+             'input_sinogram', 'SlicesZ' ,  'weights', 'ring_alpha' ,
+             'ring_lambda_R_L1', 'Lipschitz_constant'])
diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py
index f6d7d4b..aee70a4 100644
--- a/src/Python/test_reconstructor-os.py
+++ b/src/Python/test_reconstructor-os.py
@@ -122,10 +122,13 @@ if True:
     X_t = X.copy()
     print ("initialized")
     proj_geom , vol_geom, sino , \
-                      SlicesZ = fistaRecon.getParameter(['projector_geometry' ,
-                                                            'output_geometry',
-                                                            'input_sinogram',
-                                                         'SlicesZ'])
+        SlicesZ, weights , alpha_ring = fistaRecon.getParameter(
+            ['projector_geometry' , 'output_geometry',
+             'input_sinogram', 'SlicesZ' ,  'weights', 'ring_alpha'])
+    lambdaR_L1 , alpha_ring , weights , L_const= \
+                       fistaRecon.getParameter(['ring_lambda_R_L1',
+                                               'ring_alpha' , 'weights',
+                                               'Lipschitz_constant'])
 
     #fistaRecon.setParameter(number_of_iterations = 3)
     iterFISTA = fistaRecon.getParameter('number_of_iterations')
@@ -136,12 +139,13 @@ if True:
 
       
     t = 1
+    
 
     ## additional for 
     proj_geomSUB = proj_geom.copy()
     fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram']))
     residual2 = fistaRecon.residual2
-    sino_updt_FULL = residual.copy()
+    sino_updt_FULL = fistaRecon.residual.copy()
     
     print ("starting iterations")
 ##    % Outer FISTA iterations loop
@@ -156,7 +160,8 @@ if True:
         # hence additional work is required one solution is to work with a full
         # sinogram at times
 
-
+        r_old = fistaRecon.r.copy()
+        t_old = t
         SlicesZ, anglesNumb, Detectors = \
                     numpy.shape(fistaRecon.getParameter('input_sinogram'))        ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4
         if (i > 1 and lambdaR_L1 > 0) :
@@ -167,8 +172,8 @@ if True:
                                         (sino[:,kkk,:]).squeeze() -\
                                         (alpha_ring * r_x)
                                         )
-            r_old = fistaRecon.r.copy()
-            vec = residual.sum(axis = 1)
+            
+            vec = fistaRecon.residual.sum(axis = 1)
             #if SlicesZ > 1:
             #    vec = vec[:,1,:] # 1 or 0?
             r_x = fistaRecon.r_x
@@ -227,56 +232,53 @@ if True:
             
 
 
-        ## RING REMOVAL
-        residual = fistaRecon.residual
-        
-        lambdaR_L1 , alpha_ring , weights , L_const= \
-                   fistaRecon.getParameter(['ring_lambda_R_L1',
-                                           'ring_alpha' , 'weights',
-                                           'Lipschitz_constant'])
-        if lambdaR_L1 > 0 :
-             print ("ring removal")
-             residualSub = numpy.zeros(shape)
-##             for a chosen subset
-##                for kkk = 1:numProjSub
-##                    indC = CurrSubIndeces(kkk);
-##                    residualSub(:,kkk,:) =  squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x));
-##                    sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram
-##                end
-             for kkk in range(numProjSub):
-                 indC = CurrSubIndices[kkk]
-                 residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
-                        (sino_updt_Sub[:,kkk,:].squeeze() - \
-                         sino[:,indC,:].squeeze() - alpha_ring * r_x)
-                 # filling the full sinogram
-                 sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze()
+            ## RING REMOVAL
+            residual = fistaRecon.residual
+            
+            
+            if lambdaR_L1 > 0 :
+                 print ("ring removal")
+                 residualSub = numpy.zeros(shape)
+    ##             for a chosen subset
+    ##                for kkk = 1:numProjSub
+    ##                    indC = CurrSubIndeces(kkk);
+    ##                    residualSub(:,kkk,:) =  squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x));
+    ##                    sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram
+    ##                end
+                 for kkk in range(numProjSub):
+                     indC = CurrSubIndices[kkk]
+                     residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
+                            (sino_updt_Sub[:,kkk,:].squeeze() - \
+                             sino[:,indC,:].squeeze() - alpha_ring * r_x)
+                     # filling the full sinogram
+                     sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze()
 
-        else:
-            #PWLS model
-            residualSub = weights[:,CurrSubIndices,:] * \
-                          ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() )
-            objective[i] = 0.5 * numpy.linalg.norm(residualSub)
+            else:
+                #PWLS model
+                residualSub = weights[:,CurrSubIndices,:] * \
+                              ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() )
+                objective[i] = 0.5 * numpy.linalg.norm(residualSub)
 
-        if geometry_type == 'parallel' or \
-           geometry_type == 'fanflat' or \
-           geometry_type == 'fanflat_vec' :
-            # if geometry is 2D use slice-by-slice projection-backprojection
-            # routine
-            x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32)
-            for kkk in range(SlicesZ):
-                
-                x_id, x_temp[kkk] = \
-                         astra.creators.create_backprojection3d_gpu(
-                             residualSub[kkk:kkk+1],
-                             proj_geomSUB, vol_geom)
-                
-        else:
-            x_id, x_temp = \
-                  astra.creators.create_backprojection3d_gpu(
-                      residualSub, proj_geomSUB, vol_geom)
+            if geometry_type == 'parallel' or \
+               geometry_type == 'fanflat' or \
+               geometry_type == 'fanflat_vec' :
+                # if geometry is 2D use slice-by-slice projection-backprojection
+                # routine
+                x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32)
+                for kkk in range(SlicesZ):
+                    
+                    x_id, x_temp[kkk] = \
+                             astra.creators.create_backprojection3d_gpu(
+                                 residualSub[kkk:kkk+1],
+                                 proj_geomSUB, vol_geom)
+                    
+            else:
+                x_id, x_temp = \
+                      astra.creators.create_backprojection3d_gpu(
+                          residualSub, proj_geomSUB, vol_geom)
 
-        astra.matlab.data3d('delete', x_id)
-        X = X_t - (1/L_const) * x_temp
+            astra.matlab.data3d('delete', x_id)
+            X = X_t - (1/L_const) * x_temp
 
         
 
-- 
cgit v1.2.3