diff options
| author | Daniel M. Pelt <D.M.Pelt@cwi.nl> | 2015-02-24 12:35:45 +0100 | 
|---|---|---|
| committer | Daniel M. Pelt <D.M.Pelt@cwi.nl> | 2015-02-24 12:35:45 +0100 | 
| commit | 3cae1d138c53a3fd042de3d2c9d9a07cf0650e0f (patch) | |
| tree | 72f1bf197b33cfb64f259089830910a9754e5893 /python/astra/utils.pyx | |
| parent | e212ab0d4f84adafa0a2fe11f5e16f856504769a (diff) | |
| download | astra-3cae1d138c53a3fd042de3d2c9d9a07cf0650e0f.tar.gz astra-3cae1d138c53a3fd042de3d2c9d9a07cf0650e0f.tar.bz2 astra-3cae1d138c53a3fd042de3d2c9d9a07cf0650e0f.tar.xz astra-3cae1d138c53a3fd042de3d2c9d9a07cf0650e0f.zip  | |
Added Python interface
Diffstat (limited to 'python/astra/utils.pyx')
| -rw-r--r-- | python/astra/utils.pyx | 260 | 
1 files changed, 260 insertions, 0 deletions
diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx new file mode 100644 index 0000000..53e84a9 --- /dev/null +++ b/python/astra/utils.pyx @@ -0,0 +1,260 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>. +# +#----------------------------------------------------------------------- +# distutils: language = c++ +# distutils: libraries = astra + +import numpy as np +import six +from libcpp.string cimport string +from libcpp.list cimport list +from libcpp.vector cimport vector +from cython.operator cimport dereference as deref +from cpython.version cimport PY_MAJOR_VERSION + +cimport PyXMLDocument +from .PyXMLDocument cimport XMLDocument +from .PyXMLDocument cimport XMLNode +from .PyIncludes cimport * + + +cdef XMLDocument * dict2XML(string rootname, dc): +    cdef XMLDocument * doc = PyXMLDocument.createDocument(rootname) +    cdef XMLNode * node = doc.getRootNode() +    try: +        readDict(node, dc) +    except: +        six.print_('Error reading XML') +        del doc +        doc = NULL +    finally: +        del node +    return doc + +def convert_item(item): +    if isinstance(item, six.string_types): +        return item.encode('ascii') + +    if type(item) is not dict: +        return item + +    out_dict = {} +    for k in item: +        out_dict[convert_item(k)] = convert_item(item[k]) +    return out_dict + + +def wrap_to_bytes(value): +    if isinstance(value, six.binary_type): +        return value +    s = str(value) +    if PY_MAJOR_VERSION == 3: +        s = s.encode('ascii') +    return s + + +def wrap_from_bytes(value): +    s = value +    if PY_MAJOR_VERSION == 3: +        s = s.decode('ascii') +    return s + + +cdef void readDict(XMLNode * root, _dc): +    cdef XMLNode * listbase +    cdef XMLNode * itm +    cdef int i +    cdef int j + +    dc = convert_item(_dc) +    for item in dc: +        val = dc[item] +        if isinstance(val, np.ndarray): +            if val.size == 0: +                break +            listbase = root.addChildNode(item) +            listbase.addAttribute(< string > six.b('listsize'), < float32 > val.size) +            index = 0 +            if val.ndim == 2: +                for i in range(val.shape[0]): +                    for j in range(val.shape[1]): +                        itm = listbase.addChildNode(six.b('ListItem')) +                        itm.addAttribute(< string > six.b('index'), < float32 > index) +                        itm.addAttribute( < string > six.b('value'), < float32 > val[i, j]) +                        index += 1 +                        del itm +            elif val.ndim == 1: +                for i in range(val.shape[0]): +                    itm = listbase.addChildNode(six.b('ListItem')) +                    itm.addAttribute(< string > six.b('index'), < float32 > index) +                    itm.addAttribute(< string > six.b('value'), < float32 > val[i]) +                    index += 1 +                    del itm +            else: +                raise Exception("Only 1 or 2 dimensions are allowed") +            del listbase +        elif isinstance(val, dict): +            if item == six.b('option') or item == six.b('options') or item == six.b('Option') or item == six.b('Options'): +                readOptions(root, val) +            else: +                itm = root.addChildNode(item) +                readDict(itm, val) +                del itm +        else: +            if item == six.b('type'): +                root.addAttribute(< string > six.b('type'), <string> wrap_to_bytes(val)) +            else: +                itm = root.addChildNode(item, wrap_to_bytes(val)) +                del itm + +cdef void readOptions(XMLNode * node, dc): +    cdef XMLNode * listbase +    cdef XMLNode * itm +    cdef int i +    cdef int j +    for item in dc: +        val = dc[item] +        if node.hasOption(item): +            raise Exception('Duplicate Option: %s' % item) +        if isinstance(val, np.ndarray): +            if val.size == 0: +                break +            listbase = node.addChildNode(six.b('Option')) +            listbase.addAttribute(< string > six.b('key'), < string > item) +            listbase.addAttribute(< string > six.b('listsize'), < float32 > val.size) +            index = 0 +            if val.ndim == 2: +                for i in range(val.shape[0]): +                    for j in range(val.shape[1]): +                        itm = listbase.addChildNode(six.b('ListItem')) +                        itm.addAttribute(< string > six.b('index'), < float32 > index) +                        itm.addAttribute( < string > six.b('value'), < float32 > val[i, j]) +                        index += 1 +                        del itm +            elif val.ndim == 1: +                for i in range(val.shape[0]): +                    itm = listbase.addChildNode(six.b('ListItem')) +                    itm.addAttribute(< string > six.b('index'), < float32 > index) +                    itm.addAttribute(< string > six.b('value'), < float32 > val[i]) +                    index += 1 +                    del itm +            else: +                raise Exception("Only 1 or 2 dimensions are allowed") +            del listbase +        else: +            node.addOption(item, wrap_to_bytes(val)) + +cdef vectorToNumpy(vector[float32] inp): +    cdef int i +    cdef int sz = inp.size() +    ret = np.empty(sz) +    for i in range(sz): +        ret[i] = inp[i] +    return ret + +cdef XMLNode2dict(XMLNode * node): +    cdef XMLNode * subnode +    cdef list[XMLNode * ] nodes +    cdef list[XMLNode * ].iterator it +    dct = {} +    if node.hasAttribute(six.b('type')): +        dct['type'] = node.getAttribute(six.b('type')) +    nodes = node.getNodes() +    it = nodes.begin() +    while it != nodes.end(): +        subnode = deref(it) +        if subnode.hasAttribute(six.b('listsize')): +            dct[subnode.getName( +                )] = vectorToNumpy(subnode.getContentNumericalArray()) +        else: +            dct[subnode.getName()] = subnode.getContent() +        del subnode +    return dct + +cdef XML2dict(XMLDocument * xml): +    cdef XMLNode * node = xml.getRootNode() +    dct = XMLNode2dict(node) +    del node; +    return dct; + +cdef createProjectionGeometryStruct(CProjectionGeometry2D * geom): +    cdef int i +    cdef CFanFlatVecProjectionGeometry2D * fanvecGeom +    # cdef SFanProjection* p +    dct = {} +    dct['DetectorCount'] = geom.getDetectorCount() +    if not geom.isOfType(< string > six.b('fanflat_vec')): +        dct['DetectorWidth'] = geom.getDetectorWidth() +        angles = np.empty(geom.getProjectionAngleCount()) +        for i in range(geom.getProjectionAngleCount()): +            angles[i] = geom.getProjectionAngle(i) +        dct['ProjectionAngles'] = angles +    else: +        raise Exception("Not yet implemented") +        # fanvecGeom = <CFanFlatVecProjectionGeometry2D*> geom +        # vecs = np.empty(fanvecGeom.getProjectionAngleCount()*6) +        # iDetCount = pVecGeom.getDetectorCount() +        # for i in range(fanvecGeom.getProjectionAngleCount()): +        #	p = &fanvecGeom.getProjectionVectors()[i]; +        #	out[6*i + 0] = p.fSrcX +        #	out[6*i + 1] = p.fSrcY +        #	out[6*i + 2] = p.fDetSX + 0.5f*iDetCount*p.fDetUX +        #	out[6*i + 3] = p.fDetSY + 0.5f*iDetCount*p.fDetUY +        #	out[6*i + 4] = p.fDetUX +        #	out[6*i + 5] = p.fDetUY +        # dct['Vectors'] = vecs +    if (geom.isOfType(< string > six.b('parallel'))): +        dct["type"] = "parallel" +    elif (geom.isOfType(< string > six.b('fanflat'))): +        raise Exception("Not yet implemented") +        # astra::CFanFlatProjectionGeometry2D* pFanFlatGeom = dynamic_cast<astra::CFanFlatProjectionGeometry2D*>(_pProjGeom) +        # mGeometryInfo["DistanceOriginSource"] = mxCreateDoubleScalar(pFanFlatGeom->getOriginSourceDistance()) +        # mGeometryInfo["DistanceOriginDetector"] = +        # mxCreateDoubleScalar(pFanFlatGeom->getOriginDetectorDistance()) +        dct["type"] = "fanflat" +    elif (geom.isOfType(< string > six.b('sparse_matrix'))): +        raise Exception("Not yet implemented") +        # astra::CSparseMatrixProjectionGeometry2D* pSparseMatrixGeom = +        # dynamic_cast<astra::CSparseMatrixProjectionGeometry2D*>(_pProjGeom); +        dct["type"] = "sparse_matrix" +        # dct["MatrixID"] = +        # mxCreateDoubleScalar(CMatrixManager::getSingleton().getIndex(pSparseMatrixGeom->getMatrix())) +    elif(geom.isOfType(< string > six.b('fanflat_vec'))): +        dct["type"] = "fanflat_vec" +    return dct + +cdef createVolumeGeometryStruct(CVolumeGeometry2D * geom): +    mGeometryInfo = {} +    mGeometryInfo["GridColCount"] = geom.getGridColCount() +    mGeometryInfo["GridRowCount"] = geom.getGridRowCount() + +    mGeometryOptions = {} +    mGeometryOptions["WindowMinX"] = geom.getWindowMinX() +    mGeometryOptions["WindowMaxX"] = geom.getWindowMaxX() +    mGeometryOptions["WindowMinY"] = geom.getWindowMinY() +    mGeometryOptions["WindowMaxY"] = geom.getWindowMaxY() + +    mGeometryInfo["option"] = mGeometryOptions +    return mGeometryInfo  | 
