summaryrefslogtreecommitdiffstats
path: root/python/astra
diff options
context:
space:
mode:
authorDaan Pelt <daan.pelt@gmail.com>2015-03-26 16:40:38 +0100
committerDaniel M. Pelt <D.M.Pelt@cwi.nl>2015-05-04 14:21:04 +0200
commit2bc0d98c413fee4108115f26aa337f65337eec55 (patch)
tree5f77f278ab33ad7b1da819407c3f472edde8bb4b /python/astra
parentbf31003d74f538a9096ef5999b31b0daa58c38c9 (diff)
downloadastra-2bc0d98c413fee4108115f26aa337f65337eec55.tar.gz
astra-2bc0d98c413fee4108115f26aa337f65337eec55.tar.bz2
astra-2bc0d98c413fee4108115f26aa337f65337eec55.tar.xz
astra-2bc0d98c413fee4108115f26aa337f65337eec55.zip
Add SPOT-like object for Python (overrides `__mul__` and works with scipy.sparse.linalg)
Diffstat (limited to 'python/astra')
-rw-r--r--python/astra/ASTRAProjector.py135
-rw-r--r--python/astra/__init__.py2
-rw-r--r--python/astra/operator.py208
3 files changed, 209 insertions, 136 deletions
diff --git a/python/astra/ASTRAProjector.py b/python/astra/ASTRAProjector.py
deleted file mode 100644
index f282618..0000000
--- a/python/astra/ASTRAProjector.py
+++ /dev/null
@@ -1,135 +0,0 @@
-#-----------------------------------------------------------------------
-#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam
-#
-#Author: Daniel M. Pelt
-#Contact: D.M.Pelt@cwi.nl
-#Website: http://dmpelt.github.io/pyastratoolbox/
-#
-#
-#This file is part of the Python interface to the
-#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox").
-#
-#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify
-#it under the terms of the GNU General Public License as published by
-#the Free Software Foundation, either version 3 of the License, or
-#(at your option) any later version.
-#
-#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful,
-#but WITHOUT ANY WARRANTY; without even the implied warranty of
-#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-#GNU General Public License for more details.
-#
-#You should have received a copy of the GNU General Public License
-#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>.
-#
-#-----------------------------------------------------------------------
-
-import math
-from . import creators as ac
-from . import data2d
-
-
-class ASTRAProjector2DTranspose():
- """Implements the ``proj.T`` functionality.
-
- Do not use directly, since it can be accessed as member ``.T`` of
- an :class:`ASTRAProjector2D` object.
-
- """
- def __init__(self, parentProj):
- self.parentProj = parentProj
-
- def __mul__(self, data):
- return self.parentProj.backProject(data)
-
-
-class ASTRAProjector2D(object):
- """Helps with various common ASTRA Toolbox 2D operations.
-
- This class can perform several often used toolbox operations, such as:
-
- * Forward projecting
- * Back projecting
- * Reconstructing
-
- Note that this class has a some computational overhead, because it
- copies a lot of data. If you use many repeated operations, directly
- using the PyAstraToolbox methods directly is faster.
-
- You can use this class as an abstracted weight matrix :math:`W`: multiplying an instance
- ``proj`` of this class by an image results in a forward projection of the image, and multiplying
- ``proj.T`` by a sinogram results in a backprojection of the sinogram::
-
- proj = ASTRAProjector2D(...)
- fp = proj*image
- bp = proj.T*sinogram
-
- :param proj_geom: The projection geometry.
- :type proj_geom: :class:`dict`
- :param vol_geom: The volume geometry.
- :type vol_geom: :class:`dict`
- :param proj_type: Projector type, such as ``'line'``, ``'linear'``, ...
- :type proj_type: :class:`string`
- """
-
- def __init__(self, proj_geom, vol_geom, proj_type):
- self.vol_geom = vol_geom
- self.recSize = vol_geom['GridColCount']
- self.angles = proj_geom['ProjectionAngles']
- self.nDet = proj_geom['DetectorCount']
- nexpow = int(pow(2, math.ceil(math.log(2 * self.nDet, 2))))
- self.filterSize = nexpow / 2 + 1
- self.nProj = self.angles.shape[0]
- self.proj_geom = proj_geom
- self.proj_id = ac.create_projector(proj_type, proj_geom, vol_geom)
- self.T = ASTRAProjector2DTranspose(self)
-
- def backProject(self, data):
- """Backproject a sinogram.
-
- :param data: The sinogram data or ID.
- :type data: :class:`numpy.ndarray` or :class:`int`
- :returns: :class:`numpy.ndarray` -- The backprojection.
-
- """
- vol_id, vol = ac.create_backprojection(
- data, self.proj_id, returnData=True)
- data2d.delete(vol_id)
- return vol
-
- def forwardProject(self, data):
- """Forward project an image.
-
- :param data: The image data or ID.
- :type data: :class:`numpy.ndarray` or :class:`int`
- :returns: :class:`numpy.ndarray` -- The forward projection.
-
- """
- sin_id, sino = ac.create_sino(data, self.proj_id, returnData=True)
- data2d.delete(sin_id)
- return sino
-
- def reconstruct(self, data, method, **kwargs):
- """Reconstruct an image from a sinogram.
-
- :param data: The sinogram data or ID.
- :type data: :class:`numpy.ndarray` or :class:`int`
- :param method: Name of the reconstruction algorithm.
- :type method: :class:`string`
- :param kwargs: Additional named parameters to pass to :func:`astra.creators.create_reconstruction`.
- :returns: :class:`numpy.ndarray` -- The reconstruction.
-
- Example of a SIRT reconstruction using CUDA::
-
- proj = ASTRAProjector2D(...)
- rec = proj.reconstruct(sinogram,'SIRT_CUDA',iterations=1000)
-
- """
- kwargs['returnData'] = True
- rec_id, rec = ac.create_reconstruction(
- method, self.proj_id, data, **kwargs)
- data2d.delete(rec_id)
- return rec
-
- def __mul__(self, data):
- return self.forwardProject(data)
diff --git a/python/astra/__init__.py b/python/astra/__init__.py
index 063dc16..8c1740c 100644
--- a/python/astra/__init__.py
+++ b/python/astra/__init__.py
@@ -27,7 +27,6 @@ from . import matlab as m
from .creators import astra_dict,create_vol_geom, create_proj_geom, create_backprojection, create_sino, create_reconstruction, create_projector,create_sino3d_gpu, create_backprojection3d_gpu
from .functions import data_op, add_noise_to_sino, clear, move_vol_geom
from .extrautils import clipCircle
-from .ASTRAProjector import ASTRAProjector2D
from . import data2d
from . import astra
from . import data3d
@@ -36,6 +35,7 @@ from . import projector
from . import projector3d
from . import matrix
from . import log
+from .operator import OpTomo
import os
try:
diff --git a/python/astra/operator.py b/python/astra/operator.py
new file mode 100644
index 0000000..a3abd5a
--- /dev/null
+++ b/python/astra/operator.py
@@ -0,0 +1,208 @@
+#-----------------------------------------------------------------------
+#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam
+#
+#Author: Daniel M. Pelt
+#Contact: D.M.Pelt@cwi.nl
+#Website: http://dmpelt.github.io/pyastratoolbox/
+#
+#
+#This file is part of the Python interface to the
+#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox").
+#
+#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify
+#it under the terms of the GNU General Public License as published by
+#the Free Software Foundation, either version 3 of the License, or
+#(at your option) any later version.
+#
+#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful,
+#but WITHOUT ANY WARRANTY; without even the implied warranty of
+#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+#GNU General Public License for more details.
+#
+#You should have received a copy of the GNU General Public License
+#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>.
+#
+#-----------------------------------------------------------------------
+
+from . import data2d
+from . import data3d
+from . import projector
+from . import projector3d
+from . import creators
+from . import algorithm
+from . import functions
+import numpy as np
+from six.moves import range, reduce
+import operator
+import scipy.sparse.linalg
+
+class OpTomo(scipy.sparse.linalg.LinearOperator):
+ """Object that imitates a projection matrix with a given projector.
+
+ This object can do forward projection by using the ``*`` operator::
+
+ W = astra.OpTomo(proj_id)
+ fp = W*image
+ bp = W.T*sinogram
+
+ It can also be used in minimization methods of the :mod:`scipy.sparse.linalg` module::
+
+ W = astra.OpTomo(proj_id)
+ output = scipy.sparse.linalg.lsqr(W,sinogram)
+
+ :param proj_id: ID to a projector.
+ :type proj_id: :class:`int`
+ """
+
+ def __init__(self,proj_id):
+ self.dtype = np.float32
+ try:
+ self.vg = projector.volume_geometry(proj_id)
+ self.pg = projector.projection_geometry(proj_id)
+ self.data_mod = data2d
+ self.appendString = ""
+ if projector.is_cuda(proj_id):
+ self.appendString += "_CUDA"
+ except Exception:
+ self.vg = projector3d.volume_geometry(proj_id)
+ self.pg = projector3d.projection_geometry(proj_id)
+ self.data_mod = data3d
+ self.appendString = "3D"
+ if projector3d.is_cuda(proj_id):
+ self.appendString += "_CUDA"
+
+ self.vshape = functions.geom_size(self.vg)
+ self.vsize = reduce(operator.mul,self.vshape)
+ self.sshape = functions.geom_size(self.pg)
+ self.ssize = reduce(operator.mul,self.sshape)
+
+ self.shape = (self.ssize, self.vsize)
+
+ self.proj_id = proj_id
+
+ self.T = OpTomoTranspose(self)
+
+ def __checkArray(self, arr, shp):
+ if len(arr.shape)==1:
+ arr = arr.reshape(shp)
+ if arr.dtype != np.float32:
+ arr = arr.astype(np.float32)
+ if arr.flags['C_CONTIGUOUS']==False:
+ arr = np.ascontiguousarray(arr)
+ return arr
+
+ def matvec(self,v):
+ """Implements the forward operator.
+
+ :param v: Volume to forward project.
+ :type v: :class:`numpy.ndarray`
+ """
+ v = self.__checkArray(v, self.vshape)
+ vid = self.data_mod.link('-vol',self.vg,v)
+ s = np.zeros(self.sshape,dtype=np.float32)
+ sid = self.data_mod.link('-sino',self.pg,s)
+
+ cfg = creators.astra_dict('FP'+self.appendString)
+ cfg['ProjectionDataId'] = sid
+ cfg['VolumeDataId'] = vid
+ cfg['ProjectorId'] = self.proj_id
+ fp_id = algorithm.create(cfg)
+ algorithm.run(fp_id)
+
+ algorithm.delete(fp_id)
+ self.data_mod.delete([vid,sid])
+ return s.flatten()
+
+ def rmatvec(self,s):
+ """Implements the transpose operator.
+
+ :param s: The projection data.
+ :type s: :class:`numpy.ndarray`
+ """
+ s = self.__checkArray(s, self.sshape)
+ sid = self.data_mod.link('-sino',self.pg,s)
+ v = np.zeros(self.vshape,dtype=np.float32)
+ vid = self.data_mod.link('-vol',self.vg,v)
+
+ cfg = creators.astra_dict('BP'+self.appendString)
+ cfg['ProjectionDataId'] = sid
+ cfg['ReconstructionDataId'] = vid
+ cfg['ProjectorId'] = self.proj_id
+ bp_id = algorithm.create(cfg)
+ algorithm.run(bp_id)
+
+ algorithm.delete(bp_id)
+ self.data_mod.delete([vid,sid])
+ return v.flatten()
+
+ def matmat(self,m):
+ """Implements the forward operator with a matrix.
+
+ :param m: Volumes to forward project, arranged in columns.
+ :type m: :class:`numpy.ndarray`
+ """
+ out = np.zeros((self.ssize,m.shape[1]),dtype=np.float32)
+ for i in range(m.shape[1]):
+ out[:,i] = self.matvec(m[:,i].flatten())
+ return out
+
+ def __mul__(self,v):
+ """Provides easy forward operator by *.
+
+ :param v: Volume to forward project.
+ :type v: :class:`numpy.ndarray`
+ """
+ return self.matvec(v)
+
+ def reconstruct(self, method, s, iterations=1, extraOptions = {}):
+ """Reconstruct an object.
+
+ :param method: Method to use for reconstruction.
+ :type method: :class:`string`
+ :param s: The projection data.
+ :type s: :class:`numpy.ndarray`
+ :param iterations: Number of iterations to use.
+ :type iterations: :class:`int`
+ :param extraOptions: Extra options to use during reconstruction (i.e. for cfg['option']).
+ :type extraOptions: :class:`dict`
+ """
+ self.__checkArray(s, self.sshape)
+ sid = self.data_mod.link('-sino',self.pg,s)
+ v = np.zeros(self.vshape,dtype=np.float32)
+ vid = self.data_mod.link('-vol',self.vg,v)
+ cfg = creators.astra_dict(method)
+ cfg['ProjectionDataId'] = sid
+ cfg['ReconstructionDataId'] = vid
+ cfg['ProjectorId'] = self.proj_id
+ cfg['option'] = extraOptions
+ alg_id = algorithm.create(cfg)
+ algorithm.run(alg_id,iterations)
+ algorithm.delete(alg_id)
+ self.data_mod.delete([vid,sid])
+ return v
+
+class OpTomoTranspose(scipy.sparse.linalg.LinearOperator):
+ """This object provides the transpose operation (``.T``) of the OpTomo object.
+
+ Do not use directly, since it can be accessed as member ``.T`` of
+ an :class:`OpTomo` object.
+ """
+ def __init__(self,parent):
+ self.parent = parent
+ self.dtype = np.float32
+ self.shape = (parent.shape[1], parent.shape[0])
+
+ def matvec(self, s):
+ return self.parent.rmatvec(s)
+
+ def rmatvec(self, v):
+ return self.parent.matvec(v)
+
+ def matmat(self, m):
+ out = np.zeros((self.vsize,m.shape[1]),dtype=np.float32)
+ for i in range(m.shape[1]):
+ out[:,i] = self.matvec(m[:,i].flatten())
+ return out
+
+ def __mul__(self,v):
+ return self.matvec(v)