summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--python/astra/optomo.py96
-rw-r--r--samples/python/s018_plugin.py34
2 files changed, 87 insertions, 43 deletions
diff --git a/python/astra/optomo.py b/python/astra/optomo.py
index 5a92998..dde719e 100644
--- a/python/astra/optomo.py
+++ b/python/astra/optomo.py
@@ -111,21 +111,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator):
:param v: Volume to forward project.
:type v: :class:`numpy.ndarray`
"""
- v = self.__checkArray(v, self.vshape)
- vid = self.data_mod.link('-vol',self.vg,v)
- s = np.zeros(self.sshape,dtype=np.float32)
- sid = self.data_mod.link('-sino',self.pg,s)
-
- cfg = creators.astra_dict('FP'+self.appendString)
- cfg['ProjectionDataId'] = sid
- cfg['VolumeDataId'] = vid
- cfg['ProjectorId'] = self.proj_id
- fp_id = algorithm.create(cfg)
- algorithm.run(fp_id)
-
- algorithm.delete(fp_id)
- self.data_mod.delete([vid,sid])
- return s.ravel()
+ return self.FP(v, out=None).ravel()
def rmatvec(self,s):
"""Implements the transpose operator.
@@ -133,21 +119,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator):
:param s: The projection data.
:type s: :class:`numpy.ndarray`
"""
- s = self.__checkArray(s, self.sshape)
- sid = self.data_mod.link('-sino',self.pg,s)
- v = np.zeros(self.vshape,dtype=np.float32)
- vid = self.data_mod.link('-vol',self.vg,v)
-
- cfg = creators.astra_dict('BP'+self.appendString)
- cfg['ProjectionDataId'] = sid
- cfg['ReconstructionDataId'] = vid
- cfg['ProjectorId'] = self.proj_id
- bp_id = algorithm.create(cfg)
- algorithm.run(bp_id)
-
- algorithm.delete(bp_id)
- self.data_mod.delete([vid,sid])
- return v.ravel()
+ return self.BP(s, out=None).ravel()
def __mul__(self,v):
"""Provides easy forward operator by *.
@@ -189,6 +161,70 @@ class OpTomo(scipy.sparse.linalg.LinearOperator):
self.data_mod.delete([vid,sid])
return v
+ def FP(self,v,out=None):
+ """Perform forward projection.
+
+ Output must have the right 2D/3D shape. Input may also be flattened.
+
+ Output must also be contiguous and float32. This isn't required for the
+ input, but it is more efficient if it is.
+
+ :param v: Volume to forward project.
+ :type v: :class:`numpy.ndarray`
+ :param out: Array to store result in.
+ :type out: :class:`numpy.ndarray`
+ """
+
+ v = self.__checkArray(v, self.vshape)
+ vid = self.data_mod.link('-vol',self.vg,v)
+ if out is None:
+ out = np.zeros(self.sshape,dtype=np.float32)
+ sid = self.data_mod.link('-sino',self.pg,out)
+
+ cfg = creators.astra_dict('FP'+self.appendString)
+ cfg['ProjectionDataId'] = sid
+ cfg['VolumeDataId'] = vid
+ cfg['ProjectorId'] = self.proj_id
+ fp_id = algorithm.create(cfg)
+ algorithm.run(fp_id)
+
+ algorithm.delete(fp_id)
+ self.data_mod.delete([vid,sid])
+ return out
+
+ def BP(self,s,out=None):
+ """Perform backprojection.
+
+ Output must have the right 2D/3D shape. Input may also be flattened.
+
+ Output must also be contiguous and float32. This isn't required for the
+ input, but it is more efficient if it is.
+
+ :param : The projection data.
+ :type s: :class:`numpy.ndarray`
+ :param out: Array to store result in.
+ :type out: :class:`numpy.ndarray`
+ """
+ s = self.__checkArray(s, self.sshape)
+ sid = self.data_mod.link('-sino',self.pg,s)
+ if out is None:
+ out = np.zeros(self.vshape,dtype=np.float32)
+ vid = self.data_mod.link('-vol',self.vg,out)
+
+ cfg = creators.astra_dict('BP'+self.appendString)
+ cfg['ProjectionDataId'] = sid
+ cfg['ReconstructionDataId'] = vid
+ cfg['ProjectorId'] = self.proj_id
+ bp_id = algorithm.create(cfg)
+ algorithm.run(bp_id)
+
+ algorithm.delete(bp_id)
+ self.data_mod.delete([vid,sid])
+ return out
+
+
+
+
class OpTomoTranspose(scipy.sparse.linalg.LinearOperator):
"""This object provides the transpose operation (``.T``) of the OpTomo object.
diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py
index 31cca95..85b5486 100644
--- a/samples/python/s018_plugin.py
+++ b/samples/python/s018_plugin.py
@@ -30,30 +30,38 @@ import six
# Define the plugin class (has to subclass astra.plugin.base)
# Note that usually, these will be defined in a separate package/module
-class SIRTPlugin(astra.plugin.base):
- """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm.
+class LandweberPlugin(astra.plugin.base):
+ """Example of an ASTRA plugin class, implementing a simple 2D Landweber algorithm.
Options:
- 'rel_factor': relaxation factor (optional)
+ 'Relaxation': relaxation factor (optional)
"""
# The astra_name variable defines the name to use to
# call the plugin from ASTRA
- astra_name = "SIRT-PLUGIN"
+ astra_name = "LANDWEBER-PLUGIN"
- def initialize(self,cfg, rel_factor = 1):
+ def initialize(self,cfg, Relaxation = 1):
self.W = astra.OpTomo(cfg['ProjectorId'])
self.vid = cfg['ReconstructionDataId']
self.sid = cfg['ProjectionDataId']
- self.rel = rel_factor
+ self.rel = Relaxation
def run(self, its):
v = astra.data2d.get_shared(self.vid)
s = astra.data2d.get_shared(self.sid)
+ tv = np.zeros(v.shape, dtype=np.float32)
+ ts = np.zeros(s.shape, dtype=np.float32)
W = self.W
for i in range(its):
- v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size
+ W.FP(v,out=ts)
+ ts -= s # ts = W*v - s
+
+ W.BP(ts,out=tv)
+ tv *= self.rel / s.size
+
+ v -= tv # v = v - rel * W'*(W*v-s) / s.size
if __name__=='__main__':
@@ -75,20 +83,20 @@ if __name__=='__main__':
# First we import the package that contains the plugin
import s018_plugin
# Then, we register the plugin class with ASTRA
- astra.plugin.register(s018_plugin.SIRTPlugin)
+ astra.plugin.register(s018_plugin.LandweberPlugin)
# Get a list of registered plugins
six.print_(astra.plugin.get_registered())
# To get help on a registered plugin, use get_help
- six.print_(astra.plugin.get_help('SIRT-PLUGIN'))
+ six.print_(astra.plugin.get_help('LANDWEBER-PLUGIN'))
# Create data structures
sid = astra.data2d.create('-sino', proj_geom, sinogram)
vid = astra.data2d.create('-vol', vol_geom)
# Create config using plugin name
- cfg = astra.astra_dict('SIRT-PLUGIN')
+ cfg = astra.astra_dict('LANDWEBER-PLUGIN')
cfg['ProjectorId'] = proj_id
cfg['ProjectionDataId'] = sid
cfg['ReconstructionDataId'] = vid
@@ -103,18 +111,18 @@ if __name__=='__main__':
rec = astra.data2d.get(vid)
# Options for the plugin go in cfg['option']
- cfg = astra.astra_dict('SIRT-PLUGIN')
+ cfg = astra.astra_dict('LANDWEBER-PLUGIN')
cfg['ProjectorId'] = proj_id
cfg['ProjectionDataId'] = sid
cfg['ReconstructionDataId'] = vid
cfg['option'] = {}
- cfg['option']['rel_factor'] = 1.5
+ cfg['option']['Relaxation'] = 1.5
alg_id_rel = astra.algorithm.create(cfg)
astra.algorithm.run(alg_id_rel, 100)
rec_rel = astra.data2d.get(vid)
# We can also use OpTomo to call the plugin
- rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5})
+ rec_op = W.reconstruct('LANDWEBER-PLUGIN', sinogram, 100, extraOptions={'Relaxation':1.5})
import pylab as pl
pl.gray()