From 3765455f542b627450e36a76863922d955f52292 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 4 Mar 2019 13:12:56 +0000 Subject: added dot product between DataContainer s (#215) * added dot product between datacontainers closes #208 implements dot product by flattening the data in a vector and calculating the inner product on the vectors added unittest * use more efficient ravel than flatten --- Wrappers/Python/ccpi/framework.py | 26 ++++++++++++++++++++++++++ Wrappers/Python/test/test_DataContainer.py | 22 ++++++++++++++++++++++ 2 files changed, 48 insertions(+) (limited to 'Wrappers') diff --git a/Wrappers/Python/ccpi/framework.py b/Wrappers/Python/ccpi/framework.py index 71a9f3f..69a17dc 100644 --- a/Wrappers/Python/ccpi/framework.py +++ b/Wrappers/Python/ccpi/framework.py @@ -746,6 +746,13 @@ class DataContainer(object): def norm(self): '''return the euclidean norm of the DataContainer viewed as a vector''' return numpy.sqrt(self.squared_norm()) + def dot(self, other, *args, **kwargs): + '''return the inner product of 2 DataContainers viewed as vectors''' + if self.shape == other.shape: + return numpy.dot(self.as_array().ravel(), other.as_array().ravel()) + else: + raise ValueError('Shapes are not aligned: {} != {}'.format(self.shape, other.shape)) + @@ -1265,3 +1272,22 @@ if __name__ == '__main__': sino = AcquisitionData(geometry=sgeometry) sino2 = sino.clone() + a0 = numpy.asarray([i for i in range(2*3*4)]) + a1 = numpy.asarray([2*i for i in range(2*3*4)]) + + + ds0 = DataContainer(numpy.reshape(a0,(2,3,4))) + ds1 = DataContainer(numpy.reshape(a1,(2,3,4))) + + numpy.testing.assert_equal(ds0.dot(ds1), a0.dot(a1)) + + a2 = numpy.asarray([2*i for i in range(2*3*5)]) + ds2 = DataContainer(numpy.reshape(a2,(2,3,5))) + +# # it should fail if the shape is wrong +# try: +# ds2.dot(ds0) +# self.assertTrue(False) +# except ValueError as ve: +# self.assertTrue(True) + diff --git a/Wrappers/Python/test/test_DataContainer.py b/Wrappers/Python/test/test_DataContainer.py index 3def054..f23179c 100755 --- a/Wrappers/Python/test/test_DataContainer.py +++ b/Wrappers/Python/test/test_DataContainer.py @@ -425,6 +425,28 @@ class TestDataContainer(unittest.TestCase): res = False print(err) self.assertTrue(res) + + def test_dot(self): + a0 = numpy.asarray([i for i in range(2*3*4)]) + a1 = numpy.asarray([2*i for i in range(2*3*4)]) + + + ds0 = DataContainer(numpy.reshape(a0,(2,3,4))) + ds1 = DataContainer(numpy.reshape(a1,(2,3,4))) + + numpy.testing.assert_equal(ds0.dot(ds1), a0.dot(a1)) + + a2 = numpy.asarray([2*i for i in range(2*3*5)]) + ds2 = DataContainer(numpy.reshape(a2,(2,3,5))) + + # it should fail if the shape is wrong + try: + ds2.dot(ds0) + self.assertTrue(False) + except ValueError as ve: + self.assertTrue(True) + + def test_ImageData(self): # create ImageData from geometry -- cgit v1.2.3