diff options
Diffstat (limited to 'python/astra')
-rw-r--r-- | python/astra/ASTRAProjector.py | 135 | ||||
-rw-r--r-- | python/astra/PyIncludes.pxd | 4 | ||||
-rw-r--r-- | python/astra/PyXMLDocument.pxd | 10 | ||||
-rw-r--r-- | python/astra/__init__.py | 2 | ||||
-rw-r--r-- | python/astra/algorithm_c.pyx | 4 | ||||
-rw-r--r-- | python/astra/data2d_c.pyx | 10 | ||||
-rw-r--r-- | python/astra/data3d_c.pyx | 11 | ||||
-rw-r--r-- | python/astra/functions.py | 31 | ||||
-rw-r--r-- | python/astra/log_c.pyx | 21 | ||||
-rw-r--r-- | python/astra/matrix_c.pyx | 7 | ||||
-rw-r--r-- | python/astra/optomo.py | 219 | ||||
-rw-r--r-- | python/astra/projector3d_c.pyx | 10 | ||||
-rw-r--r-- | python/astra/projector_c.pyx | 10 | ||||
-rw-r--r-- | python/astra/pythonutils.py | 63 | ||||
-rw-r--r-- | python/astra/utils.pyx | 71 |
15 files changed, 383 insertions, 225 deletions
diff --git a/python/astra/ASTRAProjector.py b/python/astra/ASTRAProjector.py deleted file mode 100644 index f282618..0000000 --- a/python/astra/ASTRAProjector.py +++ /dev/null @@ -1,135 +0,0 @@ -#----------------------------------------------------------------------- -#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam -# -#Author: Daniel M. Pelt -#Contact: D.M.Pelt@cwi.nl -#Website: http://dmpelt.github.io/pyastratoolbox/ -# -# -#This file is part of the Python interface to the -#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). -# -#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify -#it under the terms of the GNU General Public License as published by -#the Free Software Foundation, either version 3 of the License, or -#(at your option) any later version. -# -#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, -#but WITHOUT ANY WARRANTY; without even the implied warranty of -#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -#GNU General Public License for more details. -# -#You should have received a copy of the GNU General Public License -#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. -# -#----------------------------------------------------------------------- - -import math -from . import creators as ac -from . import data2d - - -class ASTRAProjector2DTranspose(): - """Implements the ``proj.T`` functionality. - - Do not use directly, since it can be accessed as member ``.T`` of - an :class:`ASTRAProjector2D` object. - - """ - def __init__(self, parentProj): - self.parentProj = parentProj - - def __mul__(self, data): - return self.parentProj.backProject(data) - - -class ASTRAProjector2D(object): - """Helps with various common ASTRA Toolbox 2D operations. - - This class can perform several often used toolbox operations, such as: - - * Forward projecting - * Back projecting - * Reconstructing - - Note that this class has a some computational overhead, because it - copies a lot of data. If you use many repeated operations, directly - using the PyAstraToolbox methods directly is faster. - - You can use this class as an abstracted weight matrix :math:`W`: multiplying an instance - ``proj`` of this class by an image results in a forward projection of the image, and multiplying - ``proj.T`` by a sinogram results in a backprojection of the sinogram:: - - proj = ASTRAProjector2D(...) - fp = proj*image - bp = proj.T*sinogram - - :param proj_geom: The projection geometry. - :type proj_geom: :class:`dict` - :param vol_geom: The volume geometry. - :type vol_geom: :class:`dict` - :param proj_type: Projector type, such as ``'line'``, ``'linear'``, ... - :type proj_type: :class:`string` - """ - - def __init__(self, proj_geom, vol_geom, proj_type): - self.vol_geom = vol_geom - self.recSize = vol_geom['GridColCount'] - self.angles = proj_geom['ProjectionAngles'] - self.nDet = proj_geom['DetectorCount'] - nexpow = int(pow(2, math.ceil(math.log(2 * self.nDet, 2)))) - self.filterSize = nexpow / 2 + 1 - self.nProj = self.angles.shape[0] - self.proj_geom = proj_geom - self.proj_id = ac.create_projector(proj_type, proj_geom, vol_geom) - self.T = ASTRAProjector2DTranspose(self) - - def backProject(self, data): - """Backproject a sinogram. - - :param data: The sinogram data or ID. - :type data: :class:`numpy.ndarray` or :class:`int` - :returns: :class:`numpy.ndarray` -- The backprojection. - - """ - vol_id, vol = ac.create_backprojection( - data, self.proj_id, returnData=True) - data2d.delete(vol_id) - return vol - - def forwardProject(self, data): - """Forward project an image. - - :param data: The image data or ID. - :type data: :class:`numpy.ndarray` or :class:`int` - :returns: :class:`numpy.ndarray` -- The forward projection. - - """ - sin_id, sino = ac.create_sino(data, self.proj_id, returnData=True) - data2d.delete(sin_id) - return sino - - def reconstruct(self, data, method, **kwargs): - """Reconstruct an image from a sinogram. - - :param data: The sinogram data or ID. - :type data: :class:`numpy.ndarray` or :class:`int` - :param method: Name of the reconstruction algorithm. - :type method: :class:`string` - :param kwargs: Additional named parameters to pass to :func:`astra.creators.create_reconstruction`. - :returns: :class:`numpy.ndarray` -- The reconstruction. - - Example of a SIRT reconstruction using CUDA:: - - proj = ASTRAProjector2D(...) - rec = proj.reconstruct(sinogram,'SIRT_CUDA',iterations=1000) - - """ - kwargs['returnData'] = True - rec_id, rec = ac.create_reconstruction( - method, self.proj_id, data, **kwargs) - data2d.delete(rec_id) - return rec - - def __mul__(self, data): - return self.forwardProject(data) diff --git a/python/astra/PyIncludes.pxd b/python/astra/PyIncludes.pxd index 1d8285b..35dea5f 100644 --- a/python/astra/PyIncludes.pxd +++ b/python/astra/PyIncludes.pxd @@ -43,7 +43,7 @@ cdef extern from "astra/Config.h" namespace "astra": cdef cppclass Config: Config() void initialize(string rootname) - XMLNode *self + XMLNode self cdef extern from "astra/VolumeGeometry2D.h" namespace "astra": cdef cppclass CVolumeGeometry2D: @@ -143,7 +143,7 @@ cdef extern from "astra/Float32ProjectionData2D.h" namespace "astra": cdef extern from "astra/Algorithm.h" namespace "astra": cdef cppclass CAlgorithm: bool initialize(Config) - void run(int) + void run(int) nogil bool isInitialized() cdef extern from "astra/ReconstructionAlgorithm2D.h" namespace "astra": diff --git a/python/astra/PyXMLDocument.pxd b/python/astra/PyXMLDocument.pxd index 69781f1..033b8ef 100644 --- a/python/astra/PyXMLDocument.pxd +++ b/python/astra/PyXMLDocument.pxd @@ -44,22 +44,24 @@ cdef extern from "astra/Globals.h" namespace "astra": cdef extern from "astra/XMLNode.h" namespace "astra": cdef cppclass XMLNode: string getName() - XMLNode *addChildNode(string name) - XMLNode *addChildNode(string, string) + XMLNode addChildNode(string name) + XMLNode addChildNode(string, string) void addAttribute(string, string) void addAttribute(string, float32) void addOption(string, string) bool hasOption(string) string getAttribute(string) - list[XMLNode *] getNodes() + list[XMLNode] getNodes() vector[float32] getContentNumericalArray() + void setContent(double*, int, int, bool) + void setContent(double*, int) string getContent() bool hasAttribute(string) cdef extern from "astra/XMLDocument.h" namespace "astra": cdef cppclass XMLDocument: void saveToFile(string sFilename) - XMLNode *getRootNode() + XMLNode getRootNode() cdef extern from "astra/XMLDocument.h" namespace "astra::XMLDocument": cdef XMLDocument *createDocument(string rootname) diff --git a/python/astra/__init__.py b/python/astra/__init__.py index 063dc16..6c15d30 100644 --- a/python/astra/__init__.py +++ b/python/astra/__init__.py @@ -27,7 +27,6 @@ from . import matlab as m from .creators import astra_dict,create_vol_geom, create_proj_geom, create_backprojection, create_sino, create_reconstruction, create_projector,create_sino3d_gpu, create_backprojection3d_gpu from .functions import data_op, add_noise_to_sino, clear, move_vol_geom from .extrautils import clipCircle -from .ASTRAProjector import ASTRAProjector2D from . import data2d from . import astra from . import data3d @@ -36,6 +35,7 @@ from . import projector from . import projector3d from . import matrix from . import log +from .optomo import OpTomo import os try: diff --git a/python/astra/algorithm_c.pyx b/python/astra/algorithm_c.pyx index 966d3d7..3231c1f 100644 --- a/python/astra/algorithm_c.pyx +++ b/python/astra/algorithm_c.pyx @@ -73,7 +73,9 @@ cdef CAlgorithm * getAlg(i) except NULL: def run(i, iterations=0): cdef CAlgorithm * alg = getAlg(i) - alg.run(iterations) + cdef int its = iterations + with nogil: + alg.run(its) def get_res_norm(i): diff --git a/python/astra/data2d_c.pyx b/python/astra/data2d_c.pyx index ac54898..4919bf2 100644 --- a/python/astra/data2d_c.pyx +++ b/python/astra/data2d_c.pyx @@ -47,6 +47,12 @@ from .PyIncludes cimport * cimport utils from .utils import wrap_from_bytes +from .pythonutils import geom_size + +import operator + +from six.moves import reduce + cdef CData2DManager * man2d = <CData2DManager * >PyData2DManager.getSingletonPtr() cdef extern from "CFloat32CustomPython.h": @@ -71,6 +77,10 @@ def create(datatype, geometry, data=None, link=False): cdef CProjectionGeometry2D * ppGeometry cdef CFloat32Data2D * pDataObject2D cdef CFloat32CustomMemory * pCustom + + if link and data.shape!=geom_size(geometry): + raise Exception("The dimensions of the data do not match those specified in the geometry.") + if datatype == '-vol': cfg = utils.dictToConfig(six.b('VolumeGeometry'), geometry) pGeometry = new CVolumeGeometry2D() diff --git a/python/astra/data3d_c.pyx b/python/astra/data3d_c.pyx index 84472c1..3b27ab7 100644 --- a/python/astra/data3d_c.pyx +++ b/python/astra/data3d_c.pyx @@ -45,6 +45,13 @@ from .PyXMLDocument cimport XMLDocument cimport utils from .utils import wrap_from_bytes +from .pythonutils import geom_size + +import operator + +from six.moves import reduce + + cdef CData3DManager * man3d = <CData3DManager * >PyData3DManager.getSingletonPtr() cdef extern from *: @@ -61,6 +68,10 @@ def create(datatype,geometry,data=None, link=False): cdef CFloat32Data3DMemory * pDataObject3D cdef CConeProjectionGeometry3D* pppGeometry cdef CFloat32CustomMemory * pCustom + + if link and data.shape!=geom_size(geometry): + raise Exception("The dimensions of the data do not match those specified in the geometry.") + if datatype == '-vol': cfg = utils.dictToConfig(six.b('VolumeGeometry'), geometry) pGeometry = new CVolumeGeometry3D() diff --git a/python/astra/functions.py b/python/astra/functions.py index 4025468..e38b5bc 100644 --- a/python/astra/functions.py +++ b/python/astra/functions.py @@ -32,12 +32,17 @@ from . import creators as ac import numpy as np -from six.moves import range +try: + from six.moves import range +except ImportError: + # six 1.3.0 + from six.moves import xrange as range from . import data2d from . import data3d from . import projector from . import algorithm +from . import pythonutils @@ -158,29 +163,7 @@ def geom_size(geom, dim=None): :param dim: Optional axis index to return :type dim: :class:`int` """ - - if 'GridSliceCount' in geom: - # 3D Volume geometry? - s = (geom['GridSliceCount'], geom[ - 'GridRowCount'], geom['GridColCount']) - elif 'GridColCount' in geom: - # 2D Volume geometry? - s = (geom['GridRowCount'], geom['GridColCount']) - elif geom['type'] == 'parallel' or geom['type'] == 'fanflat': - s = (len(geom['ProjectionAngles']), geom['DetectorCount']) - elif geom['type'] == 'parallel3d' or geom['type'] == 'cone': - s = (geom['DetectorRowCount'], len( - geom['ProjectionAngles']), geom['DetectorColCount']) - elif geom['type'] == 'fanflat_vec': - s = (geom['Vectors'].shape[0], geom['DetectorCount']) - elif geom['type'] == 'parallel3d_vec' or geom['type'] == 'cone_vec': - s = (geom['DetectorRowCount'], geom[ - 'Vectors'].shape[0], geom['DetectorColCount']) - - if dim != None: - s = s[dim] - - return s + return pythonutils.geom_size(geom,dim) def geom_2vec(proj_geom): diff --git a/python/astra/log_c.pyx b/python/astra/log_c.pyx index 969cc06..55c63e6 100644 --- a/python/astra/log_c.pyx +++ b/python/astra/log_c.pyx @@ -52,16 +52,20 @@ cdef extern from "astra/Logging.h" namespace "astra::CLogger": void setFormatScreen(const char *fmt) def log_debug(sfile, sline, message): - debug(six.b(sfile),sline,six.b(message)) + cstr = list(map(six.b,(sfile,message))) + debug(cstr[0],sline,"%s",<char*>cstr[1]) def log_info(sfile, sline, message): - info(six.b(sfile),sline,six.b(message)) + cstr = list(map(six.b,(sfile,message))) + info(cstr[0],sline,"%s",<char*>cstr[1]) def log_warn(sfile, sline, message): - warn(six.b(sfile),sline,six.b(message)) + cstr = list(map(six.b,(sfile,message))) + warn(cstr[0],sline,"%s",<char*>cstr[1]) def log_error(sfile, sline, message): - error(six.b(sfile),sline,six.b(message)) + cstr = list(map(six.b,(sfile,message))) + error(cstr[0],sline,"%s",<char*>cstr[1]) def log_enable(): enable() @@ -82,10 +86,12 @@ def log_disableFile(): disableFile() def log_setFormatFile(fmt): - setFormatFile(six.b(fmt)) + cstr = six.b(fmt) + setFormatFile(cstr) def log_setFormatScreen(fmt): - setFormatScreen(six.b(fmt)) + cstr = six.b(fmt) + setFormatScreen(cstr) enumList = [LOG_DEBUG,LOG_INFO,LOG_WARN,LOG_ERROR] @@ -93,4 +99,5 @@ def log_setOutputScreen(fd, level): setOutputScreen(fd, enumList[level]) def log_setOutputFile(filename, level): - setOutputFile(six.b(filename), enumList[level])
\ No newline at end of file + cstr = six.b(filename) + setOutputFile(cstr, enumList[level]) diff --git a/python/astra/matrix_c.pyx b/python/astra/matrix_c.pyx index b0d8bc4..d099a75 100644 --- a/python/astra/matrix_c.pyx +++ b/python/astra/matrix_c.pyx @@ -27,7 +27,12 @@ # distutils: libraries = astra import six -from six.moves import range +try: + from six.moves import range +except ImportError: + # six 1.3.0 + from six.moves import xrange as range + import numpy as np import scipy.sparse as ss diff --git a/python/astra/optomo.py b/python/astra/optomo.py new file mode 100644 index 0000000..4a64150 --- /dev/null +++ b/python/astra/optomo.py @@ -0,0 +1,219 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- + +from . import data2d +from . import data3d +from . import projector +from . import projector3d +from . import creators +from . import algorithm +from . import functions +import numpy as np +from six.moves import reduce +try: + from six.moves import range +except ImportError: + # six 1.3.0 + from six.moves import xrange as range + +import operator +import scipy.sparse.linalg + +class OpTomo(scipy.sparse.linalg.LinearOperator): + """Object that imitates a projection matrix with a given projector. + + This object can do forward projection by using the ``*`` operator:: + + W = astra.OpTomo(proj_id) + fp = W*image + bp = W.T*sinogram + + It can also be used in minimization methods of the :mod:`scipy.sparse.linalg` module:: + + W = astra.OpTomo(proj_id) + output = scipy.sparse.linalg.lsqr(W,sinogram) + + :param proj_id: ID to a projector. + :type proj_id: :class:`int` + """ + + def __init__(self,proj_id): + self.dtype = np.float32 + try: + self.vg = projector.volume_geometry(proj_id) + self.pg = projector.projection_geometry(proj_id) + self.data_mod = data2d + self.appendString = "" + if projector.is_cuda(proj_id): + self.appendString += "_CUDA" + except Exception: + self.vg = projector3d.volume_geometry(proj_id) + self.pg = projector3d.projection_geometry(proj_id) + self.data_mod = data3d + self.appendString = "3D" + if projector3d.is_cuda(proj_id): + self.appendString += "_CUDA" + + self.vshape = functions.geom_size(self.vg) + self.vsize = reduce(operator.mul,self.vshape) + self.sshape = functions.geom_size(self.pg) + self.ssize = reduce(operator.mul,self.sshape) + + self.shape = (self.ssize, self.vsize) + + self.proj_id = proj_id + + self.transposeOpTomo = OpTomoTranspose(self) + try: + self.T = self.transposeOpTomo + except AttributeError: + # Scipy >= 0.16 defines self.T using self._transpose() + pass + + def _transpose(self): + return self.transposeOpTomo + + def __checkArray(self, arr, shp): + if len(arr.shape)==1: + arr = arr.reshape(shp) + if arr.dtype != np.float32: + arr = arr.astype(np.float32) + if arr.flags['C_CONTIGUOUS']==False: + arr = np.ascontiguousarray(arr) + return arr + + def _matvec(self,v): + """Implements the forward operator. + + :param v: Volume to forward project. + :type v: :class:`numpy.ndarray` + """ + v = self.__checkArray(v, self.vshape) + vid = self.data_mod.link('-vol',self.vg,v) + s = np.zeros(self.sshape,dtype=np.float32) + sid = self.data_mod.link('-sino',self.pg,s) + + cfg = creators.astra_dict('FP'+self.appendString) + cfg['ProjectionDataId'] = sid + cfg['VolumeDataId'] = vid + cfg['ProjectorId'] = self.proj_id + fp_id = algorithm.create(cfg) + algorithm.run(fp_id) + + algorithm.delete(fp_id) + self.data_mod.delete([vid,sid]) + return s.flatten() + + def rmatvec(self,s): + """Implements the transpose operator. + + :param s: The projection data. + :type s: :class:`numpy.ndarray` + """ + s = self.__checkArray(s, self.sshape) + sid = self.data_mod.link('-sino',self.pg,s) + v = np.zeros(self.vshape,dtype=np.float32) + vid = self.data_mod.link('-vol',self.vg,v) + + cfg = creators.astra_dict('BP'+self.appendString) + cfg['ProjectionDataId'] = sid + cfg['ReconstructionDataId'] = vid + cfg['ProjectorId'] = self.proj_id + bp_id = algorithm.create(cfg) + algorithm.run(bp_id) + + algorithm.delete(bp_id) + self.data_mod.delete([vid,sid]) + return v.flatten() + + def __mul__(self,v): + """Provides easy forward operator by *. + + :param v: Volume to forward project. + :type v: :class:`numpy.ndarray` + """ + # Catch the case of a forward projection of a 2D/3D image + if isinstance(v, np.ndarray) and v.shape==self.vshape: + return self._matvec(v) + return scipy.sparse.linalg.LinearOperator.__mul__(self, v) + + def reconstruct(self, method, s, iterations=1, extraOptions = {}): + """Reconstruct an object. + + :param method: Method to use for reconstruction. + :type method: :class:`string` + :param s: The projection data. + :type s: :class:`numpy.ndarray` + :param iterations: Number of iterations to use. + :type iterations: :class:`int` + :param extraOptions: Extra options to use during reconstruction (i.e. for cfg['option']). + :type extraOptions: :class:`dict` + """ + s = self.__checkArray(s, self.sshape) + sid = self.data_mod.link('-sino',self.pg,s) + v = np.zeros(self.vshape,dtype=np.float32) + vid = self.data_mod.link('-vol',self.vg,v) + cfg = creators.astra_dict(method) + cfg['ProjectionDataId'] = sid + cfg['ReconstructionDataId'] = vid + cfg['ProjectorId'] = self.proj_id + cfg['option'] = extraOptions + alg_id = algorithm.create(cfg) + algorithm.run(alg_id,iterations) + algorithm.delete(alg_id) + self.data_mod.delete([vid,sid]) + return v + +class OpTomoTranspose(scipy.sparse.linalg.LinearOperator): + """This object provides the transpose operation (``.T``) of the OpTomo object. + + Do not use directly, since it can be accessed as member ``.T`` of + an :class:`OpTomo` object. + """ + def __init__(self,parent): + self.parent = parent + self.dtype = np.float32 + self.shape = (parent.shape[1], parent.shape[0]) + try: + self.T = self.parent + except AttributeError: + # Scipy >= 0.16 defines self.T using self._transpose() + pass + + def _matvec(self, s): + return self.parent.rmatvec(s) + + def rmatvec(self, v): + return self.parent.matvec(v) + + def _transpose(self): + return self.parent + + def __mul__(self,s): + # Catch the case of a backprojection of 2D/3D data + if isinstance(s, np.ndarray) and s.shape==self.parent.sshape: + return self._matvec(s) + return scipy.sparse.linalg.LinearOperator.__mul__(self, s) diff --git a/python/astra/projector3d_c.pyx b/python/astra/projector3d_c.pyx index 8b978d7..aec9cde 100644 --- a/python/astra/projector3d_c.pyx +++ b/python/astra/projector3d_c.pyx @@ -87,12 +87,18 @@ cdef CProjector3D * getObject(i) except NULL: def projection_geometry(i): cdef CProjector3D * proj = getObject(i) - return utils.configToDict(proj.getProjectionGeometry().getConfiguration()) + cdef Config * cfg = proj.getProjectionGeometry().getConfiguration() + dct = utils.configToDict(cfg) + del cfg + return dct def volume_geometry(i): cdef CProjector3D * proj = getObject(i) - return utils.configToDict(proj.getVolumeGeometry().getConfiguration()) + cdef Config * cfg = proj.getVolumeGeometry().getConfiguration() + dct = utils.configToDict(cfg) + del cfg + return dct def weights_single_ray(i, projection_index, detector_index): diff --git a/python/astra/projector_c.pyx b/python/astra/projector_c.pyx index 9aa868e..77c64a4 100644 --- a/python/astra/projector_c.pyx +++ b/python/astra/projector_c.pyx @@ -91,12 +91,18 @@ cdef CProjector2D * getObject(i) except NULL: def projection_geometry(i): cdef CProjector2D * proj = getObject(i) - return utils.configToDict(proj.getProjectionGeometry().getConfiguration()) + cdef Config * cfg = proj.getProjectionGeometry().getConfiguration() + dct = utils.configToDict(cfg) + del cfg + return dct def volume_geometry(i): cdef CProjector2D * proj = getObject(i) - return utils.configToDict(proj.getVolumeGeometry().getConfiguration()) + cdef Config * cfg = proj.getVolumeGeometry().getConfiguration() + dct = utils.configToDict(cfg) + del cfg + return dct def weights_single_ray(i, projection_index, detector_index): diff --git a/python/astra/pythonutils.py b/python/astra/pythonutils.py new file mode 100644 index 0000000..8ea4af5 --- /dev/null +++ b/python/astra/pythonutils.py @@ -0,0 +1,63 @@ +#----------------------------------------------------------------------- +# Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +# Author: Daniel M. Pelt +# Contact: D.M.Pelt@cwi.nl +# Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +# This file is part of the Python interface to the +# All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +# The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +# The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- +"""Additional purely Python functions for PyAstraToolbox. + +.. moduleauthor:: Daniel M. Pelt <D.M.Pelt@cwi.nl> + + +""" + +def geom_size(geom, dim=None): + """Returns the size of a volume or sinogram, based on the projection or volume geometry. + + :param geom: Geometry to calculate size from + :type geometry: :class:`dict` + :param dim: Optional axis index to return + :type dim: :class:`int` + """ + + if 'GridSliceCount' in geom: + # 3D Volume geometry? + s = (geom['GridSliceCount'], geom[ + 'GridRowCount'], geom['GridColCount']) + elif 'GridColCount' in geom: + # 2D Volume geometry? + s = (geom['GridRowCount'], geom['GridColCount']) + elif geom['type'] == 'parallel' or geom['type'] == 'fanflat': + s = (len(geom['ProjectionAngles']), geom['DetectorCount']) + elif geom['type'] == 'parallel3d' or geom['type'] == 'cone': + s = (geom['DetectorRowCount'], len( + geom['ProjectionAngles']), geom['DetectorColCount']) + elif geom['type'] == 'fanflat_vec': + s = (geom['Vectors'].shape[0], geom['DetectorCount']) + elif geom['type'] == 'parallel3d_vec' or geom['type'] == 'cone_vec': + s = (geom['DetectorRowCount'], geom[ + 'Vectors'].shape[0], geom['DetectorColCount']) + + if dim != None: + s = s[dim] + + return s diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx index 0439f1b..260c308 100644 --- a/python/astra/utils.pyx +++ b/python/astra/utils.pyx @@ -26,6 +26,7 @@ # distutils: language = c++ # distutils: libraries = astra +cimport numpy as np import numpy as np import six from libcpp.string cimport string @@ -80,11 +81,12 @@ def wrap_from_bytes(value): return s -cdef void readDict(XMLNode * root, _dc): - cdef XMLNode * listbase - cdef XMLNode * itm +cdef void readDict(XMLNode root, _dc): + cdef XMLNode listbase + cdef XMLNode itm cdef int i cdef int j + cdef double* data dc = convert_item(_dc) for item in dc: @@ -93,45 +95,32 @@ cdef void readDict(XMLNode * root, _dc): if val.size == 0: break listbase = root.addChildNode(item) - listbase.addAttribute(< string > six.b('listsize'), < float32 > val.size) - index = 0 + contig_data = np.ascontiguousarray(val,dtype=np.float64) + data = <double*>np.PyArray_DATA(contig_data) if val.ndim == 2: - for i in range(val.shape[0]): - for j in range(val.shape[1]): - itm = listbase.addChildNode(six.b('ListItem')) - itm.addAttribute(< string > six.b('index'), < float32 > index) - itm.addAttribute( < string > six.b('value'), < float32 > val[i, j]) - index += 1 - del itm + listbase.setContent(data, val.shape[1], val.shape[0], False) elif val.ndim == 1: - for i in range(val.shape[0]): - itm = listbase.addChildNode(six.b('ListItem')) - itm.addAttribute(< string > six.b('index'), < float32 > index) - itm.addAttribute(< string > six.b('value'), < float32 > val[i]) - index += 1 - del itm + listbase.setContent(data, val.shape[0]) else: raise Exception("Only 1 or 2 dimensions are allowed") - del listbase elif isinstance(val, dict): if item == six.b('option') or item == six.b('options') or item == six.b('Option') or item == six.b('Options'): readOptions(root, val) else: itm = root.addChildNode(item) readDict(itm, val) - del itm else: if item == six.b('type'): root.addAttribute(< string > six.b('type'), <string> wrap_to_bytes(val)) else: itm = root.addChildNode(item, wrap_to_bytes(val)) - del itm -cdef void readOptions(XMLNode * node, dc): - cdef XMLNode * listbase - cdef XMLNode * itm +cdef void readOptions(XMLNode node, dc): + cdef XMLNode listbase + cdef XMLNode itm cdef int i cdef int j + cdef double* data for item in dc: val = dc[item] if node.hasOption(item): @@ -141,26 +130,14 @@ cdef void readOptions(XMLNode * node, dc): break listbase = node.addChildNode(six.b('Option')) listbase.addAttribute(< string > six.b('key'), < string > item) - listbase.addAttribute(< string > six.b('listsize'), < float32 > val.size) - index = 0 + contig_data = np.ascontiguousarray(val,dtype=np.float64) + data = <double*>np.PyArray_DATA(contig_data) if val.ndim == 2: - for i in range(val.shape[0]): - for j in range(val.shape[1]): - itm = listbase.addChildNode(six.b('ListItem')) - itm.addAttribute(< string > six.b('index'), < float32 > index) - itm.addAttribute( < string > six.b('value'), < float32 > val[i, j]) - index += 1 - del itm + listbase.setContent(data, val.shape[1], val.shape[0], False) elif val.ndim == 1: - for i in range(val.shape[0]): - itm = listbase.addChildNode(six.b('ListItem')) - itm.addAttribute(< string > six.b('index'), < float32 > index) - itm.addAttribute(< string > six.b('value'), < float32 > val[i]) - index += 1 - del itm + listbase.setContent(data, val.shape[0]) else: raise Exception("Only 1 or 2 dimensions are allowed") - del listbase else: node.addOption(item, wrap_to_bytes(val)) @@ -214,10 +191,10 @@ def stringToPythonValue(inputIn): return str(input) -cdef XMLNode2dict(XMLNode * node): - cdef XMLNode * subnode - cdef list[XMLNode * ] nodes - cdef list[XMLNode * ].iterator it +cdef XMLNode2dict(XMLNode node): + cdef XMLNode subnode + cdef list[XMLNode] nodes + cdef list[XMLNode].iterator it dct = {} opts = {} if node.hasAttribute(six.b('type')): @@ -227,10 +204,12 @@ cdef XMLNode2dict(XMLNode * node): while it != nodes.end(): subnode = deref(it) if castString(subnode.getName())=="Option": - opts[castString(subnode.getAttribute('key'))] = stringToPythonValue(subnode.getAttribute('value')) + if subnode.hasAttribute('value'): + opts[castString(subnode.getAttribute('key'))] = stringToPythonValue(subnode.getAttribute('value')) + else: + opts[castString(subnode.getAttribute('key'))] = stringToPythonValue(subnode.getContent()) else: dct[castString(subnode.getName())] = stringToPythonValue(subnode.getContent()) - del subnode inc(it) if len(opts)>0: dct['options'] = opts return dct |