summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/Python/ccpi/reconstruction/FISTAReconstructor.py354
1 files changed, 354 insertions, 0 deletions
diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
new file mode 100644
index 0000000..ea96b53
--- /dev/null
+++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
@@ -0,0 +1,354 @@
+# -*- coding: utf-8 -*-
+###############################################################################
+#This work is part of the Core Imaging Library developed by
+#Visual Analytics and Imaging System Group of the Science Technology
+#Facilities Council, STFC
+#
+#Copyright 2017 Edoardo Pasca, Srikanth Nagella
+#Copyright 2017 Daniil Kazantsev
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#http://www.apache.org/licenses/LICENSE-2.0
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+###############################################################################
+
+
+
+import numpy
+import h5py
+#from ccpi.reconstruction.parallelbeam import alg
+
+from ccpi.imaging.Regularizer import Regularizer
+from enum import Enum
+
+import astra
+
+
+
+class FISTAReconstructor():
+ '''FISTA-based reconstruction algorithm using ASTRA-toolbox
+
+ '''
+ # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
+ # ___Input___:
+ # params.[] file:
+ # - .proj_geom (geometry of the projector) [required]
+ # - .vol_geom (geometry of the reconstructed object) [required]
+ # - .sino (vectorized in 2D or 3D sinogram) [required]
+ # - .iterFISTA (iterations for the main loop, default 40)
+ # - .L_const (Lipschitz constant, default Power method) )
+ # - .X_ideal (ideal image, if given)
+ # - .weights (statisitcal weights, size of the sinogram)
+ # - .ROI (Region-of-interest, only if X_ideal is given)
+ # - .initialize (a 'warm start' using SIRT method from ASTRA)
+ #----------------Regularization choices------------------------
+ # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
+ # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
+ # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
+ # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
+ # - .Regul_Iterations (iterations for the selected penalty, default 25)
+ # - .Regul_tauLLT (time step parameter for LLT term)
+ # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
+ # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
+ #----------------Visualization parameters------------------------
+ # - .show (visualize reconstruction 1/0, (0 default))
+ # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
+ # - .slice (for 3D volumes - slice number to imshow)
+ # ___Output___:
+ # 1. X - reconstructed image/volume
+ # 2. output - a structure with
+ # - .Resid_error - residual error (if X_ideal is given)
+ # - .objective: value of the objective function
+ # - .L_const: Lipshitz constant to avoid recalculations
+
+ # References:
+ # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
+ # Problems" by A. Beck and M Teboulle
+ # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
+ # 3. "A novel tomographic reconstruction method based on the robust
+ # Student's t function for suppressing data outliers" D. Kazantsev et.al.
+ # D. Kazantsev, 2016-17
+ def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
+ self.params = dict()
+ self.params['projector_geometry'] = projector_geometry
+ self.params['output_geometry'] = output_geometry
+ self.params['input_sinogram'] = input_sinogram
+ detectors, nangles, sliceZ = numpy.shape(input_sinogram)
+ self.params['detectors'] = detectors
+ self.params['number_og_angles'] = nangles
+ self.params['SlicesZ'] = sliceZ
+
+ # Accepted input keywords
+ kw = ('number_of_iterations',
+ 'Lipschitz_constant' ,
+ 'ideal_image' ,
+ 'weights' ,
+ 'region_of_interest' ,
+ 'initialize' ,
+ 'regularizer' ,
+ 'ring_lambda_R_L1',
+ 'ring_alpha')
+
+ # handle keyworded parameters
+ if kwargs is not None:
+ for key, value in kwargs.items():
+ if key in kw:
+ #print("{0} = {1}".format(key, value))
+ self.pars[key] = value
+
+ # set the default values for the parameters if not set
+ if 'number_of_iterations' in kwargs.keys():
+ self.pars['number_of_iterations'] = kwargs['number_of_iterations']
+ else:
+ self.pars['number_of_iterations'] = 40
+ if 'weights' in kwargs.keys():
+ self.pars['weights'] = kwargs['weights']
+ else:
+ self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram']))
+ if 'Lipschitz_constant' in kwargs.keys():
+ self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
+ else:
+ self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
+
+ if not self.pars['ideal_image'] in kwargs.keys():
+ self.pars['ideal_image'] = None
+
+ if not self.pars['region_of_interest'] :
+ if self.pars['ideal_image'] == None:
+ pass
+ else:
+ self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
+
+ if not self.pars['regularizer'] :
+ self.pars['regularizer'] = None
+ else:
+ # the regularizer must be a correctly instantiated object
+ if not self.pars['ring_lambda_R_L1']:
+ self.pars['ring_lambda_R_L1'] = 0
+ if not self.pars['ring_alpha']:
+ self.pars['ring_alpha'] = 1
+
+
+
+
+ def calculateLipschitzConstantWithPowerMethod(self):
+ ''' using Power method (PM) to establish L constant'''
+
+ #N = params.vol_geom.GridColCount
+ N = self.pars['output_geometry'].GridColCount
+ proj_geom = self.params['projector_geometry']
+ vol_geom = self.params['output_geometry']
+ weights = self.pars['weights']
+ SlicesZ = self.pars['SlicesZ']
+
+ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+ #% for parallel geometry we can do just one slice
+ #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...');
+ niter = 15;# % number of iteration for the PM
+ #N = params.vol_geom.GridColCount;
+ #x1 = rand(N,N,1);
+ x1 = numpy.random.rand(1,N,N)
+ #sqweight = sqrt(weights(:,:,1));
+ sqweight = numpy.sqrt(weights.T[0])
+ proj_geomT = proj_geom.copy();
+ proj_geomT.DetectorRowCount = 1;
+ vol_geomT = vol_geom.copy();
+ vol_geomT['GridSliceCount'] = 1;
+
+
+ for i in range(niter):
+ if i == 0:
+ #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
+ y = sqweight * y # element wise multiplication
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id)
+
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT);
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
+ y = sqweight*y;
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ del proj_geomT
+ del vol_geomT
+ else:
+ #% divergen beam geometry
+ #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...');
+ niter = 8; #% number of iteration for PM
+ x1 = numpy.random.rand(SlicesZ , N , N);
+ #sqweight = sqrt(weights);
+ sqweight = numpy.sqrt(weights.T[0])
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id);
+
+ for i in range(niter):
+ #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
+ proj_geom,
+ vol_geom)
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geom,
+ vol_geom);
+
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ #astra_mex_data3d('delete', id);
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ #clear x1
+ del x1
+
+ return s
+
+
+ def setRegularizer(self, regularizer):
+ if regularizer is not None:
+ self.pars['regularizer'] = regularizer
+
+
+
+
+
+def getEntry(location):
+ for item in nx[location].keys():
+ print (item)
+
+
+print ("Loading Data")
+
+##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
+####ind = [i * 1049 for i in range(360)]
+#### use only 360 images
+##images = 200
+##ind = [int(i * 1049 / images) for i in range(images)]
+##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
+
+#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
+#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
+##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5"
+##nx = h5py.File(fname, "r")
+##
+### the data are stored in a particular location in the hdf5
+##for item in nx['entry1/tomo_entry/data'].keys():
+## print (item)
+##
+##data = nx.get('entry1/tomo_entry/data/rotation_angle')
+##angles = numpy.zeros(data.shape)
+##data.read_direct(angles)
+##print (angles)
+### angles should be in degrees
+##
+##data = nx.get('entry1/tomo_entry/data/data')
+##stack = numpy.zeros(data.shape)
+##data.read_direct(stack)
+##print (data.shape)
+##
+##print ("Data Loaded")
+##
+##
+### Normalize
+##data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
+##itype = numpy.zeros(data.shape)
+##data.read_direct(itype)
+### 2 is dark field
+##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
+##dark = darks[0]
+##for i in range(1, len(darks)):
+## dark += darks[i]
+##dark = dark / len(darks)
+###dark[0][0] = dark[0][1]
+##
+### 1 is flat field
+##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
+##flat = flats[0]
+##for i in range(1, len(flats)):
+## flat += flats[i]
+##flat = flat / len(flats)
+###flat[0][0] = dark[0][1]
+##
+##
+### 0 is projection data
+##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
+##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
+##angle_proj = numpy.asarray (angle_proj)
+##angle_proj = angle_proj.astype(numpy.float32)
+##
+### normalized data are
+### norm = (projection - dark)/(flat-dark)
+##
+##def normalize(projection, dark, flat, def_val=0.1):
+## a = (projection - dark)
+## b = (flat-dark)
+## with numpy.errstate(divide='ignore', invalid='ignore'):
+## c = numpy.true_divide( a, b )
+## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
+## return c
+##
+##
+##norm = [normalize(projection, dark, flat) for projection in proj]
+##norm = numpy.asarray (norm)
+##norm = norm.astype(numpy.float32)
+
+
+##niterations = 15
+##threads = 3
+##
+##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+##
+##iteration_values = numpy.zeros((niterations,))
+##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+## iteration_values, False)
+##print ("iteration values %s" % str(iteration_values))
+##
+##iteration_values = numpy.zeros((niterations,))
+##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+## numpy.double(1e-5), iteration_values , False)
+##print ("iteration values %s" % str(iteration_values))
+##iteration_values = numpy.zeros((niterations,))
+##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+## numpy.double(1e-5), iteration_values , False)
+##print ("iteration values %s" % str(iteration_values))
+##
+##
+####numpy.save("cgls_recon.npy", img_data)
+##import matplotlib.pyplot as plt
+##fig, ax = plt.subplots(1,6,sharey=True)
+##ax[0].imshow(img_cgls[80])
+##ax[0].axis('off') # clear x- and y-axes
+##ax[1].imshow(img_sirt[80])
+##ax[1].axis('off') # clear x- and y-axes
+##ax[2].imshow(img_mlem[80])
+##ax[2].axis('off') # clear x- and y-axesplt.show()
+##ax[3].imshow(img_cgls_conv[80])
+##ax[3].axis('off') # clear x- and y-axesplt.show()
+##ax[4].imshow(img_cgls_tikhonov[80])
+##ax[4].axis('off') # clear x- and y-axesplt.show()
+##ax[5].imshow(img_cgls_TVreg[80])
+##ax[5].axis('off') # clear x- and y-axesplt.show()
+##
+##
+##plt.show()
+##
+