From d91b51f6d58003de84a9d6dd8189fceba0e81a5a Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Mon, 20 Jul 2015 14:07:21 +0200 Subject: Allow registering plugins without explicit name, and fix exception handling when running in Matlab --- include/astra/PluginAlgorithm.h | 3 ++ matlab/mex/astra_mex_plugin_c.cpp | 23 ++++------ python/astra/plugin.py | 71 ++++++++++++----------------- python/astra/plugin_c.pyx | 14 ++++-- samples/python/s018_plugin.py | 23 +++++----- src/PluginAlgorithm.cpp | 95 +++++++++++++++++++++++++++++++-------- 6 files changed, 138 insertions(+), 91 deletions(-) diff --git a/include/astra/PluginAlgorithm.h b/include/astra/PluginAlgorithm.h index a82c579..b56228e 100644 --- a/include/astra/PluginAlgorithm.h +++ b/include/astra/PluginAlgorithm.h @@ -64,9 +64,12 @@ public: CPluginAlgorithm * getPlugin(std::string name); bool registerPlugin(std::string name, std::string className); + bool registerPlugin(std::string className); bool registerPluginClass(std::string name, PyObject * className); + bool registerPluginClass(PyObject * className); PyObject * getRegistered(); + std::map getRegisteredMap(); std::string getHelp(std::string name); diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp index 2d9b9a0..177fcf4 100644 --- a/matlab/mex/astra_mex_plugin_c.cpp +++ b/matlab/mex/astra_mex_plugin_c.cpp @@ -37,9 +37,6 @@ $Id$ #include "astra/PluginAlgorithm.h" -#include "Python.h" -#include "bytesobject.h" - using namespace std; using namespace astra; @@ -52,29 +49,25 @@ using namespace astra; void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); - PyObject *dict = fact->getRegistered(); - PyObject *key, *value; - Py_ssize_t pos = 0; - while (PyDict_Next(dict, &pos, &key, &value)) { - mexPrintf("%s: %s\n",PyBytes_AsString(key),PyBytes_AsString(value)); + std::map mp = fact->getRegisteredMap(); + for(std::map::iterator it=mp.begin();it!=mp.end();it++){ + mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str()); } - Py_DECREF(dict); } //----------------------------------------------------------------------------------------- -/** astra_mex_plugin('register', name, class_name); +/** astra_mex_plugin('register', class_name); * * Register plugin. */ void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { - if (3 <= nrhs) { - string name = mexToString(prhs[1]); - string class_name = mexToString(prhs[2]); + if (2 <= nrhs) { + string class_name = mexToString(prhs[1]); astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); - fact->registerPlugin(name, class_name); + fact->registerPlugin(class_name); }else{ - mexPrintf("astra_mex_plugin('register', name, class_name);\n"); + mexPrintf("astra_mex_plugin('register', class_name);\n"); } } diff --git a/python/astra/plugin.py b/python/astra/plugin.py index f8fc3bd..4b32e6e 100644 --- a/python/astra/plugin.py +++ b/python/astra/plugin.py @@ -32,60 +32,47 @@ import traceback class base(object): def astra_init(self, cfg): - try: - args, varargs, varkw, defaults = inspect.getargspec(self.initialize) - if not defaults is None: - nopt = len(defaults) - else: - nopt = 0 - if nopt>0: - req = args[2:-nopt] - opt = args[-nopt:] - else: - req = args[2:] - opt = [] + args, varargs, varkw, defaults = inspect.getargspec(self.initialize) + if not defaults is None: + nopt = len(defaults) + else: + nopt = 0 + if nopt>0: + req = args[2:-nopt] + opt = args[-nopt:] + else: + req = args[2:] + opt = [] - try: - optDict = cfg['options'] - except KeyError: - optDict = {} + try: + optDict = cfg['options'] + except KeyError: + optDict = {} - cfgKeys = set(optDict.keys()) - reqKeys = set(req) - optKeys = set(opt) + cfgKeys = set(optDict.keys()) + reqKeys = set(req) + optKeys = set(opt) - if not reqKeys.issubset(cfgKeys): - for key in reqKeys.difference(cfgKeys): - log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") - raise ValueError("Missing required options") + if not reqKeys.issubset(cfgKeys): + for key in reqKeys.difference(cfgKeys): + log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified") + raise ValueError("Missing required options") - if not cfgKeys.issubset(reqKeys | optKeys): - log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) + if not cfgKeys.issubset(reqKeys | optKeys): + log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys)))) - args = [optDict[k] for k in req] - kwargs = dict((k,optDict[k]) for k in opt if k in optDict) - self.initialize(cfg, *args, **kwargs) - except Exception: - log.error(traceback.format_exc().replace("%","%%")) - raise + args = [optDict[k] for k in req] + kwargs = dict((k,optDict[k]) for k in opt if k in optDict) + self.initialize(cfg, *args, **kwargs) - def astra_run(self, its): - try: - self.run(its) - except Exception: - log.error(traceback.format_exc().replace("%","%%")) - raise - -def register(name, className): +def register(className): """Register plugin with ASTRA. - :param name: Plugin name to register - :type name: :class:`str` :param className: Class name or class object to register :type className: :class:`str` or :class:`class` """ - p.register(name,className) + p.register(className) def get_registered(): """Get dictionary of registered plugins. diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx index 91b3cd5..8d6816b 100644 --- a/python/astra/plugin_c.pyx +++ b/python/astra/plugin_c.pyx @@ -38,7 +38,9 @@ from . import utils cdef extern from "astra/PluginAlgorithm.h" namespace "astra": cdef cppclass CPluginAlgorithmFactory: + bool registerPlugin(string className) bool registerPlugin(string name, string className) + bool registerPluginClass(object className) bool registerPluginClass(string name, object className) object getRegistered() string getHelp(string name) @@ -46,11 +48,17 @@ cdef extern from "astra/PluginAlgorithm.h" namespace "astra": cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory": cdef CPluginAlgorithmFactory* getSingletonPtr() -def register(name, className): +def register(className, name=None): if inspect.isclass(className): - fact.registerPluginClass(six.b(name), className) + if name==None: + fact.registerPluginClass(className) + else: + fact.registerPluginClass(six.b(name), className) else: - fact.registerPlugin(six.b(name), six.b(className)) + if name==None: + fact.registerPlugin(six.b(className)) + else: + fact.registerPlugin(six.b(name), six.b(className)) def get_registered(): return fact.getRegistered() diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 90e09ac..31cca95 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -38,6 +38,10 @@ class SIRTPlugin(astra.plugin.base): 'rel_factor': relaxation factor (optional) """ + # The astra_name variable defines the name to use to + # call the plugin from ASTRA + astra_name = "SIRT-PLUGIN" + def initialize(self,cfg, rel_factor = 1): self.W = astra.OpTomo(cfg['ProjectorId']) self.vid = cfg['ReconstructionDataId'] @@ -68,18 +72,13 @@ if __name__=='__main__': sinogram = sinogram.reshape([180, 384]) # Register the plugin with ASTRA - # A default set of plugins to load can be defined in: - # - /etc/astra-toolbox/plugins.txt - # - [ASTRA_INSTALL_PATH]/python/astra/plugins.txt - # - [USER_HOME_PATH]/.astra-toolbox/plugins.txt - # - [ASTRA_PLUGIN_PATH environment variable]/plugins.txt - # In these files, create a separate line for each plugin with: - # [PLUGIN_ASTRA_NAME] [FULL_PLUGIN_CLASS] - # - # So in this case, it would be a line: - # SIRT-PLUGIN s018_plugin.SIRTPlugin - # - astra.plugin.register('SIRT-PLUGIN','s018_plugin.SIRTPlugin') + # First we import the package that contains the plugin + import s018_plugin + # Then, we register the plugin class with ASTRA + astra.plugin.register(s018_plugin.SIRTPlugin) + + # Get a list of registered plugins + six.print_(astra.plugin.get_registered()) # To get help on a registered plugin, use get_help six.print_(astra.plugin.get_help('SIRT-PLUGIN')) diff --git a/src/PluginAlgorithm.cpp b/src/PluginAlgorithm.cpp index d6cf731..7f7ff61 100644 --- a/src/PluginAlgorithm.cpp +++ b/src/PluginAlgorithm.cpp @@ -100,7 +100,10 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){ PyObject *cfgDict = XMLNode2dict(_cfg.self); PyObject *retVal = PyObject_CallMethod(instance, "astra_init", "O",cfgDict); Py_DECREF(cfgDict); - if(retVal==NULL) return false; + if(retVal==NULL){ + logPythonError(); + return false; + } m_bIsInitialized = true; Py_DECREF(retVal); return m_bIsInitialized; @@ -108,8 +111,11 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){ void CPluginAlgorithm::run(int _iNrIterations){ if(instance==NULL) return; - PyObject *retVal = PyObject_CallMethod(instance, "astra_run", "i",_iNrIterations); - if(retVal==NULL) return; + PyObject *retVal = PyObject_CallMethod(instance, "run", "i",_iNrIterations); + if(retVal==NULL){ + logPythonError(); + return; + } Py_DECREF(retVal); } @@ -157,18 +163,6 @@ CPluginAlgorithmFactory::~CPluginAlgorithmFactory(){ if(six!=NULL) Py_DECREF(six); } -bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){ - PyObject *str = PyBytes_FromString(className.c_str()); - PyDict_SetItemString(pluginDict, name.c_str(), str); - Py_DECREF(str); - return true; -} - -bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){ - PyDict_SetItemString(pluginDict, name.c_str(), className); - return true; -} - PyObject * getClassFromString(std::string str){ std::vector items; boost::split(items, str, boost::is_any_of(".")); @@ -190,6 +184,43 @@ PyObject * getClassFromString(std::string str){ return pyclass; } +bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){ + PyObject *str = PyBytes_FromString(className.c_str()); + PyDict_SetItemString(pluginDict, name.c_str(), str); + Py_DECREF(str); + return true; +} + +bool CPluginAlgorithmFactory::registerPlugin(std::string className){ + PyObject *pyclass = getClassFromString(className); + if(pyclass==NULL) return false; + bool ret = registerPluginClass(pyclass); + Py_DECREF(pyclass); + return ret; +} + +bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){ + PyDict_SetItemString(pluginDict, name.c_str(), className); + return true; +} + +bool CPluginAlgorithmFactory::registerPluginClass(PyObject * className){ + PyObject *astra_name = PyObject_GetAttrString(className,"astra_name"); + if(astra_name==NULL){ + logPythonError(); + return false; + } + PyObject *retb = PyObject_CallMethod(six,"b","O",astra_name); + if(retb!=NULL){ + PyDict_SetItemString(pluginDict,PyBytes_AsString(retb),className); + Py_DECREF(retb); + }else{ + logPythonError(); + } + Py_DECREF(astra_name); + return true; +} + CPluginAlgorithm * CPluginAlgorithmFactory::getPlugin(std::string name){ PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); if(className==NULL) return NULL; @@ -212,12 +243,34 @@ PyObject * CPluginAlgorithmFactory::getRegistered(){ return pluginDict; } +std::map CPluginAlgorithmFactory::getRegisteredMap(){ + std::map ret; + PyObject *key, *value; + Py_ssize_t pos = 0; + while (PyDict_Next(pluginDict, &pos, &key, &value)) { + PyObject * keyb = PyObject_Bytes(key); + PyObject * valb = PyObject_Bytes(value); + ret[PyBytes_AsString(keyb)] = PyBytes_AsString(valb); + Py_DECREF(keyb); + Py_DECREF(valb); + } + return ret; +} + std::string CPluginAlgorithmFactory::getHelp(std::string name){ PyObject *className = PyDict_GetItemString(pluginDict, name.c_str()); - if(className==NULL) return ""; - std::string str = std::string(PyBytes_AsString(className)); + if(className==NULL){ + ASTRA_ERROR("Plugin %s not found!",name.c_str()); + return ""; + } std::string ret = ""; - PyObject *pyclass = getClassFromString(str); + PyObject *pyclass; + if(PyBytes_Check(className)){ + std::string str = std::string(PyBytes_AsString(className)); + pyclass = getClassFromString(str); + }else{ + pyclass = className; + } if(pyclass==NULL) return ""; if(inspect!=NULL && six!=NULL){ PyObject *retVal = PyObject_CallMethod(inspect,"getdoc","O",pyclass); @@ -228,9 +281,13 @@ std::string CPluginAlgorithmFactory::getHelp(std::string name){ ret = std::string(PyBytes_AsString(retb)); Py_DECREF(retb); } + }else{ + logPythonError(); } } - Py_DECREF(pyclass); + if(PyBytes_Check(className)){ + Py_DECREF(pyclass); + } return ret; } -- cgit v1.2.3