From 3cae1d138c53a3fd042de3d2c9d9a07cf0650e0f Mon Sep 17 00:00:00 2001 From: "Daniel M. Pelt" Date: Tue, 24 Feb 2015 12:35:45 +0100 Subject: Added Python interface --- python/astra/matrix_c.pyx | 116 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 python/astra/matrix_c.pyx (limited to 'python/astra/matrix_c.pyx') diff --git a/python/astra/matrix_c.pyx b/python/astra/matrix_c.pyx new file mode 100644 index 0000000..b0d8bc4 --- /dev/null +++ b/python/astra/matrix_c.pyx @@ -0,0 +1,116 @@ +#----------------------------------------------------------------------- +#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam +# +#Author: Daniel M. Pelt +#Contact: D.M.Pelt@cwi.nl +#Website: http://dmpelt.github.io/pyastratoolbox/ +# +# +#This file is part of the Python interface to the +#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox"). +# +#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify +#it under the terms of the GNU General Public License as published by +#the Free Software Foundation, either version 3 of the License, or +#(at your option) any later version. +# +#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful, +#but WITHOUT ANY WARRANTY; without even the implied warranty of +#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +#GNU General Public License for more details. +# +#You should have received a copy of the GNU General Public License +#along with the Python interface to the ASTRA Toolbox. If not, see . +# +#----------------------------------------------------------------------- +# distutils: language = c++ +# distutils: libraries = astra + +import six +from six.moves import range +import numpy as np +import scipy.sparse as ss + +from libcpp cimport bool + +cimport PyMatrixManager +from .PyMatrixManager cimport CMatrixManager +from .PyIncludes cimport * +from .utils import wrap_from_bytes + +cdef CMatrixManager * manM = PyMatrixManager.getSingletonPtr() + + +def delete(ids): + try: + for i in ids: + manM.remove(i) + except TypeError: + manM.remove(ids) + +def clear(): + manM.clear() + +cdef int csr_matrix_to_astra(data,CSparseMatrix *mat) except -1: + if isinstance(data,ss.csr_matrix): + csrD = data + else: + csrD = data.tocsr() + if not mat.isInitialized(): + raise Exception("Couldn't initialize data object.") + if csrD.nnz > mat.m_lSize or csrD.shape[0] > mat.m_iHeight: + raise Exception("Matrix too large to store in this object.") + for i in range(len(csrD.indptr)): + mat.m_plRowStarts[i] = csrD.indptr[i] + for i in range(csrD.nnz): + mat.m_piColIndices[i] = csrD.indices[i] + mat.m_pfValues[i] = csrD.data[i] + +cdef astra_to_csr_matrix(CSparseMatrix *mat): + indptr = np.zeros(mat.m_iHeight+1,dtype=np.int) + indices = np.zeros(mat.m_plRowStarts[mat.m_iHeight],dtype=np.int) + data = np.zeros(mat.m_plRowStarts[mat.m_iHeight]) + for i in range(mat.m_iHeight+1): + indptr[i] = mat.m_plRowStarts[i] + for i in range(mat.m_plRowStarts[mat.m_iHeight]): + indices[i] = mat.m_piColIndices[i] + data[i] = mat.m_pfValues[i] + return ss.csr_matrix((data,indices,indptr),shape=(mat.m_iHeight,mat.m_iWidth)) + +def create(data): + cdef CSparseMatrix* pMatrix + pMatrix = new CSparseMatrix(data.shape[0], data.shape[1], data.nnz) + if not pMatrix.isInitialized(): + del pMatrix + raise Exception("Couldn't initialize data object.") + try: + csr_matrix_to_astra(data,pMatrix) + except: + del pMatrix + raise Exception("Failed to create data object.") + + return manM.store(pMatrix) + +cdef CSparseMatrix * getObject(i) except NULL: + cdef CSparseMatrix * pDataObject = manM.get(i) + if pDataObject == NULL: + raise Exception("Data object not found") + if not pDataObject.isInitialized(): + raise Exception("Data object not initialized properly.") + return pDataObject + + +def store(i,data): + cdef CSparseMatrix * pDataObject = getObject(i) + csr_matrix_to_astra(data,pDataObject) + +def get_size(i): + cdef CSparseMatrix * pDataObject = getObject(i) + return (pDataObject.m_iHeight,pDataObject.m_iWidth) + +def get(i): + cdef CSparseMatrix * pDataObject = getObject(i) + return astra_to_csr_matrix(pDataObject) + +def info(): + six.print_(wrap_from_bytes(manM.info())) -- cgit v1.2.3