summaryrefslogtreecommitdiffstats
path: root/python/astra/ASTRAProjector.py
blob: f2826181fc7c8410de2db180d939bfef3e261d04 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#-----------------------------------------------------------------------
#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam
#
#Author: Daniel M. Pelt
#Contact: D.M.Pelt@cwi.nl
#Website: http://dmpelt.github.io/pyastratoolbox/
#
#
#This file is part of the Python interface to the
#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox").
#
#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify
#it under the terms of the GNU General Public License as published by
#the Free Software Foundation, either version 3 of the License, or
#(at your option) any later version.
#
#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful,
#but WITHOUT ANY WARRANTY; without even the implied warranty of
#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#GNU General Public License for more details.
#
#You should have received a copy of the GNU General Public License
#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>.
#
#-----------------------------------------------------------------------

import math
from . import creators as ac
from . import data2d


class ASTRAProjector2DTranspose():
    """Implements the ``proj.T`` functionality.

    Do not use directly, since it can be accessed as member ``.T`` of
    an :class:`ASTRAProjector2D` object.

    """
    def __init__(self, parentProj):
        self.parentProj = parentProj

    def __mul__(self, data):
        return self.parentProj.backProject(data)


class ASTRAProjector2D(object):
    """Helps with various common ASTRA Toolbox 2D operations.

    This class can perform several often used toolbox operations, such as:

    * Forward projecting
    * Back projecting
    * Reconstructing

    Note that this class has a some computational overhead, because it
    copies a lot of data. If you use many repeated operations, directly
    using the PyAstraToolbox methods directly is faster.

    You can use this class as an abstracted weight matrix :math:`W`: multiplying an instance
    ``proj`` of this class by an image results in a forward projection of the image, and multiplying
    ``proj.T`` by a sinogram results in a backprojection of the sinogram::

        proj = ASTRAProjector2D(...)
        fp = proj*image
        bp = proj.T*sinogram

    :param proj_geom: The projection geometry.
    :type proj_geom: :class:`dict`
    :param vol_geom: The volume geometry.
    :type vol_geom: :class:`dict`
    :param proj_type: Projector type, such as ``'line'``, ``'linear'``, ...
    :type proj_type: :class:`string`
    """

    def __init__(self, proj_geom, vol_geom, proj_type):
        self.vol_geom = vol_geom
        self.recSize = vol_geom['GridColCount']
        self.angles = proj_geom['ProjectionAngles']
        self.nDet = proj_geom['DetectorCount']
        nexpow = int(pow(2, math.ceil(math.log(2 * self.nDet, 2))))
        self.filterSize = nexpow / 2 + 1
        self.nProj = self.angles.shape[0]
        self.proj_geom = proj_geom
        self.proj_id = ac.create_projector(proj_type, proj_geom, vol_geom)
        self.T = ASTRAProjector2DTranspose(self)

    def backProject(self, data):
        """Backproject a sinogram.

        :param data: The sinogram data or ID.
        :type data: :class:`numpy.ndarray` or :class:`int`
        :returns: :class:`numpy.ndarray` -- The backprojection.

        """
        vol_id, vol = ac.create_backprojection(
            data, self.proj_id, returnData=True)
        data2d.delete(vol_id)
        return vol

    def forwardProject(self, data):
        """Forward project an image.

        :param data: The image data or ID.
        :type data: :class:`numpy.ndarray` or :class:`int`
        :returns: :class:`numpy.ndarray` -- The forward projection.

        """
        sin_id, sino = ac.create_sino(data, self.proj_id, returnData=True)
        data2d.delete(sin_id)
        return sino

    def reconstruct(self, data, method, **kwargs):
        """Reconstruct an image from a sinogram.

        :param data: The sinogram data or ID.
        :type data: :class:`numpy.ndarray` or :class:`int`
        :param method: Name of the reconstruction algorithm.
        :type method: :class:`string`
        :param kwargs: Additional named parameters to pass to :func:`astra.creators.create_reconstruction`.
        :returns: :class:`numpy.ndarray` -- The reconstruction.

        Example of a SIRT reconstruction using CUDA::

            proj = ASTRAProjector2D(...)
            rec = proj.reconstruct(sinogram,'SIRT_CUDA',iterations=1000)

        """
        kwargs['returnData'] = True
        rec_id, rec = ac.create_reconstruction(
            method, self.proj_id, data, **kwargs)
        data2d.delete(rec_id)
        return rec

    def __mul__(self, data):
        return self.forwardProject(data)