diff options
Diffstat (limited to 'python/astra')
-rw-r--r-- | python/astra/PyIncludes.pxd | 2 | ||||
-rw-r--r-- | python/astra/__init__.py | 1 | ||||
-rw-r--r-- | python/astra/data2d_c.pyx | 27 | ||||
-rw-r--r-- | python/astra/plugin.py | 121 | ||||
-rw-r--r-- | python/astra/plugin_c.pyx | 67 | ||||
-rw-r--r-- | python/astra/utils.pyx | 78 |
6 files changed, 224 insertions, 72 deletions
diff --git a/python/astra/PyIncludes.pxd b/python/astra/PyIncludes.pxd index 35dea5f..77346b0 100644 --- a/python/astra/PyIncludes.pxd +++ b/python/astra/PyIncludes.pxd @@ -62,6 +62,7 @@ cdef extern from "astra/VolumeGeometry2D.h" namespace "astra": float32 getWindowMaxX() float32 getWindowMaxY() Config* getConfiguration() + bool isEqual(CVolumeGeometry2D*) cdef extern from "astra/Float32Data2D.h" namespace "astra": cdef cppclass CFloat32CustomMemory: @@ -89,6 +90,7 @@ cdef extern from "astra/ProjectionGeometry2D.h" namespace "astra": float32 getProjectionAngle(int) float32 getDetectorWidth() Config* getConfiguration() + bool isEqual(CProjectionGeometry2D*) cdef extern from "astra/Float32Data2D.h" namespace "astra::CFloat32Data2D": cdef enum TWOEDataType "astra::CFloat32Data2D::EDataType": diff --git a/python/astra/__init__.py b/python/astra/__init__.py index 6c15d30..10ed74d 100644 --- a/python/astra/__init__.py +++ b/python/astra/__init__.py @@ -34,6 +34,7 @@ from . import algorithm from . import projector from . import projector3d from . import matrix +from . import plugin from . import log from .optomo import OpTomo diff --git a/python/astra/data2d_c.pyx b/python/astra/data2d_c.pyx index 4919bf2..801fd8e 100644 --- a/python/astra/data2d_c.pyx +++ b/python/astra/data2d_c.pyx @@ -34,6 +34,9 @@ from cython cimport view cimport PyData2DManager from .PyData2DManager cimport CData2DManager +cimport PyProjector2DManager +from .PyProjector2DManager cimport CProjector2DManager + cimport PyXMLDocument from .PyXMLDocument cimport XMLDocument @@ -54,6 +57,8 @@ import operator from six.moves import reduce cdef CData2DManager * man2d = <CData2DManager * >PyData2DManager.getSingletonPtr() +cdef CProjector2DManager * manProj = <CProjector2DManager * >PyProjector2DManager.getSingletonPtr() + cdef extern from "CFloat32CustomPython.h": cdef cppclass CFloat32CustomPython: @@ -164,7 +169,6 @@ def store(i, data): cdef CFloat32Data2D * pDataObject = getObject(i) fillDataObject(pDataObject, data) - def get_geometry(i): cdef CFloat32Data2D * pDataObject = getObject(i) cdef CFloat32ProjectionData2D * pDataObject2 @@ -179,6 +183,27 @@ def get_geometry(i): raise Exception("Not a known data object") return geom +cdef CProjector2D * getProjector(i) except NULL: + cdef CProjector2D * proj = manProj.get(i) + if proj == NULL: + raise Exception("Projector not initialized.") + if not proj.isInitialized(): + raise Exception("Projector not initialized.") + return proj + +def check_compatible(i, proj_id): + cdef CProjector2D * proj = getProjector(proj_id) + cdef CFloat32Data2D * pDataObject = getObject(i) + cdef CFloat32ProjectionData2D * pDataObject2 + cdef CFloat32VolumeData2D * pDataObject3 + if pDataObject.getType() == TWOPROJECTION: + pDataObject2 = <CFloat32ProjectionData2D * >pDataObject + return pDataObject2.getGeometry().isEqual(proj.getProjectionGeometry()) + elif pDataObject.getType() == TWOVOLUME: + pDataObject3 = <CFloat32VolumeData2D * >pDataObject + return pDataObject3.getGeometry().isEqual(proj.getVolumeGeometry()) + else: + raise Exception("Not a known data object") def change_geometry(i, geom): cdef Config *cfg diff --git a/python/astra/plugin.py b/python/astra/plugin.py new file mode 100644 index 0000000..3e3528d --- /dev/null +++ b/python/astra/plugin.py @@ -0,0 +1,121 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- + +from . import plugin_c as p +from . import log +from . import data2d +from . import data2d_c +from . import data3d +from . import projector +import inspect +import traceback + +class base(object): + + def astra_init(self, cfg): + args, varargs, varkw, defaults = inspect.getargspec(self.initialize) + if not defaults is None: + nopt = len(defaults) + else: + nopt = 0 + if nopt>0: + req = args[2:-nopt] + opt = args[-nopt:] + else: + req = args[2:] + opt = [] + + try: + optDict = cfg['options'] + except KeyError: + optDict = {} + + cfgKeys = set(optDict.keys()) + reqKeys = set(req) + optKeys = set(opt) + + if not reqKeys.issubset(cfgKeys): + for key in reqKeys.difference(cfgKeys): + log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") + raise ValueError("Missing required options") + + if not cfgKeys.issubset(reqKeys | optKeys): + log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) + + args = [optDict[k] for k in req] + kwargs = dict((k,optDict[k]) for k in opt if k in optDict) + self.initialize(cfg, *args, **kwargs) + +class ReconstructionAlgorithm2D(base): + + def astra_init(self, cfg): + self.pid = cfg['ProjectorId'] + self.s = data2d.get_shared(cfg['ProjectionDataId']) + self.v = data2d.get_shared(cfg['ReconstructionDataId']) + self.vg = projector.volume_geometry(self.pid) + self.pg = projector.projection_geometry(self.pid) + if not data2d_c.check_compatible(cfg['ProjectionDataId'], self.pid): + raise ValueError("Projection data and projector not compatible") + if not data2d_c.check_compatible(cfg['ReconstructionDataId'], self.pid): + raise ValueError("Reconstruction data and projector not compatible") + super(ReconstructionAlgorithm2D,self).astra_init(cfg) + +class ReconstructionAlgorithm3D(base): + + def astra_init(self, cfg): + self.pid = cfg['ProjectorId'] + self.s = data3d.get_shared(cfg['ProjectionDataId']) + self.v = data3d.get_shared(cfg['ReconstructionDataId']) + self.vg = data3d.get_geometry(cfg['ReconstructionDataId']) + self.pg = data3d.get_geometry(cfg['ProjectionDataId']) + super(ReconstructionAlgorithm3D,self).astra_init(cfg) + +def register(className): + """Register plugin with ASTRA. + + :param className: Class name or class object to register + :type className: :class:`str` or :class:`class` + + """ + p.register(className) + +def get_registered(): + """Get dictionary of registered plugins. + + :returns: :class:`dict` -- Registered plugins. + + """ + return p.get_registered() + +def get_help(name): + """Get help for registered plugin. + + :param name: Plugin name to get help for + :type name: :class:`str` + :returns: :class:`str` -- Help string (docstring). + + """ + return p.get_help(name)
\ No newline at end of file diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx new file mode 100644 index 0000000..8d6816b --- /dev/null +++ b/python/astra/plugin_c.pyx @@ -0,0 +1,67 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- +# distutils: language = c++ +# distutils: libraries = astra + +import six +import inspect + +from libcpp.string cimport string +from libcpp cimport bool + +cdef CPluginAlgorithmFactory *fact = getSingletonPtr() + +from . import utils + +cdef extern from "astra/PluginAlgorithm.h" namespace "astra": + cdef cppclass CPluginAlgorithmFactory: + bool registerPlugin(string className) + bool registerPlugin(string name, string className) + bool registerPluginClass(object className) + bool registerPluginClass(string name, object className) + object getRegistered() + string getHelp(string name) + +cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory": + cdef CPluginAlgorithmFactory* getSingletonPtr() + +def register(className, name=None): + if inspect.isclass(className): + if name==None: + fact.registerPluginClass(className) + else: + fact.registerPluginClass(six.b(name), className) + else: + if name==None: + fact.registerPlugin(six.b(className)) + else: + fact.registerPlugin(six.b(name), six.b(className)) + +def get_registered(): + return fact.getRegistered() + +def get_help(name): + return utils.wrap_from_bytes(fact.getHelp(six.b(name))) diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx index 260c308..9871ac6 100644 --- a/python/astra/utils.pyx +++ b/python/astra/utils.pyx @@ -30,7 +30,6 @@ cimport numpy as np import numpy as np import six from libcpp.string cimport string -from libcpp.list cimport list from libcpp.vector cimport vector from cython.operator cimport dereference as deref, preincrement as inc from cpython.version cimport PY_MAJOR_VERSION @@ -40,6 +39,9 @@ from .PyXMLDocument cimport XMLDocument from .PyXMLDocument cimport XMLNode from .PyIncludes cimport * +cdef extern from "astra/PluginAlgorithm.h" namespace "astra": + object XMLNode2dict(XMLNode) + cdef Config * dictToConfig(string rootname, dc): cdef Config * cfg = new Config() @@ -91,6 +93,8 @@ cdef void readDict(XMLNode root, _dc): dc = convert_item(_dc) for item in dc: val = dc[item] + if isinstance(val, list) or isinstance(val, tuple): + val = np.array(val,dtype=np.float64) if isinstance(val, np.ndarray): if val.size == 0: break @@ -125,6 +129,8 @@ cdef void readOptions(XMLNode node, dc): val = dc[item] if node.hasOption(item): raise Exception('Duplicate Option: %s' % item) + if isinstance(val, list) or isinstance(val, tuple): + val = np.array(val,dtype=np.float64) if isinstance(val, np.ndarray): if val.size == 0: break @@ -143,73 +149,3 @@ cdef void readOptions(XMLNode node, dc): cdef configToDict(Config *cfg): return XMLNode2dict(cfg.self) - -def castString3(input): - return input.decode('utf-8') - -def castString2(input): - return input - -if six.PY3: - castString = castString3 -else: - castString = castString2 - -def stringToPythonValue(inputIn): - input = castString(inputIn) - # matrix - if ';' in input: - row_strings = input.split(';') - col_strings = row_strings[0].split(',') - nRows = len(row_strings) - nCols = len(col_strings) - - out = np.empty((nRows,nCols)) - for ridx, row in enumerate(row_strings): - col_strings = row.split(',') - for cidx, col in enumerate(col_strings): - out[ridx,cidx] = float(col) - return out - - # vector - if ',' in input: - items = input.split(',') - out = np.empty(len(items)) - for idx,item in enumerate(items): - out[idx] = float(item) - return out - - try: - # integer - return int(input) - except ValueError: - try: - #float - return float(input) - except ValueError: - # string - return str(input) - - -cdef XMLNode2dict(XMLNode node): - cdef XMLNode subnode - cdef list[XMLNode] nodes - cdef list[XMLNode].iterator it - dct = {} - opts = {} - if node.hasAttribute(six.b('type')): - dct['type'] = castString(node.getAttribute(six.b('type'))) - nodes = node.getNodes() - it = nodes.begin() - while it != nodes.end(): - subnode = deref(it) - if castString(subnode.getName())=="Option": - if subnode.hasAttribute('value'): - opts[castString(subnode.getAttribute('key'))] = stringToPythonValue(subnode.getAttribute('value')) - else: - opts[castString(subnode.getAttribute('key'))] = stringToPythonValue(subnode.getContent()) - else: - dct[castString(subnode.getName())] = stringToPythonValue(subnode.getContent()) - inc(it) - if len(opts)>0: dct['options'] = opts - return dct |