summaryrefslogtreecommitdiffstats
path: root/samples/python
diff options
context:
space:
mode:
authorWillem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl>2016-10-06 12:30:18 +0200
committerWillem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl>2016-10-06 12:30:18 +0200
commit0cec258c5079cc065fa75f82ae8d785986ebdf18 (patch)
treee7ca39da75ad5c9d728698295ac9c8ec32e4e499 /samples/python
parentc2cdbc312196481edd202baa3bd668396e78534c (diff)
parent7bb42ddd9e26fc7c01734d26bc114b5a935d9110 (diff)
downloadastra-0cec258c5079cc065fa75f82ae8d785986ebdf18.tar.gz
astra-0cec258c5079cc065fa75f82ae8d785986ebdf18.tar.bz2
astra-0cec258c5079cc065fa75f82ae8d785986ebdf18.tar.xz
astra-0cec258c5079cc065fa75f82ae8d785986ebdf18.zip
Merge branch 'master' into FDK
Diffstat (limited to 'samples/python')
-rw-r--r--samples/python/s009_projection_matrix.py2
-rw-r--r--samples/python/s015_fp_bp.py6
-rw-r--r--samples/python/s017_OpTomo.py2
-rw-r--r--samples/python/s018_plugin.py34
4 files changed, 26 insertions, 18 deletions
diff --git a/samples/python/s009_projection_matrix.py b/samples/python/s009_projection_matrix.py
index c4c4557..e20d58c 100644
--- a/samples/python/s009_projection_matrix.py
+++ b/samples/python/s009_projection_matrix.py
@@ -46,7 +46,7 @@ W = astra.matrix.get(matrix_id)
# Manually use this projection matrix to do a projection:
import scipy.io
P = scipy.io.loadmat('phantom.mat')['phantom256']
-s = W.dot(P.flatten())
+s = W.dot(P.ravel())
s = np.reshape(s, (len(proj_geom['ProjectionAngles']),proj_geom['DetectorCount']))
import pylab
diff --git a/samples/python/s015_fp_bp.py b/samples/python/s015_fp_bp.py
index fa0bf86..ff0b30a 100644
--- a/samples/python/s015_fp_bp.py
+++ b/samples/python/s015_fp_bp.py
@@ -46,12 +46,12 @@ class astra_wrap(object):
def matvec(self,v):
sid, s = astra.create_sino(np.reshape(v,(vol_geom['GridRowCount'],vol_geom['GridColCount'])),self.proj_id)
astra.data2d.delete(sid)
- return s.flatten()
+ return s.ravel()
def rmatvec(self,v):
bid, b = astra.create_backprojection(np.reshape(v,(len(proj_geom['ProjectionAngles']),proj_geom['DetectorCount'],)),self.proj_id)
astra.data2d.delete(bid)
- return b.flatten()
+ return b.ravel()
vol_geom = astra.create_vol_geom(256, 256)
proj_geom = astra.create_proj_geom('parallel', 1.0, 384, np.linspace(0,np.pi,180,False))
@@ -65,7 +65,7 @@ proj_id = astra.create_projector('cuda',proj_geom,vol_geom)
sinogram_id, sinogram = astra.create_sino(P, proj_id)
# Reshape the sinogram into a vector
-b = sinogram.flatten()
+b = sinogram.ravel()
# Call lsqr with ASTRA FP and BP
import scipy.sparse.linalg
diff --git a/samples/python/s017_OpTomo.py b/samples/python/s017_OpTomo.py
index 967fa64..214e9a7 100644
--- a/samples/python/s017_OpTomo.py
+++ b/samples/python/s017_OpTomo.py
@@ -50,7 +50,7 @@ pylab.figure(2)
pylab.imshow(sinogram)
# Run the lsqr linear solver
-output = scipy.sparse.linalg.lsqr(W, sinogram.flatten(), iter_lim=150)
+output = scipy.sparse.linalg.lsqr(W, sinogram.ravel(), iter_lim=150)
rec = output[0].reshape([256, 256])
pylab.figure(3)
diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py
index 31cca95..85b5486 100644
--- a/samples/python/s018_plugin.py
+++ b/samples/python/s018_plugin.py
@@ -30,30 +30,38 @@ import six
# Define the plugin class (has to subclass astra.plugin.base)
# Note that usually, these will be defined in a separate package/module
-class SIRTPlugin(astra.plugin.base):
- """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm.
+class LandweberPlugin(astra.plugin.base):
+ """Example of an ASTRA plugin class, implementing a simple 2D Landweber algorithm.
Options:
- 'rel_factor': relaxation factor (optional)
+ 'Relaxation': relaxation factor (optional)
"""
# The astra_name variable defines the name to use to
# call the plugin from ASTRA
- astra_name = "SIRT-PLUGIN"
+ astra_name = "LANDWEBER-PLUGIN"
- def initialize(self,cfg, rel_factor = 1):
+ def initialize(self,cfg, Relaxation = 1):
self.W = astra.OpTomo(cfg['ProjectorId'])
self.vid = cfg['ReconstructionDataId']
self.sid = cfg['ProjectionDataId']
- self.rel = rel_factor
+ self.rel = Relaxation
def run(self, its):
v = astra.data2d.get_shared(self.vid)
s = astra.data2d.get_shared(self.sid)
+ tv = np.zeros(v.shape, dtype=np.float32)
+ ts = np.zeros(s.shape, dtype=np.float32)
W = self.W
for i in range(its):
- v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size
+ W.FP(v,out=ts)
+ ts -= s # ts = W*v - s
+
+ W.BP(ts,out=tv)
+ tv *= self.rel / s.size
+
+ v -= tv # v = v - rel * W'*(W*v-s) / s.size
if __name__=='__main__':
@@ -75,20 +83,20 @@ if __name__=='__main__':
# First we import the package that contains the plugin
import s018_plugin
# Then, we register the plugin class with ASTRA
- astra.plugin.register(s018_plugin.SIRTPlugin)
+ astra.plugin.register(s018_plugin.LandweberPlugin)
# Get a list of registered plugins
six.print_(astra.plugin.get_registered())
# To get help on a registered plugin, use get_help
- six.print_(astra.plugin.get_help('SIRT-PLUGIN'))
+ six.print_(astra.plugin.get_help('LANDWEBER-PLUGIN'))
# Create data structures
sid = astra.data2d.create('-sino', proj_geom, sinogram)
vid = astra.data2d.create('-vol', vol_geom)
# Create config using plugin name
- cfg = astra.astra_dict('SIRT-PLUGIN')
+ cfg = astra.astra_dict('LANDWEBER-PLUGIN')
cfg['ProjectorId'] = proj_id
cfg['ProjectionDataId'] = sid
cfg['ReconstructionDataId'] = vid
@@ -103,18 +111,18 @@ if __name__=='__main__':
rec = astra.data2d.get(vid)
# Options for the plugin go in cfg['option']
- cfg = astra.astra_dict('SIRT-PLUGIN')
+ cfg = astra.astra_dict('LANDWEBER-PLUGIN')
cfg['ProjectorId'] = proj_id
cfg['ProjectionDataId'] = sid
cfg['ReconstructionDataId'] = vid
cfg['option'] = {}
- cfg['option']['rel_factor'] = 1.5
+ cfg['option']['Relaxation'] = 1.5
alg_id_rel = astra.algorithm.create(cfg)
astra.algorithm.run(alg_id_rel, 100)
rec_rel = astra.data2d.get(vid)
# We can also use OpTomo to call the plugin
- rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5})
+ rec_op = W.reconstruct('LANDWEBER-PLUGIN', sinogram, 100, extraOptions={'Relaxation':1.5})
import pylab as pl
pl.gray()