diff options
-rw-r--r-- | demos/exportDemoRD2Data.m | 2 | ||||
-rw-r--r-- | main_func/regularizers_CPU/FGP_TV_core.h | 2 | ||||
-rw-r--r-- | main_func/regularizers_CPU/LLT_model.c | 6 | ||||
-rw-r--r-- | src/Python/ccpi/fista/FISTAReconstructor.py | 389 | ||||
-rw-r--r-- | src/Python/ccpi/fista/Reconstructor.py | 425 | ||||
-rw-r--r-- | src/Python/ccpi/fista/__init__.py | 0 | ||||
-rw-r--r-- | src/Python/test/astra_test.py | 85 | ||||
-rw-r--r-- | src/Python/test/readhd5.py | 1 | ||||
-rw-r--r-- | src/Python/test_reconstructor.py | 65 | ||||
-rw-r--r-- | src/Python/test_regularizers.py | 73 |
10 files changed, 990 insertions, 58 deletions
diff --git a/demos/exportDemoRD2Data.m b/demos/exportDemoRD2Data.m index 0ca036b..028353b 100644 --- a/demos/exportDemoRD2Data.m +++ b/demos/exportDemoRD2Data.m @@ -24,7 +24,7 @@ end Sino3D = Sino3D.*1000; Weights3D = single(data_raw3D); % weights for PW model -%clear data_raw3D +clear data_raw3D hdf5write('DendrData.h5', '/Weights3D', Weights3D) hdf5write('DendrData.h5', '/Sino3D', Sino3D, 'WriteMode', 'append') diff --git a/main_func/regularizers_CPU/FGP_TV_core.h b/main_func/regularizers_CPU/FGP_TV_core.h index 7e7229a..6430bf2 100644 --- a/main_func/regularizers_CPU/FGP_TV_core.h +++ b/main_func/regularizers_CPU/FGP_TV_core.h @@ -68,4 +68,4 @@ float Rupd_func3D(float *P1, float *P1_old, float *P2, float *P2_old, float *P3, float Obj_func_CALC3D(float *A, float *D, float *funcvalA, float lambda, int dimX, int dimY, int dimZ); #ifdef __cplusplus } -#endif +#endif
\ No newline at end of file diff --git a/main_func/regularizers_CPU/LLT_model.c b/main_func/regularizers_CPU/LLT_model.c index 0050654..0b07b47 100644 --- a/main_func/regularizers_CPU/LLT_model.c +++ b/main_func/regularizers_CPU/LLT_model.c @@ -24,13 +24,7 @@ limitations under the License. /* C-OMP implementation of Lysaker, Lundervold and Tai (LLT) model of higher order regularization penalty * * Input Parameters: -* 1. U0 - origanal noise image/volume -======= -/* C-OMP implementation of Lysaker, Lundervold and Tai (LLT) model of higher order regularization penalty -* -* Input Parameters: * 1. U0 - original noise image/volume ->>>>>>> fix typo and add include "matrix.h" * 2. lambda - regularization parameter * 3. tau - time-step for explicit scheme * 4. iter - iterations number diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py new file mode 100644 index 0000000..1e76815 --- /dev/null +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -0,0 +1,389 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +#from ccpi.reconstruction.parallelbeam import alg + +#from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + # handle parmeters: + # obligatory parameters + self.pars = dict() + self.pars['projector_geometry'] = projector_geometry + self.pars['output_geometry'] = output_geometry + self.pars['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.pars['detectors'] = detectors + self.pars['number_og_angles'] = nangles + self.pars['SlicesZ'] = sliceZ + + print (self.pars) + # handle optional input parameters (at instantiation) + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.pars['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not 'ideal_image' in kwargs.keys(): + self.pars['ideal_image'] = None + + if not 'region_of_interest'in kwargs.keys() : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not 'regularizer' in kwargs.keys() : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + N = self.pars['output_geometry']['GridColCount'] + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 5;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights[0]) + proj_geomT = proj_geom.copy(); + proj_geomT['DetectorRowCount'] = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + + + for i in range(niter): + # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); + # s = norm(x1(:)); + # x1 = x1/s; + # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + # y = sqweight.*y; + # astra_mex_data3d('delete', sino_id); + # astra_mex_data3d('delete', id); + #print ("iteration {0}".format(i)) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + + y = (sqweight * y).copy() # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + del x1 + + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + proj_geomT, + vol_geomT) + del y + + + s = numpy.linalg.norm(x1) + ### this line? + x1 = (x1/s).copy(); + + # ### this line? + # sino_id, y = astra.creators.create_sino3d_gpu(x1, + # proj_geomT, + # vol_geomT); + # y = sqweight * y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + + #end + del proj_geomT + del vol_geomT + #plt.show() + else: + #% divergen beam geometry + print('Calculating Lipshitz constant for divergen beam geometry...') + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location, nx): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py new file mode 100644 index 0000000..d29ac0d --- /dev/null +++ b/src/Python/ccpi/fista/Reconstructor.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/Python/ccpi/fista/__init__.py diff --git a/src/Python/test/astra_test.py b/src/Python/test/astra_test.py new file mode 100644 index 0000000..42c375a --- /dev/null +++ b/src/Python/test/astra_test.py @@ -0,0 +1,85 @@ +import astra +import numpy +import filefun + + +# read in the same data as the DemoRD2 +angles = filefun.dlmread("DemoRD2/angles.csv") +darks_ar = filefun.dlmread("DemoRD2/darks_ar.csv", separator=",") +flats_ar = filefun.dlmread("DemoRD2/flats_ar.csv", separator=",") + +if True: + Sino3D = numpy.load("DemoRD2/Sino3D.npy") +else: + sino = filefun.dlmread("DemoRD2/sino_01.csv", separator=",") + a = map (lambda x:x, numpy.shape(sino)) + a.append(20) + + Sino3D = numpy.zeros(tuple(a), dtype="float") + + for i in range(1,numpy.shape(Sino3D)[2]+1): + print("Read file DemoRD2/sino_%02d.csv" % i) + sino = filefun.dlmread("DemoRD2/sino_%02d.csv" % i, separator=",") + Sino3D.T[i-1] = sino.T + +Weights3D = numpy.asarray(Sino3D, dtype="float") + +##angles_rad = angles*(pi/180); % conversion to radians +##size_det = size(data_raw3D,1); % detectors dim +##angSize = size(data_raw3D, 2); % angles dim +##slices_tot = size(data_raw3D, 3); % no of slices +##recon_size = 950; % reconstruction size + + +angles_rad = angles * numpy.pi /180. +size_det, angSize, slices_tot = numpy.shape(Sino3D) +size_det, angSize, slices_tot = [int(i) for i in numpy.shape(Sino3D)] +recon_size = 950 +Z_slices = 3; +det_row_count = Z_slices; + +#proj_geom = astra_create_proj_geom('parallel3d', 1, 1, +# det_row_count, size_det, angles_rad); + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX +proj_geom = astra.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + size_det, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +vol_geom = astra.create_vol_geom(recon_size,recon_size,Z_slices); + +sino = numpy.zeros((size_det, angSize, slices_tot), dtype="float") + +#weights = ones(size(sino)); +weights = numpy.ones(numpy.shape(sino)) + +##################################################################### +## PowerMethod for Lipschitz constant + +N = vol_geom['GridColCount'] +x1 = numpy.random.rand(1,N,N) +#sqweight = sqrt(weights(:,:,1)); +sqweight = numpy.sqrt(weights.T[0]).T +##proj_geomT = proj_geom; +proj_geomT = proj_geom.copy() +##proj_geomT.DetectorRowCount = 1; +proj_geomT['DetectorRowCount'] = 1 +##vol_geomT = vol_geom; +vol_geomT = vol_geom.copy() +##vol_geomT.GridSliceCount = 1; +vol_geomT['GridSliceCount'] = 1 + +##[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + +#sino_id, y = astra.create_sino3d_gpu(x1, proj_geomT, vol_geomT); +sino_id, y = astra.create_sino(x1, proj_geomT, vol_geomT); + +##y = sqweight.*y; +##astra_mex_data3d('delete', sino_id); + + diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py index b042341..10f25ed 100644 --- a/src/Python/test/readhd5.py +++ b/src/Python/test/readhd5.py @@ -25,6 +25,7 @@ angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] angles_rad = numpy.asarray(nx.get('/angles_rad')) recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] + slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] #from ccpi.viewer.CILViewer2D import CILViewer2D diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index 0fd08f5..a4a622b 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -9,7 +9,7 @@ Based on DemoRD2.m import h5py import numpy -from ccpi.reconstruction_dev.FISTAReconstructor import FISTAReconstructor +from ccpi.fista.FISTAReconstructor import FISTAReconstructor import astra ##def getEntry(nx, location): @@ -23,10 +23,10 @@ nx = h5py.File(filename, "r") entries = [entry for entry in nx['/'].keys()] print (entries) -Sino3D = numpy.asarray(nx.get('/Sino3D')) -Weights3D = numpy.asarray(nx.get('/Weights3D')) +Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") +Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] -angles_rad = numpy.asarray(nx.get('/angles_rad')) +angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] @@ -58,21 +58,23 @@ vol_geom = astra.creators.create_vol_geom( image_size_x, ## First pass the arguments to the FISTAReconstructor and test the ## Lipschitz constant -#fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D ) +fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D , weights=Weights3D) +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) #N = params.vol_geom.GridColCount pars = dict() -pars['projector_geometry'] = proj_geom -pars['output_geometry'] = vol_geom -pars['input_sinogram'] = Sino3D +pars['projector_geometry'] = proj_geom.copy() +pars['output_geometry'] = vol_geom.copy() +pars['input_sinogram'] = Sino3D.copy() sliceZ , nangles , detectors = numpy.shape(Sino3D) pars['detectors'] = detectors pars['number_of_angles'] = nangles pars['SlicesZ'] = sliceZ -pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) - +#pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) +pars['weights'] = Weights3D.copy() + N = pars['output_geometry']['GridColCount'] proj_geom = pars['projector_geometry'] vol_geom = pars['output_geometry'] @@ -82,7 +84,7 @@ SlicesZ = pars['SlicesZ'] if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice print('Calculating Lipshitz constant for parallel beam geometry...') - niter = 15;# % number of iteration for the PM + niter = 5;# % number of iteration for the PM #N = params.vol_geom.GridColCount; #x1 = rand(N,N,1); x1 = numpy.random.rand(1,N,N) @@ -96,7 +98,8 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); import matplotlib.pyplot as plt - fig = plt.figure() + fig = [] + props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) #a.set_title('Lipschitz') for i in range(niter): @@ -107,14 +110,27 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): # y = sqweight.*y; # astra_mex_data3d('delete', sino_id); # astra_mex_data3d('delete', id); - print ("iteration {0}".format(i)) + #print ("iteration {0}".format(i)) + fig.append(plt.figure()) + + a=fig[-1].add_subplot(1,2,1) + a.text(0.05, 0.95, "iteration {0}, x1".format(i), transform=a.transAxes, + fontsize=14,verticalalignment='top', bbox=props) + + imgplot = plt.imshow(x1[0].copy()) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT) - #a=fig.add_subplot(2,1,1) - #imgplot = plt.imshow(y[0]) + a=fig[-1].add_subplot(1,2,2) + a.text(0.05, 0.95, "iteration {0}, y".format(i), + transform=a.transAxes, fontsize=14,verticalalignment='top', + bbox=props) + + imgplot = plt.imshow(y[0].copy()) - y = sqweight * y # element wise multiplication + y = (sqweight * y) # element wise multiplication #b=fig.add_subplot(2,1,2) #imgplot = plt.imshow(x1[0]) @@ -122,15 +138,17 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #astra_mex_data3d('delete', sino_id); astra.matlab.data3d('delete', sino_id) + del x1 - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y), proj_geomT, - vol_geomT); - print ("shape {1} x1 {0}".format(x1.T[:4].T, numpy.shape(x1))) + vol_geomT) + del y + + s = numpy.linalg.norm(x1) ### this line? - x1 = x1/s; - print ("x1 {0}".format(x1.T[:4].T)) + x1 = (x1/s) # ### this line? # sino_id, y = astra.creators.create_sino3d_gpu(x1, @@ -138,10 +156,13 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): # vol_geomT); # y = sqweight * y; astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + #end del proj_geomT del vol_geomT + #plt.show() else: #% divergen beam geometry print('Calculating Lipshitz constant for divergen beam geometry...') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 5804897..e76262c 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -5,15 +5,15 @@ Created on Fri Aug 4 11:10:05 2017 @author: ofn77899 """ -from ccpi.viewer.CILViewer2D import Converter -import vtk +#from ccpi.viewer.CILViewer2D import Converter +#import vtk import matplotlib.pyplot as plt import numpy as np import os from enum import Enum import timeit - +#from PIL import Image #from Regularizer import Regularizer from ccpi.imaging.Regularizer import Regularizer @@ -46,12 +46,21 @@ def nrmse(im1, im2): # u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; # u = SplitBregman_TV(single(u0), 10, 30, 1e-04); -filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" -reader = vtk.vtkTIFFReader() -reader.SetFileName(os.path.normpath(filename)) -reader.Update() -#vtk returns 3D images, let's take just the one slice there is as 2D -Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 + +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif" +#filename = r'/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif' + +#reader = vtk.vtkTIFFReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +Im = plt.imread(filename) +#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255 +#img.show() +Im = np.asarray(Im, dtype='float32') + + + #imgplot = plt.imshow(Im) perc = 0.05 @@ -68,7 +77,7 @@ fig = plt.figure() a=fig.add_subplot(2,3,1) a.set_title('noise') -imgplot = plt.imshow(u0) +imgplot = plt.imshow(u0,cmap="gray") reg_output = [] ############################################################################## @@ -80,6 +89,7 @@ reg_output = [] use_object = True if use_object: reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + print (reg.pars) reg.setParameter(input=u0) reg.setParameter(regularization_parameter=10.) # or @@ -113,7 +123,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(plotme) +imgplot = plt.imshow(plotme,cmap="gray") ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); @@ -131,8 +141,12 @@ textstr = out2[-1] props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) + verticalalignment='top', bbox=props) imgplot = plt.imshow(reg_output[-1][0]) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") ###################### LLT_model ######################################### # * u0 = Im + .03*randn(size(Im)); % adding noise @@ -149,13 +163,16 @@ pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,4) + textstr = out2[-1] + # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + # ###################### PatchBased_Regul ######################################### # # Quick 2D denoising example in Matlab: @@ -164,9 +181,9 @@ imgplot = plt.imshow(reg_output[-1][0]) # # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - searching_window_ratio=3, - similarity_window_ratio=1, - PB_filtering_parameter=0.08) + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) pars = out2[-2] reg_output.append(out2) @@ -180,20 +197,20 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") -###################### TGV_PD ######################################### -# Quick 2D denoising example in Matlab: -# Im = double(imread('lena_gray_256.tif'))/255; % loading image -# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); +# ###################### TGV_PD ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - first_order_term=1.3, - second_order_term=1, - number_of_iterations=550) + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) pars = out2[-2] reg_output.append(out2) @@ -207,8 +224,8 @@ textstr = out2[-1] props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") plt.show() |