summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2017-10-11 16:04:49 +0100
committerEdoardo Pasca <edo.paskino@gmail.com>2017-10-11 16:04:49 +0100
commitad260ad1ab2e44c8c6d3daa9fe9d24c55bf5f280 (patch)
treed417fb4780b142adb817d53cc0903e51ca70ee12 /src
parent5c978b706192bc5885c7e5001a4bc4626f63d29f (diff)
parentf7e1cf04f791898737bc15b0eb437abc2c5d9305 (diff)
downloadregularization-ad260ad1ab2e44c8c6d3daa9fe9d24c55bf5f280.tar.gz
regularization-ad260ad1ab2e44c8c6d3daa9fe9d24c55bf5f280.tar.bz2
regularization-ad260ad1ab2e44c8c6d3daa9fe9d24c55bf5f280.tar.xz
regularization-ad260ad1ab2e44c8c6d3daa9fe9d24c55bf5f280.zip
Merge branch 'pythonize' of https://github.com/vais-ral/CCPi-FISTA_Reconstruction into pythonize
Conflicts: demos/exportDemoRD2Data.m main_func/regularizers_CPU/FGP_TV_core.h main_func/regularizers_CPU/LLT_model.c src/Python/test/readhd5.py src/Python/test_reconstructor.py src/Python/test_regularizers.py
Diffstat (limited to 'src')
-rw-r--r--src/Python/ccpi/fista/FISTAReconstructor.py389
-rw-r--r--src/Python/ccpi/fista/Reconstructor.py425
-rw-r--r--src/Python/ccpi/fista/__init__.py0
-rw-r--r--src/Python/test/astra_test.py85
-rw-r--r--src/Python/test/readhd5.py1
-rw-r--r--src/Python/test_reconstructor.py65
-rw-r--r--src/Python/test_regularizers.py73
7 files changed, 988 insertions, 50 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py
new file mode 100644
index 0000000..1e76815
--- /dev/null
+++ b/src/Python/ccpi/fista/FISTAReconstructor.py
@@ -0,0 +1,389 @@
+# -*- coding: utf-8 -*-
+###############################################################################
+#This work is part of the Core Imaging Library developed by
+#Visual Analytics and Imaging System Group of the Science Technology
+#Facilities Council, STFC
+#
+#Copyright 2017 Edoardo Pasca, Srikanth Nagella
+#Copyright 2017 Daniil Kazantsev
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#http://www.apache.org/licenses/LICENSE-2.0
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+###############################################################################
+
+
+
+import numpy
+#from ccpi.reconstruction.parallelbeam import alg
+
+#from ccpi.imaging.Regularizer import Regularizer
+from enum import Enum
+
+import astra
+
+
+
+class FISTAReconstructor():
+ '''FISTA-based reconstruction algorithm using ASTRA-toolbox
+
+ '''
+ # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
+ # ___Input___:
+ # params.[] file:
+ # - .proj_geom (geometry of the projector) [required]
+ # - .vol_geom (geometry of the reconstructed object) [required]
+ # - .sino (vectorized in 2D or 3D sinogram) [required]
+ # - .iterFISTA (iterations for the main loop, default 40)
+ # - .L_const (Lipschitz constant, default Power method) )
+ # - .X_ideal (ideal image, if given)
+ # - .weights (statisitcal weights, size of the sinogram)
+ # - .ROI (Region-of-interest, only if X_ideal is given)
+ # - .initialize (a 'warm start' using SIRT method from ASTRA)
+ #----------------Regularization choices------------------------
+ # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
+ # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
+ # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
+ # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
+ # - .Regul_Iterations (iterations for the selected penalty, default 25)
+ # - .Regul_tauLLT (time step parameter for LLT term)
+ # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
+ # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
+ #----------------Visualization parameters------------------------
+ # - .show (visualize reconstruction 1/0, (0 default))
+ # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
+ # - .slice (for 3D volumes - slice number to imshow)
+ # ___Output___:
+ # 1. X - reconstructed image/volume
+ # 2. output - a structure with
+ # - .Resid_error - residual error (if X_ideal is given)
+ # - .objective: value of the objective function
+ # - .L_const: Lipshitz constant to avoid recalculations
+
+ # References:
+ # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
+ # Problems" by A. Beck and M Teboulle
+ # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
+ # 3. "A novel tomographic reconstruction method based on the robust
+ # Student's t function for suppressing data outliers" D. Kazantsev et.al.
+ # D. Kazantsev, 2016-17
+ def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
+ # handle parmeters:
+ # obligatory parameters
+ self.pars = dict()
+ self.pars['projector_geometry'] = projector_geometry
+ self.pars['output_geometry'] = output_geometry
+ self.pars['input_sinogram'] = input_sinogram
+ detectors, nangles, sliceZ = numpy.shape(input_sinogram)
+ self.pars['detectors'] = detectors
+ self.pars['number_og_angles'] = nangles
+ self.pars['SlicesZ'] = sliceZ
+
+ print (self.pars)
+ # handle optional input parameters (at instantiation)
+
+ # Accepted input keywords
+ kw = ('number_of_iterations',
+ 'Lipschitz_constant' ,
+ 'ideal_image' ,
+ 'weights' ,
+ 'region_of_interest' ,
+ 'initialize' ,
+ 'regularizer' ,
+ 'ring_lambda_R_L1',
+ 'ring_alpha')
+
+ # handle keyworded parameters
+ if kwargs is not None:
+ for key, value in kwargs.items():
+ if key in kw:
+ #print("{0} = {1}".format(key, value))
+ self.pars[key] = value
+
+ # set the default values for the parameters if not set
+ if 'number_of_iterations' in kwargs.keys():
+ self.pars['number_of_iterations'] = kwargs['number_of_iterations']
+ else:
+ self.pars['number_of_iterations'] = 40
+ if 'weights' in kwargs.keys():
+ self.pars['weights'] = kwargs['weights']
+ else:
+ self.pars['weights'] = numpy.ones(numpy.shape(self.pars['input_sinogram']))
+ if 'Lipschitz_constant' in kwargs.keys():
+ self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
+ else:
+ self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
+
+ if not 'ideal_image' in kwargs.keys():
+ self.pars['ideal_image'] = None
+
+ if not 'region_of_interest'in kwargs.keys() :
+ if self.pars['ideal_image'] == None:
+ pass
+ else:
+ self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
+
+ if not 'regularizer' in kwargs.keys() :
+ self.pars['regularizer'] = None
+ else:
+ # the regularizer must be a correctly instantiated object
+ if not 'ring_lambda_R_L1' in kwargs.keys():
+ self.pars['ring_lambda_R_L1'] = 0
+ if not 'ring_alpha' in kwargs.keys():
+ self.pars['ring_alpha'] = 1
+
+
+
+
+ def calculateLipschitzConstantWithPowerMethod(self):
+ ''' using Power method (PM) to establish L constant'''
+
+ N = self.pars['output_geometry']['GridColCount']
+ proj_geom = self.pars['projector_geometry']
+ vol_geom = self.pars['output_geometry']
+ weights = self.pars['weights']
+ SlicesZ = self.pars['SlicesZ']
+
+
+
+ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+ #% for parallel geometry we can do just one slice
+ #print('Calculating Lipshitz constant for parallel beam geometry...')
+ niter = 5;# % number of iteration for the PM
+ #N = params.vol_geom.GridColCount;
+ #x1 = rand(N,N,1);
+ x1 = numpy.random.rand(1,N,N)
+ #sqweight = sqrt(weights(:,:,1));
+ sqweight = numpy.sqrt(weights[0])
+ proj_geomT = proj_geom.copy();
+ proj_geomT['DetectorRowCount'] = 1;
+ vol_geomT = vol_geom.copy();
+ vol_geomT['GridSliceCount'] = 1;
+
+ #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+
+
+ for i in range(niter):
+ # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT);
+ # s = norm(x1(:));
+ # x1 = x1/s;
+ # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ # y = sqweight.*y;
+ # astra_mex_data3d('delete', sino_id);
+ # astra_mex_data3d('delete', id);
+ #print ("iteration {0}".format(i))
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geomT,
+ vol_geomT)
+
+ y = (sqweight * y).copy() # element wise multiplication
+
+ #b=fig.add_subplot(2,1,2)
+ #imgplot = plt.imshow(x1[0])
+ #plt.show()
+
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id)
+ del x1
+
+ idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(),
+ proj_geomT,
+ vol_geomT)
+ del y
+
+
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = (x1/s).copy();
+
+ # ### this line?
+ # sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ # proj_geomT,
+ # vol_geomT);
+ # y = sqweight * y;
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx)
+ print ("iteration {0} s= {1}".format(i,s))
+
+ #end
+ del proj_geomT
+ del vol_geomT
+ #plt.show()
+ else:
+ #% divergen beam geometry
+ print('Calculating Lipshitz constant for divergen beam geometry...')
+ niter = 8; #% number of iteration for PM
+ x1 = numpy.random.rand(SlicesZ , N , N);
+ #sqweight = sqrt(weights);
+ sqweight = numpy.sqrt(weights[0])
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id);
+
+ for i in range(niter):
+ #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
+ proj_geom,
+ vol_geom)
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geom,
+ vol_geom);
+
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ #astra_mex_data3d('delete', id);
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ #clear x1
+ del x1
+
+
+ return s
+
+
+ def setRegularizer(self, regularizer):
+ if regularizer is not None:
+ self.pars['regularizer'] = regularizer
+
+
+
+
+
+def getEntry(location, nx):
+ for item in nx[location].keys():
+ print (item)
+
+
+print ("Loading Data")
+
+##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
+####ind = [i * 1049 for i in range(360)]
+#### use only 360 images
+##images = 200
+##ind = [int(i * 1049 / images) for i in range(images)]
+##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
+
+#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
+#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
+##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5"
+##nx = h5py.File(fname, "r")
+##
+### the data are stored in a particular location in the hdf5
+##for item in nx['entry1/tomo_entry/data'].keys():
+## print (item)
+##
+##data = nx.get('entry1/tomo_entry/data/rotation_angle')
+##angles = numpy.zeros(data.shape)
+##data.read_direct(angles)
+##print (angles)
+### angles should be in degrees
+##
+##data = nx.get('entry1/tomo_entry/data/data')
+##stack = numpy.zeros(data.shape)
+##data.read_direct(stack)
+##print (data.shape)
+##
+##print ("Data Loaded")
+##
+##
+### Normalize
+##data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
+##itype = numpy.zeros(data.shape)
+##data.read_direct(itype)
+### 2 is dark field
+##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
+##dark = darks[0]
+##for i in range(1, len(darks)):
+## dark += darks[i]
+##dark = dark / len(darks)
+###dark[0][0] = dark[0][1]
+##
+### 1 is flat field
+##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
+##flat = flats[0]
+##for i in range(1, len(flats)):
+## flat += flats[i]
+##flat = flat / len(flats)
+###flat[0][0] = dark[0][1]
+##
+##
+### 0 is projection data
+##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
+##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
+##angle_proj = numpy.asarray (angle_proj)
+##angle_proj = angle_proj.astype(numpy.float32)
+##
+### normalized data are
+### norm = (projection - dark)/(flat-dark)
+##
+##def normalize(projection, dark, flat, def_val=0.1):
+## a = (projection - dark)
+## b = (flat-dark)
+## with numpy.errstate(divide='ignore', invalid='ignore'):
+## c = numpy.true_divide( a, b )
+## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
+## return c
+##
+##
+##norm = [normalize(projection, dark, flat) for projection in proj]
+##norm = numpy.asarray (norm)
+##norm = norm.astype(numpy.float32)
+
+
+##niterations = 15
+##threads = 3
+##
+##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+##
+##iteration_values = numpy.zeros((niterations,))
+##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+## iteration_values, False)
+##print ("iteration values %s" % str(iteration_values))
+##
+##iteration_values = numpy.zeros((niterations,))
+##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+## numpy.double(1e-5), iteration_values , False)
+##print ("iteration values %s" % str(iteration_values))
+##iteration_values = numpy.zeros((niterations,))
+##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+## numpy.double(1e-5), iteration_values , False)
+##print ("iteration values %s" % str(iteration_values))
+##
+##
+####numpy.save("cgls_recon.npy", img_data)
+##import matplotlib.pyplot as plt
+##fig, ax = plt.subplots(1,6,sharey=True)
+##ax[0].imshow(img_cgls[80])
+##ax[0].axis('off') # clear x- and y-axes
+##ax[1].imshow(img_sirt[80])
+##ax[1].axis('off') # clear x- and y-axes
+##ax[2].imshow(img_mlem[80])
+##ax[2].axis('off') # clear x- and y-axesplt.show()
+##ax[3].imshow(img_cgls_conv[80])
+##ax[3].axis('off') # clear x- and y-axesplt.show()
+##ax[4].imshow(img_cgls_tikhonov[80])
+##ax[4].axis('off') # clear x- and y-axesplt.show()
+##ax[5].imshow(img_cgls_TVreg[80])
+##ax[5].axis('off') # clear x- and y-axesplt.show()
+##
+##
+##plt.show()
+##
+
diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py
new file mode 100644
index 0000000..d29ac0d
--- /dev/null
+++ b/src/Python/ccpi/fista/Reconstructor.py
@@ -0,0 +1,425 @@
+# -*- coding: utf-8 -*-
+###############################################################################
+#This work is part of the Core Imaging Library developed by
+#Visual Analytics and Imaging System Group of the Science Technology
+#Facilities Council, STFC
+#
+#Copyright 2017 Edoardo Pasca, Srikanth Nagella
+#Copyright 2017 Daniil Kazantsev
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#http://www.apache.org/licenses/LICENSE-2.0
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+###############################################################################
+
+
+
+import numpy
+import h5py
+from ccpi.reconstruction.parallelbeam import alg
+
+from Regularizer import Regularizer
+from enum import Enum
+
+import astra
+
+
+
+class FISTAReconstructor():
+ '''FISTA-based reconstruction algorithm using ASTRA-toolbox
+
+ '''
+ # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>>
+ # ___Input___:
+ # params.[] file:
+ # - .proj_geom (geometry of the projector) [required]
+ # - .vol_geom (geometry of the reconstructed object) [required]
+ # - .sino (vectorized in 2D or 3D sinogram) [required]
+ # - .iterFISTA (iterations for the main loop, default 40)
+ # - .L_const (Lipschitz constant, default Power method) )
+ # - .X_ideal (ideal image, if given)
+ # - .weights (statisitcal weights, size of the sinogram)
+ # - .ROI (Region-of-interest, only if X_ideal is given)
+ # - .initialize (a 'warm start' using SIRT method from ASTRA)
+ #----------------Regularization choices------------------------
+ # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter)
+ # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter)
+ # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter)
+ # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04)
+ # - .Regul_Iterations (iterations for the selected penalty, default 25)
+ # - .Regul_tauLLT (time step parameter for LLT term)
+ # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal)
+ # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1)
+ #----------------Visualization parameters------------------------
+ # - .show (visualize reconstruction 1/0, (0 default))
+ # - .maxvalplot (maximum value to use for imshow[0 maxvalplot])
+ # - .slice (for 3D volumes - slice number to imshow)
+ # ___Output___:
+ # 1. X - reconstructed image/volume
+ # 2. output - a structure with
+ # - .Resid_error - residual error (if X_ideal is given)
+ # - .objective: value of the objective function
+ # - .L_const: Lipshitz constant to avoid recalculations
+
+ # References:
+ # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
+ # Problems" by A. Beck and M Teboulle
+ # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo
+ # 3. "A novel tomographic reconstruction method based on the robust
+ # Student's t function for suppressing data outliers" D. Kazantsev et.al.
+ # D. Kazantsev, 2016-17
+ def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs):
+ self.params = dict()
+ self.params['projector_geometry'] = projector_geometry
+ self.params['output_geometry'] = output_geometry
+ self.params['input_sinogram'] = input_sinogram
+ detectors, nangles, sliceZ = numpy.shape(input_sinogram)
+ self.params['detectors'] = detectors
+ self.params['number_og_angles'] = nangles
+ self.params['SlicesZ'] = sliceZ
+
+ # Accepted input keywords
+ kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' ,
+ 'weights' , 'region_of_interest' , 'initialize' ,
+ 'regularizer' ,
+ 'ring_lambda_R_L1',
+ 'ring_alpha')
+
+ # handle keyworded parameters
+ if kwargs is not None:
+ for key, value in kwargs.items():
+ if key in kw:
+ #print("{0} = {1}".format(key, value))
+ self.pars[key] = value
+
+ # set the default values for the parameters if not set
+ if 'number_of_iterations' in kwargs.keys():
+ self.pars['number_of_iterations'] = kwargs['number_of_iterations']
+ else:
+ self.pars['number_of_iterations'] = 40
+ if 'weights' in kwargs.keys():
+ self.pars['weights'] = kwargs['weights']
+ else:
+ self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram']))
+ if 'Lipschitz_constant' in kwargs.keys():
+ self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant']
+ else:
+ self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod()
+
+ if not self.pars['ideal_image'] in kwargs.keys():
+ self.pars['ideal_image'] = None
+
+ if not self.pars['region_of_interest'] :
+ if self.pars['ideal_image'] == None:
+ pass
+ else:
+ self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0)
+
+ if not self.pars['regularizer'] :
+ self.pars['regularizer'] = None
+ else:
+ # the regularizer must be a correctly instantiated object
+ if not self.pars['ring_lambda_R_L1']:
+ self.pars['ring_lambda_R_L1'] = 0
+ if not self.pars['ring_alpha']:
+ self.pars['ring_alpha'] = 1
+
+
+
+
+ def calculateLipschitzConstantWithPowerMethod(self):
+ ''' using Power method (PM) to establish L constant'''
+
+ #N = params.vol_geom.GridColCount
+ N = self.pars['output_geometry'].GridColCount
+ proj_geom = self.params['projector_geometry']
+ vol_geom = self.params['output_geometry']
+ weights = self.pars['weights']
+ SlicesZ = self.pars['SlicesZ']
+
+ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
+ #% for parallel geometry we can do just one slice
+ #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...');
+ niter = 15;# % number of iteration for the PM
+ #N = params.vol_geom.GridColCount;
+ #x1 = rand(N,N,1);
+ x1 = numpy.random.rand(1,N,N)
+ #sqweight = sqrt(weights(:,:,1));
+ sqweight = numpy.sqrt(weights.T[0])
+ proj_geomT = proj_geom.copy();
+ proj_geomT.DetectorRowCount = 1;
+ vol_geomT = vol_geom.copy();
+ vol_geomT['GridSliceCount'] = 1;
+
+
+ for i in range(niter):
+ if i == 0:
+ #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
+ y = sqweight * y # element wise multiplication
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id)
+
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT);
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+ y = sqweight*y;
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ del proj_geomT
+ del vol_geomT
+ else
+ #% divergen beam geometry
+ #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...');
+ niter = 8; #% number of iteration for PM
+ x1 = numpy.random.rand(SlicesZ , N , N);
+ #sqweight = sqrt(weights);
+ sqweight = numpy.sqrt(weights.T[0])
+
+ sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom);
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ astra.matlab.data3d('delete', sino_id);
+
+ for i in range(niter):
+ #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom);
+ idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
+ proj_geom,
+ vol_geom)
+ s = numpy.linalg.norm(x1)
+ ### this line?
+ x1 = x1/s;
+ ### this line?
+ #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom);
+ sino_id, y = astra.creators.create_sino3d_gpu(x1,
+ proj_geom,
+ vol_geom);
+
+ y = sqweight*y;
+ #astra_mex_data3d('delete', sino_id);
+ #astra_mex_data3d('delete', id);
+ astra.matlab.data3d('delete', sino_id);
+ astra.matlab.data3d('delete', idx);
+ #end
+ #clear x1
+ del x1
+
+ return s
+
+
+ def setRegularizer(self, regularizer):
+ if regularizer
+ self.pars['regularizer'] = regularizer
+
+
+
+
+
+def getEntry(location):
+ for item in nx[location].keys():
+ print (item)
+
+
+print ("Loading Data")
+
+##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif"
+####ind = [i * 1049 for i in range(360)]
+#### use only 360 images
+##images = 200
+##ind = [int(i * 1049 / images) for i in range(images)]
+##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None)
+
+#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs"
+fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs"
+nx = h5py.File(fname, "r")
+
+# the data are stored in a particular location in the hdf5
+for item in nx['entry1/tomo_entry/data'].keys():
+ print (item)
+
+data = nx.get('entry1/tomo_entry/data/rotation_angle')
+angles = numpy.zeros(data.shape)
+data.read_direct(angles)
+print (angles)
+# angles should be in degrees
+
+data = nx.get('entry1/tomo_entry/data/data')
+stack = numpy.zeros(data.shape)
+data.read_direct(stack)
+print (data.shape)
+
+print ("Data Loaded")
+
+
+# Normalize
+data = nx.get('entry1/tomo_entry/instrument/detector/image_key')
+itype = numpy.zeros(data.shape)
+data.read_direct(itype)
+# 2 is dark field
+darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ]
+dark = darks[0]
+for i in range(1, len(darks)):
+ dark += darks[i]
+dark = dark / len(darks)
+#dark[0][0] = dark[0][1]
+
+# 1 is flat field
+flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ]
+flat = flats[0]
+for i in range(1, len(flats)):
+ flat += flats[i]
+flat = flat / len(flats)
+#flat[0][0] = dark[0][1]
+
+
+# 0 is projection data
+proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ]
+angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ]
+angle_proj = numpy.asarray (angle_proj)
+angle_proj = angle_proj.astype(numpy.float32)
+
+# normalized data are
+# norm = (projection - dark)/(flat-dark)
+
+def normalize(projection, dark, flat, def_val=0.1):
+ a = (projection - dark)
+ b = (flat-dark)
+ with numpy.errstate(divide='ignore', invalid='ignore'):
+ c = numpy.true_divide( a, b )
+ c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0
+ return c
+
+
+norm = [normalize(projection, dark, flat) for projection in proj]
+norm = numpy.asarray (norm)
+norm = norm.astype(numpy.float32)
+
+#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm,
+# angles = angle_proj, center_of_rotation = 86.2 ,
+# flat_field = flat, dark_field = dark,
+# iterations = 15, resolution = 1, isLogScale = False, threads = 3)
+
+#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj,
+# angles = angle_proj, center_of_rotation = 86.2 ,
+# flat_field = flat, dark_field = dark,
+# iterations = 15, resolution = 1, isLogScale = False, threads = 3)
+#img_cgls = recon.reconstruct()
+#
+#pars = dict()
+#pars['algorithm'] = Reconstructor.Algorithm.SIRT
+#pars['projection_data'] = proj
+#pars['angles'] = angle_proj
+#pars['center_of_rotation'] = numpy.double(86.2)
+#pars['flat_field'] = flat
+#pars['iterations'] = 15
+#pars['dark_field'] = dark
+#pars['resolution'] = 1
+#pars['isLogScale'] = False
+#pars['threads'] = 3
+#
+#img_sirt = recon.reconstruct(pars)
+#
+#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM
+#img_mlem = recon.reconstruct()
+
+############################################################
+############################################################
+#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV
+#recon.pars['regularize'] = numpy.double(0.1)
+#img_cgls_conv = recon.reconstruct()
+
+niterations = 15
+threads = 3
+
+img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False)
+
+iteration_values = numpy.zeros((niterations,))
+img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+ iteration_values, False)
+print ("iteration values %s" % str(iteration_values))
+
+iteration_values = numpy.zeros((niterations,))
+img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+ numpy.double(1e-5), iteration_values , False)
+print ("iteration values %s" % str(iteration_values))
+iteration_values = numpy.zeros((niterations,))
+img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads,
+ numpy.double(1e-5), iteration_values , False)
+print ("iteration values %s" % str(iteration_values))
+
+
+##numpy.save("cgls_recon.npy", img_data)
+import matplotlib.pyplot as plt
+fig, ax = plt.subplots(1,6,sharey=True)
+ax[0].imshow(img_cgls[80])
+ax[0].axis('off') # clear x- and y-axes
+ax[1].imshow(img_sirt[80])
+ax[1].axis('off') # clear x- and y-axes
+ax[2].imshow(img_mlem[80])
+ax[2].axis('off') # clear x- and y-axesplt.show()
+ax[3].imshow(img_cgls_conv[80])
+ax[3].axis('off') # clear x- and y-axesplt.show()
+ax[4].imshow(img_cgls_tikhonov[80])
+ax[4].axis('off') # clear x- and y-axesplt.show()
+ax[5].imshow(img_cgls_TVreg[80])
+ax[5].axis('off') # clear x- and y-axesplt.show()
+
+
+plt.show()
+
+#viewer = edo.CILViewer()
+#viewer.setInputAsNumpy(img_cgls2)
+#viewer.displaySliceActor(0)
+#viewer.startRenderLoop()
+
+import vtk
+
+def NumpyToVTKImageData(numpyarray):
+ if (len(numpy.shape(numpyarray)) == 3):
+ doubleImg = vtk.vtkImageData()
+ shape = numpy.shape(numpyarray)
+ doubleImg.SetDimensions(shape[0], shape[1], shape[2])
+ doubleImg.SetOrigin(0,0,0)
+ doubleImg.SetSpacing(1,1,1)
+ doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1)
+ #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation())
+ doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1)
+
+ for i in range(shape[0]):
+ for j in range(shape[1]):
+ for k in range(shape[2]):
+ doubleImg.SetScalarComponentFromDouble(
+ i,j,k,0, numpyarray[i][j][k])
+ #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) )
+ # rescale to appropriate VTK_UNSIGNED_SHORT
+ stats = vtk.vtkImageAccumulate()
+ stats.SetInputData(doubleImg)
+ stats.Update()
+ iMin = stats.GetMin()[0]
+ iMax = stats.GetMax()[0]
+ scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin)
+
+ shiftScaler = vtk.vtkImageShiftScale ()
+ shiftScaler.SetInputData(doubleImg)
+ shiftScaler.SetScale(scale)
+ shiftScaler.SetShift(iMin)
+ shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT)
+ shiftScaler.Update()
+ return shiftScaler.GetOutput()
+
+#writer = vtk.vtkMetaImageWriter()
+#writer.SetFileName(alg + "_recon.mha")
+#writer.SetInputData(NumpyToVTKImageData(img_cgls2))
+#writer.Write()
diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/Python/ccpi/fista/__init__.py
diff --git a/src/Python/test/astra_test.py b/src/Python/test/astra_test.py
new file mode 100644
index 0000000..42c375a
--- /dev/null
+++ b/src/Python/test/astra_test.py
@@ -0,0 +1,85 @@
+import astra
+import numpy
+import filefun
+
+
+# read in the same data as the DemoRD2
+angles = filefun.dlmread("DemoRD2/angles.csv")
+darks_ar = filefun.dlmread("DemoRD2/darks_ar.csv", separator=",")
+flats_ar = filefun.dlmread("DemoRD2/flats_ar.csv", separator=",")
+
+if True:
+ Sino3D = numpy.load("DemoRD2/Sino3D.npy")
+else:
+ sino = filefun.dlmread("DemoRD2/sino_01.csv", separator=",")
+ a = map (lambda x:x, numpy.shape(sino))
+ a.append(20)
+
+ Sino3D = numpy.zeros(tuple(a), dtype="float")
+
+ for i in range(1,numpy.shape(Sino3D)[2]+1):
+ print("Read file DemoRD2/sino_%02d.csv" % i)
+ sino = filefun.dlmread("DemoRD2/sino_%02d.csv" % i, separator=",")
+ Sino3D.T[i-1] = sino.T
+
+Weights3D = numpy.asarray(Sino3D, dtype="float")
+
+##angles_rad = angles*(pi/180); % conversion to radians
+##size_det = size(data_raw3D,1); % detectors dim
+##angSize = size(data_raw3D, 2); % angles dim
+##slices_tot = size(data_raw3D, 3); % no of slices
+##recon_size = 950; % reconstruction size
+
+
+angles_rad = angles * numpy.pi /180.
+size_det, angSize, slices_tot = numpy.shape(Sino3D)
+size_det, angSize, slices_tot = [int(i) for i in numpy.shape(Sino3D)]
+recon_size = 950
+Z_slices = 3;
+det_row_count = Z_slices;
+
+#proj_geom = astra_create_proj_geom('parallel3d', 1, 1,
+# det_row_count, size_det, angles_rad);
+
+detectorSpacingX = 1.0
+detectorSpacingY = detectorSpacingX
+proj_geom = astra.create_proj_geom('parallel3d',
+ detectorSpacingX,
+ detectorSpacingY,
+ det_row_count,
+ size_det,
+ angles_rad)
+
+#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices);
+vol_geom = astra.create_vol_geom(recon_size,recon_size,Z_slices);
+
+sino = numpy.zeros((size_det, angSize, slices_tot), dtype="float")
+
+#weights = ones(size(sino));
+weights = numpy.ones(numpy.shape(sino))
+
+#####################################################################
+## PowerMethod for Lipschitz constant
+
+N = vol_geom['GridColCount']
+x1 = numpy.random.rand(1,N,N)
+#sqweight = sqrt(weights(:,:,1));
+sqweight = numpy.sqrt(weights.T[0]).T
+##proj_geomT = proj_geom;
+proj_geomT = proj_geom.copy()
+##proj_geomT.DetectorRowCount = 1;
+proj_geomT['DetectorRowCount'] = 1
+##vol_geomT = vol_geom;
+vol_geomT = vol_geom.copy()
+##vol_geomT.GridSliceCount = 1;
+vol_geomT['GridSliceCount'] = 1
+
+##[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
+
+#sino_id, y = astra.create_sino3d_gpu(x1, proj_geomT, vol_geomT);
+sino_id, y = astra.create_sino(x1, proj_geomT, vol_geomT);
+
+##y = sqweight.*y;
+##astra_mex_data3d('delete', sino_id);
+
+
diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py
index b042341..10f25ed 100644
--- a/src/Python/test/readhd5.py
+++ b/src/Python/test/readhd5.py
@@ -25,6 +25,7 @@ angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0]
angles_rad = numpy.asarray(nx.get('/angles_rad'))
recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0]
size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0]
+
slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0]
#from ccpi.viewer.CILViewer2D import CILViewer2D
diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py
index 0fd08f5..a4a622b 100644
--- a/src/Python/test_reconstructor.py
+++ b/src/Python/test_reconstructor.py
@@ -9,7 +9,7 @@ Based on DemoRD2.m
import h5py
import numpy
-from ccpi.reconstruction_dev.FISTAReconstructor import FISTAReconstructor
+from ccpi.fista.FISTAReconstructor import FISTAReconstructor
import astra
##def getEntry(nx, location):
@@ -23,10 +23,10 @@ nx = h5py.File(filename, "r")
entries = [entry for entry in nx['/'].keys()]
print (entries)
-Sino3D = numpy.asarray(nx.get('/Sino3D'))
-Weights3D = numpy.asarray(nx.get('/Weights3D'))
+Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32")
+Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32")
angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0]
-angles_rad = numpy.asarray(nx.get('/angles_rad'))
+angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32")
recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0]
size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0]
slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0]
@@ -58,21 +58,23 @@ vol_geom = astra.creators.create_vol_geom( image_size_x,
## First pass the arguments to the FISTAReconstructor and test the
## Lipschitz constant
-#fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D )
+fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D , weights=Weights3D)
+print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant']))
#N = params.vol_geom.GridColCount
pars = dict()
-pars['projector_geometry'] = proj_geom
-pars['output_geometry'] = vol_geom
-pars['input_sinogram'] = Sino3D
+pars['projector_geometry'] = proj_geom.copy()
+pars['output_geometry'] = vol_geom.copy()
+pars['input_sinogram'] = Sino3D.copy()
sliceZ , nangles , detectors = numpy.shape(Sino3D)
pars['detectors'] = detectors
pars['number_of_angles'] = nangles
pars['SlicesZ'] = sliceZ
-pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram']))
-
+#pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram']))
+pars['weights'] = Weights3D.copy()
+
N = pars['output_geometry']['GridColCount']
proj_geom = pars['projector_geometry']
vol_geom = pars['output_geometry']
@@ -82,7 +84,7 @@ SlicesZ = pars['SlicesZ']
if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
#% for parallel geometry we can do just one slice
print('Calculating Lipshitz constant for parallel beam geometry...')
- niter = 15;# % number of iteration for the PM
+ niter = 5;# % number of iteration for the PM
#N = params.vol_geom.GridColCount;
#x1 = rand(N,N,1);
x1 = numpy.random.rand(1,N,N)
@@ -96,7 +98,8 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
#[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT);
import matplotlib.pyplot as plt
- fig = plt.figure()
+ fig = []
+ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
#a.set_title('Lipschitz')
for i in range(niter):
@@ -107,14 +110,27 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
# y = sqweight.*y;
# astra_mex_data3d('delete', sino_id);
# astra_mex_data3d('delete', id);
- print ("iteration {0}".format(i))
+ #print ("iteration {0}".format(i))
+ fig.append(plt.figure())
+
+ a=fig[-1].add_subplot(1,2,1)
+ a.text(0.05, 0.95, "iteration {0}, x1".format(i), transform=a.transAxes,
+ fontsize=14,verticalalignment='top', bbox=props)
+
+ imgplot = plt.imshow(x1[0].copy())
+
+
sino_id, y = astra.creators.create_sino3d_gpu(x1,
proj_geomT,
vol_geomT)
- #a=fig.add_subplot(2,1,1)
- #imgplot = plt.imshow(y[0])
+ a=fig[-1].add_subplot(1,2,2)
+ a.text(0.05, 0.95, "iteration {0}, y".format(i),
+ transform=a.transAxes, fontsize=14,verticalalignment='top',
+ bbox=props)
+
+ imgplot = plt.imshow(y[0].copy())
- y = sqweight * y # element wise multiplication
+ y = (sqweight * y) # element wise multiplication
#b=fig.add_subplot(2,1,2)
#imgplot = plt.imshow(x1[0])
@@ -122,15 +138,17 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
#astra_mex_data3d('delete', sino_id);
astra.matlab.data3d('delete', sino_id)
+ del x1
- idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,
+ idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y),
proj_geomT,
- vol_geomT);
- print ("shape {1} x1 {0}".format(x1.T[:4].T, numpy.shape(x1)))
+ vol_geomT)
+ del y
+
+
s = numpy.linalg.norm(x1)
### this line?
- x1 = x1/s;
- print ("x1 {0}".format(x1.T[:4].T))
+ x1 = (x1/s)
# ### this line?
# sino_id, y = astra.creators.create_sino3d_gpu(x1,
@@ -138,10 +156,13 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'):
# vol_geomT);
# y = sqweight * y;
astra.matlab.data3d('delete', sino_id);
- astra.matlab.data3d('delete', idx);
+ astra.matlab.data3d('delete', idx)
+ print ("iteration {0} s= {1}".format(i,s))
+
#end
del proj_geomT
del vol_geomT
+ #plt.show()
else:
#% divergen beam geometry
print('Calculating Lipshitz constant for divergen beam geometry...')
diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py
index 5804897..e76262c 100644
--- a/src/Python/test_regularizers.py
+++ b/src/Python/test_regularizers.py
@@ -5,15 +5,15 @@ Created on Fri Aug 4 11:10:05 2017
@author: ofn77899
"""
-from ccpi.viewer.CILViewer2D import Converter
-import vtk
+#from ccpi.viewer.CILViewer2D import Converter
+#import vtk
import matplotlib.pyplot as plt
import numpy as np
import os
from enum import Enum
import timeit
-
+#from PIL import Image
#from Regularizer import Regularizer
from ccpi.imaging.Regularizer import Regularizer
@@ -46,12 +46,21 @@ def nrmse(im1, im2):
# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0;
# u = SplitBregman_TV(single(u0), 10, 30, 1e-04);
-filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif"
-reader = vtk.vtkTIFFReader()
-reader.SetFileName(os.path.normpath(filename))
-reader.Update()
-#vtk returns 3D images, let's take just the one slice there is as 2D
-Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255
+
+#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif"
+filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif"
+#filename = r'/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif'
+
+#reader = vtk.vtkTIFFReader()
+#reader.SetFileName(os.path.normpath(filename))
+#reader.Update()
+Im = plt.imread(filename)
+#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255
+#img.show()
+Im = np.asarray(Im, dtype='float32')
+
+
+
#imgplot = plt.imshow(Im)
perc = 0.05
@@ -68,7 +77,7 @@ fig = plt.figure()
a=fig.add_subplot(2,3,1)
a.set_title('noise')
-imgplot = plt.imshow(u0)
+imgplot = plt.imshow(u0,cmap="gray")
reg_output = []
##############################################################################
@@ -80,6 +89,7 @@ reg_output = []
use_object = True
if use_object:
reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV)
+ print (reg.pars)
reg.setParameter(input=u0)
reg.setParameter(regularization_parameter=10.)
# or
@@ -113,7 +123,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
-imgplot = plt.imshow(plotme)
+imgplot = plt.imshow(plotme,cmap="gray")
###################### FGP_TV #########################################
# u = FGP_TV(single(u0), 0.05, 100, 1e-04);
@@ -131,8 +141,12 @@ textstr = out2[-1]
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
- verticalalignment='top', bbox=props)
+ verticalalignment='top', bbox=props)
imgplot = plt.imshow(reg_output[-1][0])
+# place a text box in upper left in axes coords
+a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0],cmap="gray")
###################### LLT_model #########################################
# * u0 = Im + .03*randn(size(Im)); % adding noise
@@ -149,13 +163,16 @@ pars = out2[-2]
reg_output.append(out2)
a=fig.add_subplot(2,3,4)
+
textstr = out2[-1]
+
# these are matplotlib.patch.Patch properties
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
- verticalalignment='top', bbox=props)
-imgplot = plt.imshow(reg_output[-1][0])
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0],cmap="gray")
+
# ###################### PatchBased_Regul #########################################
# # Quick 2D denoising example in Matlab:
@@ -164,9 +181,9 @@ imgplot = plt.imshow(reg_output[-1][0])
# # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);
out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05,
- searching_window_ratio=3,
- similarity_window_ratio=1,
- PB_filtering_parameter=0.08)
+ searching_window_ratio=3,
+ similarity_window_ratio=1,
+ PB_filtering_parameter=0.08)
pars = out2[-2]
reg_output.append(out2)
@@ -180,20 +197,20 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
-imgplot = plt.imshow(reg_output[-1][0])
+imgplot = plt.imshow(reg_output[-1][0],cmap="gray")
-###################### TGV_PD #########################################
-# Quick 2D denoising example in Matlab:
-# Im = double(imread('lena_gray_256.tif'))/255; % loading image
-# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
-# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
+# ###################### TGV_PD #########################################
+# # Quick 2D denoising example in Matlab:
+# # Im = double(imread('lena_gray_256.tif'))/255; % loading image
+# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise
+# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550);
out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05,
- first_order_term=1.3,
- second_order_term=1,
- number_of_iterations=550)
+ first_order_term=1.3,
+ second_order_term=1,
+ number_of_iterations=550)
pars = out2[-2]
reg_output.append(out2)
@@ -207,8 +224,8 @@ textstr = out2[-1]
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
# place a text box in upper left in axes coords
a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14,
- verticalalignment='top', bbox=props)
-imgplot = plt.imshow(reg_output[-1][0])
+ verticalalignment='top', bbox=props)
+imgplot = plt.imshow(reg_output[-1][0],cmap="gray")
plt.show()