diff options
| -rw-r--r-- | build/linux/Makefile.in | 16 | ||||
| -rw-r--r-- | include/astra/AstraObjectFactory.h | 13 | ||||
| -rw-r--r-- | include/astra/PluginAlgorithm.h | 85 | ||||
| -rw-r--r-- | matlab/mex/astra_mex_plugin_c.cpp | 139 | ||||
| -rw-r--r-- | python/astra/__init__.py | 1 | ||||
| -rw-r--r-- | python/astra/plugin.py | 95 | ||||
| -rw-r--r-- | python/astra/plugin_c.pyx | 59 | ||||
| -rw-r--r-- | python/astra/utils.pyx | 72 | ||||
| -rw-r--r-- | python/docSRC/index.rst | 1 | ||||
| -rw-r--r-- | python/docSRC/plugins.rst | 8 | ||||
| -rw-r--r-- | samples/python/s018_plugin.py | 138 | ||||
| -rw-r--r-- | src/PluginAlgorithm.cpp | 294 | 
12 files changed, 851 insertions, 70 deletions
diff --git a/build/linux/Makefile.in b/build/linux/Makefile.in index 2d862f2..e209fa7 100644 --- a/build/linux/Makefile.in +++ b/build/linux/Makefile.in @@ -50,11 +50,17 @@ LDFLAGS+=-fopenmp  endif  ifeq ($(python),yes) -PYCPPFLAGS  = ${CPPFLAGS} +PYTHON      = @PYTHON@ +PYLIBDIR = $(shell $(PYTHON) -c 'from distutils.sysconfig import get_config_var; import six; six.print_(get_config_var("LIBDIR"))') +PYINCDIR = $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; import six; six.print_(get_python_inc())') +PYLIBVER = `basename $(PYINCDIR)` +CPPFLAGS += -DASTRA_PYTHON -I$(PYINCDIR) +PYCPPFLAGS  = $(CPPFLAGS)  PYCPPFLAGS  += -I../include -PYLDFLAGS = ${LDFLAGS} +PYLDFLAGS = $(LDFLAGS)  PYLDFLAGS   += -L../build/linux/.libs -PYTHON      = @PYTHON@ +LIBS		+= -l$(PYLIBVER) +LDFLAGS += -L$(PYLIBDIR)  endif  BOOST_CPPFLAGS= @@ -234,6 +240,10 @@ MATLAB_MEX=\  	matlab/mex/astra_mex_log_c.$(MEXSUFFIX) \  	matlab/mex/astra_mex_data3d_c.$(MEXSUFFIX) +ifeq ($(python),yes) +ALL_OBJECTS+=src/PluginAlgorithm.lo +MATLAB_MEX+=matlab/mex/astra_mex_plugin_c.$(MEXSUFFIX) +endif  OBJECT_DIRS = src/ tests/ cuda/2d/ cuda/3d/ matlab/mex/ ./  DEPDIRS = $(addsuffix $(DEPDIR),$(OBJECT_DIRS)) diff --git a/include/astra/AstraObjectFactory.h b/include/astra/AstraObjectFactory.h index 356acf9..325989e 100644 --- a/include/astra/AstraObjectFactory.h +++ b/include/astra/AstraObjectFactory.h @@ -40,6 +40,10 @@ $Id$  #include "AlgorithmTypelist.h" +#ifdef ASTRA_PYTHON +#include "PluginAlgorithm.h" +#endif +  namespace astra { @@ -147,6 +151,15 @@ T* CAstraObjectFactory<T, TypeList>::create(const Config& _cfg)  */  class _AstraExport CAlgorithmFactory : public CAstraObjectFactory<CAlgorithm, AlgorithmTypeList> {}; +#ifdef ASTRA_PYTHON +template <> +inline CAlgorithm* CAstraObjectFactory<CAlgorithm, AlgorithmTypeList>::findPlugin(std::string _sType) +	{ +		CPluginAlgorithmFactory *fac = CPluginAlgorithmFactory::getSingletonPtr(); +		return fac->getPlugin(_sType); +	} +#endif +  /**   * Class used to create 2D projectors from a string or a config object  */ diff --git a/include/astra/PluginAlgorithm.h b/include/astra/PluginAlgorithm.h new file mode 100644 index 0000000..7d6c64a --- /dev/null +++ b/include/astra/PluginAlgorithm.h @@ -0,0 +1,85 @@ +/* +----------------------------------------------------------------------- +Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +           2014-2015, CWI, Amsterdam + +Contact: astra@uantwerpen.be +Website: http://sf.net/projects/astra-toolbox + +This file is part of the ASTRA Toolbox. + + +The ASTRA Toolbox is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +The ASTRA Toolbox is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. + +----------------------------------------------------------------------- +$Id$ +*/ + +#ifndef _INC_ASTRA_PLUGINALGORITHM +#define _INC_ASTRA_PLUGINALGORITHM + +#ifdef ASTRA_PYTHON + +#include <Python.h> +#include "bytesobject.h" +#include "astra/Algorithm.h" +#include "astra/Singleton.h" +#include "astra/XMLDocument.h" +#include "astra/XMLNode.h" + +namespace astra { +class _AstraExport CPluginAlgorithm : public CAlgorithm { + +public: + +    CPluginAlgorithm(PyObject* pyclass); +    ~CPluginAlgorithm(); + +    bool initialize(const Config& _cfg); +    void run(int _iNrIterations); + +private: +    PyObject * instance; + +}; + +class _AstraExport CPluginAlgorithmFactory : public Singleton<CPluginAlgorithmFactory> { + +public: + +    CPluginAlgorithmFactory(); +    ~CPluginAlgorithmFactory(); + +    CPluginAlgorithm * getPlugin(std::string name); + +    bool registerPlugin(std::string name, std::string className); +    bool registerPluginClass(std::string name, PyObject * className); +     +    PyObject * getRegistered(); +     +    std::string getHelp(std::string name); + +private: +    PyObject * pluginDict; +    PyObject *ospath, *inspect, *six, *astra; +    std::vector<std::string> getPluginPathList(); +}; + +PyObject* XMLNode2dict(XMLNode node); + +} + +#endif + +#endif
\ No newline at end of file diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp new file mode 100644 index 0000000..2d9b9a0 --- /dev/null +++ b/matlab/mex/astra_mex_plugin_c.cpp @@ -0,0 +1,139 @@ +/* +----------------------------------------------------------------------- +Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +           2014-2015, CWI, Amsterdam + +Contact: astra@uantwerpen.be +Website: http://sf.net/projects/astra-toolbox + +This file is part of the ASTRA Toolbox. + + +The ASTRA Toolbox is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +The ASTRA Toolbox is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. + +----------------------------------------------------------------------- +$Id$ +*/ + +/** \file astra_mex_plugin_c.cpp + * + *  \brief Manages Python plugins. + */ + +#include <mex.h> +#include "mexHelpFunctions.h" +#include "mexInitFunctions.h" + +#include "astra/PluginAlgorithm.h" + +#include "Python.h" +#include "bytesobject.h" + +using namespace std; +using namespace astra; + + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('get_registered'); + * + * Print registered plugins. + */ +void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ +    astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); +    PyObject *dict = fact->getRegistered(); +    PyObject *key, *value; +    Py_ssize_t pos = 0; +    while (PyDict_Next(dict, &pos, &key, &value)) { +        mexPrintf("%s: %s\n",PyBytes_AsString(key),PyBytes_AsString(value)); +    } +    Py_DECREF(dict); +} + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('register', name, class_name); + * + * Register plugin. + */ +void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ +    if (3 <= nrhs) { +        string name = mexToString(prhs[1]); +        string class_name = mexToString(prhs[2]); +        astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); +        fact->registerPlugin(name, class_name); +    }else{ +        mexPrintf("astra_mex_plugin('register', name, class_name);\n"); +    } +} + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('get_help', name); + * + * Get help about plugin. + */ +void astra_mex_plugin_get_help(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ +    if (2 <= nrhs) { +        string name = mexToString(prhs[1]); +        astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); +        mexPrintf((fact->getHelp(name)+"\n").c_str()); +    }else{ +        mexPrintf("astra_mex_plugin('get_help', name);\n"); +    } +} + + +//----------------------------------------------------------------------------------------- + +static void printHelp() +{ +	mexPrintf("Please specify a mode of operation.\n"); +	mexPrintf("   Valid modes: register, get_registered, get_help\n"); +} + +//----------------------------------------------------------------------------------------- +/** + * ... = astra_mex(type,...); + */ +void mexFunction(int nlhs, mxArray* plhs[], +				 int nrhs, const mxArray* prhs[]) +{ + +	// INPUT0: Mode +	string sMode = ""; +	if (1 <= nrhs) { +		sMode = mexToString(prhs[0]);	 +	} else { +		printHelp(); +		return; +	} + +	initASTRAMex(); + +	// SWITCH (MODE) +	if (sMode ==  std::string("get_registered")) {  +		astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs);  +    }else if (sMode ==  std::string("get_help")) {  +        astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs);  +    }else if (sMode ==  std::string("register")) {  +		astra_mex_plugin_register(nlhs, plhs, nrhs, prhs);  +	} else { +		printHelp(); +	} + +	return; +} + + diff --git a/python/astra/__init__.py b/python/astra/__init__.py index 6c15d30..10ed74d 100644 --- a/python/astra/__init__.py +++ b/python/astra/__init__.py @@ -34,6 +34,7 @@ from . import algorithm  from . import projector  from . import projector3d  from . import matrix +from . import plugin  from . import log  from .optomo import OpTomo diff --git a/python/astra/plugin.py b/python/astra/plugin.py new file mode 100644 index 0000000..ccdb2cb --- /dev/null +++ b/python/astra/plugin.py @@ -0,0 +1,95 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- + +from . import plugin_c as p +from . import log + +class base(object): + +    def astra_init(self, cfg): +        try: +            try: +                req = self.required_options +            except AttributeError: +                log.warn("Plugin '" + self.__class__.__name__ + "' does not specify required options") +                req = {} + +            try: +                opt = self.optional_options +            except AttributeError: +                log.warn("Plugin '" + self.__class__.__name__ + "' does not specify optional options") +                opt = {} + +            try: +                optDict = cfg['options'] +            except KeyError: +                optDict = {} + +            cfgKeys = set(optDict.keys()) +            reqKeys = set(req) +            optKeys = set(opt) + +            if not reqKeys.issubset(cfgKeys): +                for key in reqKeys.difference(cfgKeys): +                    log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") +                raise ValueError("Missing required options") + +            if not cfgKeys.issubset(reqKeys | optKeys): +                log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) + +            self.initialize(cfg) +        except Exception as e: +            log.error(str(e)) +            raise + +def register(name, className): +    """Register plugin with ASTRA. +     +    :param name: Plugin name to register +    :type name: :class:`str` +    :param className: Class name or class object to register +    :type className: :class:`str` or :class:`class` +     +    """ +    p.register(name,className) + +def get_registered(): +    """Get dictionary of registered plugins. +     +    :returns: :class:`dict` -- Registered plugins. +     +    """ +    return p.get_registered() + +def get_help(name): +    """Get help for registered plugin. +     +    :param name: Plugin name to get help for +    :type name: :class:`str` +    :returns: :class:`str` -- Help string (docstring). +     +    """ +    return p.get_help(name)
\ No newline at end of file diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx new file mode 100644 index 0000000..91b3cd5 --- /dev/null +++ b/python/astra/plugin_c.pyx @@ -0,0 +1,59 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- +# distutils: language = c++ +# distutils: libraries = astra + +import six +import inspect + +from libcpp.string cimport string +from libcpp cimport bool + +cdef CPluginAlgorithmFactory *fact = getSingletonPtr() + +from . import utils + +cdef extern from "astra/PluginAlgorithm.h" namespace "astra": +    cdef cppclass CPluginAlgorithmFactory: +        bool registerPlugin(string name, string className) +        bool registerPluginClass(string name, object className) +        object getRegistered() +        string getHelp(string name) + +cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory": +    cdef CPluginAlgorithmFactory* getSingletonPtr() + +def register(name, className): +    if inspect.isclass(className): +        fact.registerPluginClass(six.b(name), className) +    else: +        fact.registerPlugin(six.b(name), six.b(className)) + +def get_registered(): +    return fact.getRegistered() + +def get_help(name): +    return utils.wrap_from_bytes(fact.getHelp(six.b(name))) diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx index ddb37aa..3746b8e 100644 --- a/python/astra/utils.pyx +++ b/python/astra/utils.pyx @@ -30,7 +30,6 @@ cimport numpy as np  import numpy as np  import six  from libcpp.string cimport string -from libcpp.list cimport list  from libcpp.vector cimport vector  from cython.operator cimport dereference as deref, preincrement as inc  from cpython.version cimport PY_MAJOR_VERSION @@ -40,6 +39,9 @@ from .PyXMLDocument cimport XMLDocument  from .PyXMLDocument cimport XMLNode  from .PyIncludes cimport * +cdef extern from "astra/PluginAlgorithm.h" namespace "astra": +    object XMLNode2dict(XMLNode) +  cdef Config * dictToConfig(string rootname, dc):      cdef Config * cfg = new Config() @@ -91,6 +93,8 @@ cdef void readDict(XMLNode root, _dc):      dc = convert_item(_dc)      for item in dc:          val = dc[item] +        if isinstance(val, list): +            val = np.array(val,dtype=np.float64)          if isinstance(val, np.ndarray):              if val.size == 0:                  break @@ -142,69 +146,3 @@ cdef void readOptions(XMLNode node, dc):  cdef configToDict(Config *cfg):      return XMLNode2dict(cfg.self) -def castString3(input): -    return input.decode('utf-8') - -def castString2(input): -    return input - -if six.PY3: -    castString = castString3 -else: -    castString = castString2 - -def stringToPythonValue(inputIn): -    input = castString(inputIn) -    # matrix -    if ';' in input: -        row_strings = input.split(';') -        col_strings = row_strings[0].split(',') -        nRows = len(row_strings) -        nCols = len(col_strings) - -        out = np.empty((nRows,nCols)) -        for ridx, row in enumerate(row_strings): -            col_strings = row.split(',') -            for cidx, col in enumerate(col_strings): -                out[ridx,cidx] = float(col) -        return out - -    # vector -    if ',' in input: -        items = input.split(',') -        out = np.empty(len(items)) -        for idx,item in enumerate(items): -            out[idx] = float(item) -        return out - -    try: -        # integer -        return int(input) -    except ValueError: -        try: -            #float -            return float(input) -        except ValueError: -            # string -            return str(input) - - -cdef XMLNode2dict(XMLNode node): -    cdef XMLNode subnode -    cdef list[XMLNode] nodes -    cdef list[XMLNode].iterator it -    dct = {} -    opts = {} -    if node.hasAttribute(six.b('type')): -        dct['type'] = castString(node.getAttribute(six.b('type'))) -    nodes = node.getNodes() -    it = nodes.begin() -    while it != nodes.end(): -        subnode = deref(it) -        if castString(subnode.getName())=="Option": -            opts[castString(subnode.getAttribute('key'))] = stringToPythonValue(subnode.getAttribute('value')) -        else: -            dct[castString(subnode.getName())] = stringToPythonValue(subnode.getContent()) -        inc(it) -    if len(opts)>0: dct['options'] = opts -    return dct diff --git a/python/docSRC/index.rst b/python/docSRC/index.rst index b7cc6d6..dcc6590 100644 --- a/python/docSRC/index.rst +++ b/python/docSRC/index.rst @@ -19,6 +19,7 @@ Contents:     creators     functions     operator +   plugins     matlab     astra  .. astra diff --git a/python/docSRC/plugins.rst b/python/docSRC/plugins.rst new file mode 100644 index 0000000..dc7c607 --- /dev/null +++ b/python/docSRC/plugins.rst @@ -0,0 +1,8 @@ +Plugins: the :mod:`plugin` module +========================================= + +.. automodule:: astra.plugin +    :members: +    :undoc-members: +    :show-inheritance: + diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py new file mode 100644 index 0000000..6677930 --- /dev/null +++ b/samples/python/s018_plugin.py @@ -0,0 +1,138 @@ +#----------------------------------------------------------------------- +#Copyright 2015 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- + +import astra +import numpy as np +import six + +# Define the plugin class (has to subclass astra.plugin.base) +# Note that usually, these will be defined in a separate package/module +class SIRTPlugin(astra.plugin.base): +    """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. + +    Optional options: + +    'rel_factor': relaxation factor +    """ +    required_options=[] +    optional_options=['rel_factor'] + +    def initialize(self,cfg): +        self.W = astra.OpTomo(cfg['ProjectorId']) +        self.vid = cfg['ReconstructionDataId'] +        self.sid = cfg['ProjectionDataId'] +        try: +            self.rel = cfg['option']['rel_factor'] +        except KeyError: +            self.rel = 1 + +    def run(self, its): +        v = astra.data2d.get_shared(self.vid) +        s = astra.data2d.get_shared(self.sid) +        W = self.W +        for i in range(its): +            v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size + +if __name__=='__main__': + +    vol_geom = astra.create_vol_geom(256, 256) +    proj_geom = astra.create_proj_geom('parallel', 1.0, 384, np.linspace(0,np.pi,180,False)) + +    # As before, create a sinogram from a phantom +    import scipy.io +    P = scipy.io.loadmat('phantom.mat')['phantom256'] +    proj_id = astra.create_projector('cuda',proj_geom,vol_geom) + +    # construct the OpTomo object +    W = astra.OpTomo(proj_id) + +    sinogram = W * P +    sinogram = sinogram.reshape([180, 384]) + +    # Register the plugin with ASTRA +    # A default set of plugins to load can be defined in: +    #     - /etc/astra-toolbox/plugins.txt +    #     - [ASTRA_INSTALL_PATH]/python/astra/plugins.txt +    #     - [USER_HOME_PATH]/.astra-toolbox/plugins.txt +    #     - [ASTRA_PLUGIN_PATH environment variable]/plugins.txt +    # In these files, create a separate line for each plugin with: +    # [PLUGIN_ASTRA_NAME] [FULL_PLUGIN_CLASS] +    # +    # So in this case, it would be a line: +    # SIRT-PLUGIN s018_plugin.SIRTPlugin +    # +    astra.plugin.register('SIRT-PLUGIN','s018_plugin.SIRTPlugin') + +    # To get help on a registered plugin, use get_help +    six.print_(astra.plugin.get_help('SIRT-PLUGIN')) + +    # Create data structures +    sid = astra.data2d.create('-sino', proj_geom, sinogram) +    vid = astra.data2d.create('-vol', vol_geom) + +    # Create config using plugin name +    cfg = astra.astra_dict('SIRT-PLUGIN') +    cfg['ProjectorId'] = proj_id +    cfg['ProjectionDataId'] = sid +    cfg['ReconstructionDataId'] = vid + +    # Create algorithm object +    alg_id = astra.algorithm.create(cfg) + +    # Run algorithm for 100 iterations +    astra.algorithm.run(alg_id, 100) + +    # Get reconstruction +    rec = astra.data2d.get(vid) + +    # Options for the plugin go in cfg['option'] +    cfg = astra.astra_dict('SIRT-PLUGIN') +    cfg['ProjectorId'] = proj_id +    cfg['ProjectionDataId'] = sid +    cfg['ReconstructionDataId'] = vid +    cfg['option'] = {} +    cfg['option']['rel_factor'] = 1.5 +    alg_id_rel = astra.algorithm.create(cfg) +    astra.algorithm.run(alg_id_rel, 100) +    rec_rel = astra.data2d.get(vid) + +    # We can also use OpTomo to call the plugin +    rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5}) + +    import pylab as pl +    pl.gray() +    pl.figure(1) +    pl.imshow(rec,vmin=0,vmax=1) +    pl.figure(2) +    pl.imshow(rec_rel,vmin=0,vmax=1) +    pl.figure(3) +    pl.imshow(rec_op,vmin=0,vmax=1) +    pl.show() + +    # Clean up. +    astra.projector.delete(proj_id) +    astra.algorithm.delete([alg_id, alg_id_rel]) +    astra.data2d.delete([vid, sid]) diff --git a/src/PluginAlgorithm.cpp b/src/PluginAlgorithm.cpp new file mode 100644 index 0000000..df13f31 --- /dev/null +++ b/src/PluginAlgorithm.cpp @@ -0,0 +1,294 @@ +/* +----------------------------------------------------------------------- +Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp +           2014-2015, CWI, Amsterdam + +Contact: astra@uantwerpen.be +Website: http://sf.net/projects/astra-toolbox + +This file is part of the ASTRA Toolbox. + + +The ASTRA Toolbox is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +The ASTRA Toolbox is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. + +----------------------------------------------------------------------- +$Id$ +*/ + +#ifdef ASTRA_PYTHON + +#include "astra/PluginAlgorithm.h" +#include <boost/algorithm/string.hpp> +#include <boost/algorithm/string/split.hpp> +#include <boost/lexical_cast.hpp> +#include <iostream> +#include <fstream> +#include <string> + +namespace astra { + +CPluginAlgorithm::CPluginAlgorithm(PyObject* pyclass){ +    instance = PyObject_CallObject(pyclass, NULL); +} + +CPluginAlgorithm::~CPluginAlgorithm(){ +    if(instance!=NULL){ +        Py_DECREF(instance); +        instance = NULL; +    } +} + +bool CPluginAlgorithm::initialize(const Config& _cfg){ +    if(instance==NULL) return false; +    PyObject *cfgDict = XMLNode2dict(_cfg.self); +    PyObject *retVal = PyObject_CallMethod(instance, "astra_init", "O",cfgDict); +    Py_DECREF(cfgDict); +    if(retVal==NULL) return false; +    m_bIsInitialized = true; +    Py_DECREF(retVal); +    return m_bIsInitialized; +} + +void CPluginAlgorithm::run(int _iNrIterations){ +    if(instance==NULL) return; +    PyObject *retVal = PyObject_CallMethod(instance, "run", "i",_iNrIterations); +    if(retVal==NULL) return; +    Py_DECREF(retVal); +} + +const char ps = +#ifdef _WIN32 +                            '\\'; +#else +                            '/'; +#endif + +std::vector<std::string> CPluginAlgorithmFactory::getPluginPathList(){ +    std::vector<std::string> list; +    list.push_back("/etc/astra-toolbox"); +    PyObject *ret, *retb; +    ret = PyObject_CallMethod(inspect,"getfile","O",astra); +    if(ret!=NULL){ +        retb = PyObject_CallMethod(six,"b","O",ret); +        Py_DECREF(ret); +        if(retb!=NULL){ +            std::string astra_inst (PyBytes_AsString(retb)); +            Py_DECREF(retb); +            ret = PyObject_CallMethod(ospath,"dirname","s",astra_inst.c_str()); +            if(ret!=NULL){ +                retb = PyObject_CallMethod(six,"b","O",ret); +                Py_DECREF(ret); +                if(retb!=NULL){ +                    list.push_back(std::string(PyBytes_AsString(retb))); +                    Py_DECREF(retb); +                } +            } +        } +    } +    ret = PyObject_CallMethod(ospath,"expanduser","s","~"); +    if(ret!=NULL){ +        retb = PyObject_CallMethod(six,"b","O",ret); +        Py_DECREF(ret); +        if(retb!=NULL){ +            list.push_back(std::string(PyBytes_AsString(retb)) + ps + ".astra-toolbox"); +            Py_DECREF(retb); +        } +    } +    const char *envval = getenv("ASTRA_PLUGIN_PATH"); +    if(envval!=NULL){ +        list.push_back(std::string(envval)); +    } +    return list; +} + +CPluginAlgorithmFactory::CPluginAlgorithmFactory(){ +    Py_Initialize(); +    pluginDict = PyDict_New(); +    ospath = PyImport_ImportModule("os.path"); +    inspect = PyImport_ImportModule("inspect"); +    six = PyImport_ImportModule("six"); +    astra = PyImport_ImportModule("astra"); +    std::vector<std::string> fls = getPluginPathList(); +    std::vector<std::string> items; +    for(unsigned int i=0;i<fls.size();i++){ +        std::ifstream fs ((fls[i]+ps+"plugins.txt").c_str()); +        if(!fs.is_open()) continue; +        std::string line; +        while (std::getline(fs,line)){ +            boost::split(items, line, boost::is_any_of(" ")); +            if(items.size()<2) continue; +            PyObject *str = PyBytes_FromString(items[1].c_str()); +            PyDict_SetItemString(pluginDict,items[0].c_str(),str); +            Py_DECREF(str); +        } +        fs.close(); +    } +} + +CPluginAlgorithmFactory::~CPluginAlgorithmFactory(){ +    if(pluginDict!=NULL){ +        Py_DECREF(pluginDict); +    } +} + +bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){ +    PyObject *str = PyBytes_FromString(className.c_str()); +    PyDict_SetItemString(pluginDict, name.c_str(), str); +    Py_DECREF(str); +    return true; +} + +bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){ +    PyDict_SetItemString(pluginDict, name.c_str(), className); +    return true; +} + +PyObject * getClassFromString(std::string str){ +    std::vector<std::string> items; +    boost::split(items, str, boost::is_any_of(".")); +    PyObject *pyclass = PyImport_ImportModule(items[0].c_str()); +    if(pyclass==NULL) return NULL; +    PyObject *submod = pyclass; +    for(unsigned int i=1;i<items.size();i++){ +        submod = PyObject_GetAttrString(submod,items[i].c_str()); +        Py_DECREF(pyclass); +        pyclass = submod; +        if(pyclass==NULL) return NULL; +    } +    return pyclass; +} + +CPluginAlgorithm * CPluginAlgorithmFactory::getPlugin(std::string name){ +    PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); +    if(className==NULL) return NULL; +    CPluginAlgorithm *alg = NULL; +    if(PyBytes_Check(className)){ +        std::string str = std::string(PyBytes_AsString(className)); +    	PyObject *pyclass = getClassFromString(str); +        if(pyclass!=NULL){ +            alg = new CPluginAlgorithm(pyclass); +            Py_DECREF(pyclass); +        } +    }else{ +        alg = new CPluginAlgorithm(className); +    } +    return alg; +} + +PyObject * CPluginAlgorithmFactory::getRegistered(){ +    Py_INCREF(pluginDict); +    return pluginDict; +} + +std::string CPluginAlgorithmFactory::getHelp(std::string name){ +    PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); +    if(className==NULL) return ""; +    std::string str = std::string(PyBytes_AsString(className)); +    std::string ret = ""; +    PyObject *pyclass = getClassFromString(str); +    if(pyclass==NULL) return ""; +    PyObject *module = PyImport_ImportModule("inspect"); +    if(module!=NULL){ +        PyObject *retVal = PyObject_CallMethod(module,"getdoc","O",pyclass); +        if(retVal!=NULL){ +            PyObject *retb = PyObject_CallMethod(six,"b","O",retVal); +            Py_DECREF(retVal); +            if(retVal!=NULL){ +                ret = std::string(PyBytes_AsString(retb)); +                Py_DECREF(retb); +            } +        } +        Py_DECREF(module); +    } +    Py_DECREF(pyclass); +    return ret; +} + +DEFINE_SINGLETON(CPluginAlgorithmFactory); + +#if PY_MAJOR_VERSION >= 3 +PyObject * pyStringFromString(std::string str){ +    return PyUnicode_FromString(str.c_str()); +} +#else +PyObject * pyStringFromString(std::string str){ +    return PyBytes_FromString(str.c_str()); +} +#endif + +PyObject* stringToPythonValue(std::string str){ +    if(str.find(";")!=std::string::npos){ +        std::vector<std::string> rows, row; +        boost::split(rows, str, boost::is_any_of(";")); +        PyObject *mat = PyList_New(rows.size()); +        for(unsigned int i=0; i<rows.size(); i++){ +            boost::split(row, rows[i], boost::is_any_of(",")); +            PyObject *rowlist = PyList_New(row.size()); +            for(unsigned int j=0;j<row.size();j++){ +                PyList_SetItem(rowlist, j, PyFloat_FromDouble(boost::lexical_cast<double>(row[j]))); +            } +            PyList_SetItem(mat, i, rowlist); +        } +        return mat; +    } +    if(str.find(",")!=std::string::npos){ +        std::vector<std::string> vec; +        boost::split(vec, str, boost::is_any_of(",")); +        PyObject *veclist = PyList_New(vec.size()); +        for(unsigned int i=0;i<vec.size();i++){ +            PyList_SetItem(veclist, i, PyFloat_FromDouble(boost::lexical_cast<double>(vec[i]))); +        } +        return veclist; +    } +    try{ +        return PyLong_FromLong(boost::lexical_cast<long>(str)); +    }catch(const boost::bad_lexical_cast &){ +        try{ +            return PyFloat_FromDouble(boost::lexical_cast<double>(str)); +        }catch(const boost::bad_lexical_cast &){ +            return pyStringFromString(str); +        } +    } +} + +PyObject* XMLNode2dict(XMLNode node){ +    PyObject *dct = PyDict_New(); +    PyObject *opts = PyDict_New(); +    if(node.hasAttribute("type")){ +        PyObject *obj = pyStringFromString(node.getAttribute("type").c_str()); +        PyDict_SetItemString(dct, "type", obj); +        Py_DECREF(obj); +    } +    std::list<XMLNode> nodes = node.getNodes(); +    std::list<XMLNode>::iterator it = nodes.begin(); +    while(it!=nodes.end()){ +        XMLNode subnode = *it; +        if(subnode.getName()=="Option"){ +            PyObject *obj = stringToPythonValue(subnode.getAttribute("value")); +            PyDict_SetItemString(opts, subnode.getAttribute("key").c_str(), obj); +            Py_DECREF(obj); +        }else{ +            PyObject *obj = stringToPythonValue(subnode.getContent()); +            PyDict_SetItemString(dct, subnode.getName().c_str(), obj); +            Py_DECREF(obj); +        } +        ++it; +    } +    PyDict_SetItemString(dct, "options", opts); +    Py_DECREF(opts); +    return dct; +} + +} +#endif
\ No newline at end of file  | 
