diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-03-04 13:12:56 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-03-04 13:12:56 +0000 |
commit | 3765455f542b627450e36a76863922d955f52292 (patch) | |
tree | 17dea78698fda41ecc4a18acad34417304744abc /Wrappers | |
parent | 6d609d54f828882ec46e11af4d3e09fc83a20535 (diff) | |
download | framework-3765455f542b627450e36a76863922d955f52292.tar.gz framework-3765455f542b627450e36a76863922d955f52292.tar.bz2 framework-3765455f542b627450e36a76863922d955f52292.tar.xz framework-3765455f542b627450e36a76863922d955f52292.zip |
added dot product between DataContainer s (#215)
* added dot product between datacontainers
closes #208
implements dot product by flattening the data in a vector and calculating
the inner product on the vectors
added unittest
* use more efficient ravel than flatten
Diffstat (limited to 'Wrappers')
-rw-r--r-- | Wrappers/Python/ccpi/framework.py | 26 | ||||
-rwxr-xr-x | Wrappers/Python/test/test_DataContainer.py | 22 |
2 files changed, 48 insertions, 0 deletions
diff --git a/Wrappers/Python/ccpi/framework.py b/Wrappers/Python/ccpi/framework.py index 71a9f3f..69a17dc 100644 --- a/Wrappers/Python/ccpi/framework.py +++ b/Wrappers/Python/ccpi/framework.py @@ -746,6 +746,13 @@ class DataContainer(object): def norm(self): '''return the euclidean norm of the DataContainer viewed as a vector''' return numpy.sqrt(self.squared_norm()) + def dot(self, other, *args, **kwargs): + '''return the inner product of 2 DataContainers viewed as vectors''' + if self.shape == other.shape: + return numpy.dot(self.as_array().ravel(), other.as_array().ravel()) + else: + raise ValueError('Shapes are not aligned: {} != {}'.format(self.shape, other.shape)) + @@ -1265,3 +1272,22 @@ if __name__ == '__main__': sino = AcquisitionData(geometry=sgeometry) sino2 = sino.clone() + a0 = numpy.asarray([i for i in range(2*3*4)]) + a1 = numpy.asarray([2*i for i in range(2*3*4)]) + + + ds0 = DataContainer(numpy.reshape(a0,(2,3,4))) + ds1 = DataContainer(numpy.reshape(a1,(2,3,4))) + + numpy.testing.assert_equal(ds0.dot(ds1), a0.dot(a1)) + + a2 = numpy.asarray([2*i for i in range(2*3*5)]) + ds2 = DataContainer(numpy.reshape(a2,(2,3,5))) + +# # it should fail if the shape is wrong +# try: +# ds2.dot(ds0) +# self.assertTrue(False) +# except ValueError as ve: +# self.assertTrue(True) + diff --git a/Wrappers/Python/test/test_DataContainer.py b/Wrappers/Python/test/test_DataContainer.py index 3def054..f23179c 100755 --- a/Wrappers/Python/test/test_DataContainer.py +++ b/Wrappers/Python/test/test_DataContainer.py @@ -425,6 +425,28 @@ class TestDataContainer(unittest.TestCase): res = False print(err) self.assertTrue(res) + + def test_dot(self): + a0 = numpy.asarray([i for i in range(2*3*4)]) + a1 = numpy.asarray([2*i for i in range(2*3*4)]) + + + ds0 = DataContainer(numpy.reshape(a0,(2,3,4))) + ds1 = DataContainer(numpy.reshape(a1,(2,3,4))) + + numpy.testing.assert_equal(ds0.dot(ds1), a0.dot(a1)) + + a2 = numpy.asarray([2*i for i in range(2*3*5)]) + ds2 = DataContainer(numpy.reshape(a2,(2,3,5))) + + # it should fail if the shape is wrong + try: + ds2.dot(ds0) + self.assertTrue(False) + except ValueError as ve: + self.assertTrue(True) + + def test_ImageData(self): # create ImageData from geometry |