summaryrefslogtreecommitdiffstats
path: root/Wrappers
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-03-04 13:12:56 +0000
committerGitHub <noreply@github.com>2019-03-04 13:12:56 +0000
commit3765455f542b627450e36a76863922d955f52292 (patch)
tree17dea78698fda41ecc4a18acad34417304744abc /Wrappers
parent6d609d54f828882ec46e11af4d3e09fc83a20535 (diff)
downloadframework-3765455f542b627450e36a76863922d955f52292.tar.gz
framework-3765455f542b627450e36a76863922d955f52292.tar.bz2
framework-3765455f542b627450e36a76863922d955f52292.tar.xz
framework-3765455f542b627450e36a76863922d955f52292.zip
added dot product between DataContainer s (#215)
* added dot product between datacontainers closes #208 implements dot product by flattening the data in a vector and calculating the inner product on the vectors added unittest * use more efficient ravel than flatten
Diffstat (limited to 'Wrappers')
-rw-r--r--Wrappers/Python/ccpi/framework.py26
-rwxr-xr-xWrappers/Python/test/test_DataContainer.py22
2 files changed, 48 insertions, 0 deletions
diff --git a/Wrappers/Python/ccpi/framework.py b/Wrappers/Python/ccpi/framework.py
index 71a9f3f..69a17dc 100644
--- a/Wrappers/Python/ccpi/framework.py
+++ b/Wrappers/Python/ccpi/framework.py
@@ -746,6 +746,13 @@ class DataContainer(object):
def norm(self):
'''return the euclidean norm of the DataContainer viewed as a vector'''
return numpy.sqrt(self.squared_norm())
+ def dot(self, other, *args, **kwargs):
+ '''return the inner product of 2 DataContainers viewed as vectors'''
+ if self.shape == other.shape:
+ return numpy.dot(self.as_array().ravel(), other.as_array().ravel())
+ else:
+ raise ValueError('Shapes are not aligned: {} != {}'.format(self.shape, other.shape))
+
@@ -1265,3 +1272,22 @@ if __name__ == '__main__':
sino = AcquisitionData(geometry=sgeometry)
sino2 = sino.clone()
+ a0 = numpy.asarray([i for i in range(2*3*4)])
+ a1 = numpy.asarray([2*i for i in range(2*3*4)])
+
+
+ ds0 = DataContainer(numpy.reshape(a0,(2,3,4)))
+ ds1 = DataContainer(numpy.reshape(a1,(2,3,4)))
+
+ numpy.testing.assert_equal(ds0.dot(ds1), a0.dot(a1))
+
+ a2 = numpy.asarray([2*i for i in range(2*3*5)])
+ ds2 = DataContainer(numpy.reshape(a2,(2,3,5)))
+
+# # it should fail if the shape is wrong
+# try:
+# ds2.dot(ds0)
+# self.assertTrue(False)
+# except ValueError as ve:
+# self.assertTrue(True)
+
diff --git a/Wrappers/Python/test/test_DataContainer.py b/Wrappers/Python/test/test_DataContainer.py
index 3def054..f23179c 100755
--- a/Wrappers/Python/test/test_DataContainer.py
+++ b/Wrappers/Python/test/test_DataContainer.py
@@ -425,6 +425,28 @@ class TestDataContainer(unittest.TestCase):
res = False
print(err)
self.assertTrue(res)
+
+ def test_dot(self):
+ a0 = numpy.asarray([i for i in range(2*3*4)])
+ a1 = numpy.asarray([2*i for i in range(2*3*4)])
+
+
+ ds0 = DataContainer(numpy.reshape(a0,(2,3,4)))
+ ds1 = DataContainer(numpy.reshape(a1,(2,3,4)))
+
+ numpy.testing.assert_equal(ds0.dot(ds1), a0.dot(a1))
+
+ a2 = numpy.asarray([2*i for i in range(2*3*5)])
+ ds2 = DataContainer(numpy.reshape(a2,(2,3,5)))
+
+ # it should fail if the shape is wrong
+ try:
+ ds2.dot(ds0)
+ self.assertTrue(False)
+ except ValueError as ve:
+ self.assertTrue(True)
+
+
def test_ImageData(self):
# create ImageData from geometry