diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Python/ccpi/fista/FISTAReconstructor.py | 164 |
1 files changed, 34 insertions, 130 deletions
diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index 1e76815..cbd27da 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -73,7 +73,8 @@ class FISTAReconstructor(): # 3. "A novel tomographic reconstruction method based on the robust # Student's t function for suppressing data outliers" D. Kazantsev et.al. # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + def __init__(self, projector_geometry, output_geometry, input_sinogram, + **kwargs): # handle parmeters: # obligatory parameters self.pars = dict() @@ -98,6 +99,7 @@ class FISTAReconstructor(): 'regularizer' , 'ring_lambda_R_L1', 'ring_alpha') + self.acceptedInputKeywords = kw # handle keyworded parameters if kwargs is not None: @@ -114,11 +116,14 @@ class FISTAReconstructor(): if 'weights' in kwargs.keys(): self.pars['weights'] = kwargs['weights'] else: - self.pars['weights'] = numpy.ones(numpy.shape(self.pars['input_sinogram'])) + self.pars['weights'] = \ + numpy.ones(numpy.shape( + self.pars['input_sinogram'])) if 'Lipschitz_constant' in kwargs.keys(): self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + self.pars['Lipschitz_constant'] = \ + self.calculateLipschitzConstantWithPowerMethod() if not 'ideal_image' in kwargs.keys(): self.pars['ideal_image'] = None @@ -127,7 +132,8 @@ class FISTAReconstructor(): if self.pars['ideal_image'] == None: pass else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + self.pars['region_of_interest'] = numpy.nonzero( + self.pars['ideal_image']>0.0) if not 'regularizer' in kwargs.keys() : self.pars['regularizer'] = None @@ -140,7 +146,29 @@ class FISTAReconstructor(): + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.acceptedInputKeywords.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for '.format(key) + + 'Reconstruction algorithm') + # setParameter + def calculateLipschitzConstantWithPowerMethod(self): ''' using Power method (PM) to establish L constant''' @@ -152,7 +180,8 @@ class FISTAReconstructor(): - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + if (proj_geom['type'] == 'parallel') or \ + (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice #print('Calculating Lipshitz constant for parallel beam geometry...') niter = 5;# % number of iteration for the PM @@ -262,128 +291,3 @@ class FISTAReconstructor(): - - -def getEntry(location, nx): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" -##nx = h5py.File(fname, "r") -## -### the data are stored in a particular location in the hdf5 -##for item in nx['entry1/tomo_entry/data'].keys(): -## print (item) -## -##data = nx.get('entry1/tomo_entry/data/rotation_angle') -##angles = numpy.zeros(data.shape) -##data.read_direct(angles) -##print (angles) -### angles should be in degrees -## -##data = nx.get('entry1/tomo_entry/data/data') -##stack = numpy.zeros(data.shape) -##data.read_direct(stack) -##print (data.shape) -## -##print ("Data Loaded") -## -## -### Normalize -##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -##itype = numpy.zeros(data.shape) -##data.read_direct(itype) -### 2 is dark field -##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -##dark = darks[0] -##for i in range(1, len(darks)): -## dark += darks[i] -##dark = dark / len(darks) -###dark[0][0] = dark[0][1] -## -### 1 is flat field -##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -##flat = flats[0] -##for i in range(1, len(flats)): -## flat += flats[i] -##flat = flat / len(flats) -###flat[0][0] = dark[0][1] -## -## -### 0 is projection data -##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = numpy.asarray (angle_proj) -##angle_proj = angle_proj.astype(numpy.float32) -## -### normalized data are -### norm = (projection - dark)/(flat-dark) -## -##def normalize(projection, dark, flat, def_val=0.1): -## a = (projection - dark) -## b = (flat-dark) -## with numpy.errstate(divide='ignore', invalid='ignore'): -## c = numpy.true_divide( a, b ) -## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 -## return c -## -## -##norm = [normalize(projection, dark, flat) for projection in proj] -##norm = numpy.asarray (norm) -##norm = norm.astype(numpy.float32) - - -##niterations = 15 -##threads = 3 -## -##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## iteration_values, False) -##print ("iteration values %s" % str(iteration_values)) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -## -## -####numpy.save("cgls_recon.npy", img_data) -##import matplotlib.pyplot as plt -##fig, ax = plt.subplots(1,6,sharey=True) -##ax[0].imshow(img_cgls[80]) -##ax[0].axis('off') # clear x- and y-axes -##ax[1].imshow(img_sirt[80]) -##ax[1].axis('off') # clear x- and y-axes -##ax[2].imshow(img_mlem[80]) -##ax[2].axis('off') # clear x- and y-axesplt.show() -##ax[3].imshow(img_cgls_conv[80]) -##ax[3].axis('off') # clear x- and y-axesplt.show() -##ax[4].imshow(img_cgls_tikhonov[80]) -##ax[4].axis('off') # clear x- and y-axesplt.show() -##ax[5].imshow(img_cgls_TVreg[80]) -##ax[5].axis('off') # clear x- and y-axesplt.show() -## -## -##plt.show() -## - |