From 1b34498aaa93b95925991258fe542b62a9155aff Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 21 Mar 2019 15:16:39 +0000 Subject: BlockDataContainer can do algebra with DataContainers --- .../Python/ccpi/framework/BlockDataContainer.py | 27 ++++++++++++++----- Wrappers/Python/test/test_BlockDataContainer.py | 30 ++++++++++++++++++++-- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/Wrappers/Python/ccpi/framework/BlockDataContainer.py b/Wrappers/Python/ccpi/framework/BlockDataContainer.py index 358ba2d..f29f839 100755 --- a/Wrappers/Python/ccpi/framework/BlockDataContainer.py +++ b/Wrappers/Python/ccpi/framework/BlockDataContainer.py @@ -12,6 +12,7 @@ from __future__ import unicode_literals import numpy from numbers import Number import functools +from ccpi.framework import DataContainer #from ccpi.framework import AcquisitionData, ImageData #from ccpi.optimisation.operators import Operator, LinearOperator @@ -64,6 +65,8 @@ class BlockDataContainer(object): return len(self.containers) == len(other) elif isinstance(other, numpy.ndarray): return self.shape == other.shape + elif issubclass(other.__class__, DataContainer): + return self.get_item(0).shape == other.shape return len(self.containers) == len(other.containers) def get_item(self, row): @@ -75,24 +78,33 @@ class BlockDataContainer(object): return self.get_item(row) def add(self, other, *args, **kwargs): - assert self.is_compatible(other) + if not self.is_compatible(other): + raise ValueError('Incompatible for add') out = kwargs.get('out', None) #print ("args" , *args) if isinstance(other, Number): return type(self)(*[ el.add(other, *args, **kwargs) for el in self.containers], shape=self.shape) elif isinstance(other, list) or isinstance(other, numpy.ndarray): - return type(self)(*[ el.add(ot, out, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape) + return type(self)(*[ el.add(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape) + elif issubclass(other.__class__, DataContainer): + # try to do algebra with one DataContainer. Will raise error if not compatible + return type(self)(*[ el.add(other, *args, **kwargs) for el in self.containers], shape=self.shape) + return type(self)( *[ el.add(ot, out, *args, **kwargs) for el,ot in zip(self.containers,other.containers)], shape=self.shape) def subtract(self, other, *args, **kwargs): - assert self.is_compatible(other) + if not self.is_compatible(other): + raise ValueError('Incompatible for add') out = kwargs.get('out', None) if isinstance(other, Number): return type(self)(*[ el.subtract(other, out, *args, **kwargs) for el in self.containers], shape=self.shape) elif isinstance(other, list) or isinstance(other, numpy.ndarray): return type(self)(*[ el.subtract(ot, out, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape) + elif issubclass(other.__class__, DataContainer): + # try to do algebra with one DataContainer. Will raise error if not compatible + return type(self)(*[ el.subtract(other, *args, **kwargs) for el in self.containers], shape=self.shape) return type(self)(*[ el.subtract(ot, out, *args, **kwargs) for el,ot in zip(self.containers,other.containers)], shape=self.shape) @@ -105,6 +117,9 @@ class BlockDataContainer(object): return type(self)(*[ el.multiply(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape) elif isinstance(other, numpy.ndarray): return type(self)(*[ el.multiply(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape) + elif issubclass(other.__class__, DataContainer): + # try to do algebra with one DataContainer. Will raise error if not compatible + return type(self)(*[ el.multiply(other, *args, **kwargs) for el in self.containers], shape=self.shape) return type(self)(*[ el.multiply(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)], shape=self.shape) @@ -115,6 +130,9 @@ class BlockDataContainer(object): return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape) elif isinstance(other, list) or isinstance(other, numpy.ndarray): return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other)], shape=self.shape) + elif issubclass(other.__class__, DataContainer): + # try to do algebra with one DataContainer. Will raise error if not compatible + return type(self)(*[ el.divide(other, *args, **kwargs) for el in self.containers], shape=self.shape) return type(self)(*[ el.divide(ot, *args, **kwargs) for el,ot in zip(self.containers,other.containers)], shape=self.shape) @@ -138,13 +156,10 @@ class BlockDataContainer(object): ## unary operations def abs(self, *args, **kwargs): - out = kwargs.get('out', None) return type(self)(*[ el.abs(*args, **kwargs) for el in self.containers], shape=self.shape) def sign(self, *args, **kwargs): - out = kwargs.get('out', None) return type(self)(*[ el.sign(*args, **kwargs) for el in self.containers], shape=self.shape) def sqrt(self, *args, **kwargs): - out = kwargs.get('out', None) return type(self)(*[ el.sqrt(*args, **kwargs) for el in self.containers], shape=self.shape) def conjugate(self, out=None): return type(self)(*[el.conjugate() for el in self.containers], shape=self.shape) diff --git a/Wrappers/Python/test/test_BlockDataContainer.py b/Wrappers/Python/test/test_BlockDataContainer.py index 6c0bede..51d07fa 100755 --- a/Wrappers/Python/test/test_BlockDataContainer.py +++ b/Wrappers/Python/test/test_BlockDataContainer.py @@ -95,7 +95,7 @@ class TestBlockDataContainer(unittest.TestCase): def test_BlockDataContainer(self): print ("test block data container") ig0 = ImageGeometry(2,3,4) - ig1 = ImageGeometry(2,3,4) + ig1 = ImageGeometry(2,3,5) data0 = ImageData(geometry=ig0) data1 = ImageData(geometry=ig1) + 1 @@ -105,7 +105,33 @@ class TestBlockDataContainer(unittest.TestCase): cp0 = BlockDataContainer(data0,data1) cp1 = BlockDataContainer(data2,data3) - # + + cp2 = BlockDataContainer(data0+1, data2+1) + d = cp2 + data0 + self.assertEqual(d.get_item(0).as_array()[0][0][0], 1) + try: + d = cp2 + data1 + self.assertTrue(False) + except ValueError as ve: + print (ve) + self.assertTrue(True) + d = cp2 - data0 + self.assertEqual(d.get_item(0).as_array()[0][0][0], 1) + try: + d = cp2 - data1 + self.assertTrue(False) + except ValueError as ve: + print (ve) + self.assertTrue(True) + d = cp2 * data2 + self.assertEqual(d.get_item(0).as_array()[0][0][0], 2) + try: + d = cp2 * data1 + self.assertTrue(False) + except ValueError as ve: + print (ve) + self.assertTrue(True) + a = [ (el, ot) for el,ot in zip(cp0.containers,cp1.containers)] print (a[0][0].shape) #cp2 = BlockDataContainer(*a) -- cgit v1.2.3