summaryrefslogtreecommitdiffstats
path: root/python
diff options
context:
space:
mode:
authorWillem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl>2016-01-06 16:40:26 +0100
committerWillem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl>2016-01-19 16:50:17 +0100
commit76d182ace088b7ab6d64019079c049a0f9a29e69 (patch)
tree6056c1f266408eb13f0e43e1ac59c60b0887ad12 /python
parent687c5e244e46e51786afad77f5015cae9abad129 (diff)
downloadastra-76d182ace088b7ab6d64019079c049a0f9a29e69.tar.gz
astra-76d182ace088b7ab6d64019079c049a0f9a29e69.tar.bz2
astra-76d182ace088b7ab6d64019079c049a0f9a29e69.tar.xz
astra-76d182ace088b7ab6d64019079c049a0f9a29e69.zip
Change python set_gpu_index to match
Diffstat (limited to 'python')
-rw-r--r--python/astra/astra.py4
-rw-r--r--python/astra/astra_c.pyx21
2 files changed, 20 insertions, 5 deletions
diff --git a/python/astra/astra.py b/python/astra/astra.py
index 26b1ff0..9328b6b 100644
--- a/python/astra/astra.py
+++ b/python/astra/astra.py
@@ -49,10 +49,10 @@ def version(printToScreen=False):
"""
return a.version(printToScreen)
-def set_gpu_index(idx):
+def set_gpu_index(idx, memory=0):
"""Set default GPU index to use.
:param idx: GPU index
:type idx: :class:`int`
"""
- a.set_gpu_index(idx)
+ a.set_gpu_index(idx, memory)
diff --git a/python/astra/astra_c.pyx b/python/astra/astra_c.pyx
index 6b246b6..2a9c816 100644
--- a/python/astra/astra_c.pyx
+++ b/python/astra/astra_c.pyx
@@ -31,6 +31,7 @@ import six
from .utils import wrap_from_bytes
from libcpp.string cimport string
+from libcpp.vector cimport vector
from libcpp cimport bool
cdef extern from "astra/Globals.h" namespace "astra":
int getVersion()
@@ -43,6 +44,12 @@ IF HAVE_CUDA==True:
ELSE:
def setGPUIndex():
pass
+cdef extern from "astra/CompositeGeometryManager.h" namespace "astra":
+ cdef cppclass SGPUParams:
+ vector[int] GPUIndices
+ size_t memory
+cdef extern from "astra/CompositeGeometryManager.h" namespace "astra::CCompositeGeometryManager":
+ void setGlobalGPUParams(SGPUParams&)
def credits():
six.print_("""The ASTRA Toolbox has been developed at the University of Antwerp and CWI, Amsterdam by
@@ -70,8 +77,16 @@ def version(printToScreen=False):
else:
return getVersion()
-def set_gpu_index(idx):
+def set_gpu_index(idx, memory=0):
+ import types
+ import collections
+ cdef SGPUParams params
if use_cuda()==True:
- ret = setGPUIndex(idx)
+ if not isinstance(idx, collections.Iterable) or isinstance(idx, types.StringTypes):
+ idx = (idx,)
+ params.memory = memory
+ params.GPUIndices = idx
+ setGlobalGPUParams(params)
+ ret = setGPUIndex(params.GPUIndices[0])
if not ret:
- six.print_("Failed to set GPU " + str(idx))
+ six.print_("Failed to set GPU " + str(params.GPUIndices[0]))