summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJakob Jorgensen <jakob.jorgensen@manchester.ac.uk>2018-04-10 07:30:43 +0100
committerJakob Jorgensen <jakob.jorgensen@manchester.ac.uk>2018-04-10 07:30:43 +0100
commit783e5842055bc6991a1963c649fd96d012c27c98 (patch)
tree4f37dbbbc1c087546aacab7f4b3b31c5f87d3818
parentdc958934237623296deca7ff04f25451ced25055 (diff)
downloadframework-783e5842055bc6991a1963c649fd96d012c27c98.tar.gz
framework-783e5842055bc6991a1963c649fd96d012c27c98.tar.bz2
framework-783e5842055bc6991a1963c649fd96d012c27c98.tar.xz
framework-783e5842055bc6991a1963c649fd96d012c27c98.zip
Replaces .fun in funcs by .__call__
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algs.py4
-rwxr-xr-xWrappers/Python/ccpi/optimisation/funcs.py10
2 files changed, 7 insertions, 7 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py
index 5942055..6cae65a 100755
--- a/Wrappers/Python/ccpi/optimisation/algs.py
+++ b/Wrappers/Python/ccpi/optimisation/algs.py
@@ -84,7 +84,7 @@ def FISTA(x_init, f=None, g=None, opt=None):
# time and criterion
timing[it] = time.time() - time0
- criter[it] = f.fun(x) + g.fun(x);
+ criter[it] = f(x) + g(x);
# stopping rule
#if np.linalg.norm(x - x_old) < tol * np.linalg.norm(x_old) and it > 10:
@@ -156,7 +156,7 @@ def FBPD(x_init, f=None, g=None, h=None, opt=None):
# time and criterion
timing[it] = time.time() - t
- criter[it] = f.fun(x) + g.fun(x) + h.fun(h.op.direct(x));
+ criter[it] = f(x) + g(x) + h(h.op.direct(x));
# stopping rule
#if np.linalg.norm(x - x_old) < tol * np.linalg.norm(x_old) and it > 10:
diff --git a/Wrappers/Python/ccpi/optimisation/funcs.py b/Wrappers/Python/ccpi/optimisation/funcs.py
index 0dbe28f..da6a3de 100755
--- a/Wrappers/Python/ccpi/optimisation/funcs.py
+++ b/Wrappers/Python/ccpi/optimisation/funcs.py
@@ -24,7 +24,7 @@ import numpy
class BaseFunction(object):
def __init__(self):
self.op = Identity()
- def fun(self,x): return 0
+ def __call__(self,x): return 0
def grad(self,x): return 0
def prox(self,x,tau): return x
@@ -37,7 +37,7 @@ class Norm2(BaseFunction):
self.gamma = gamma;
self.direction = direction;
- def fun(self, x):
+ def __call__(self, x):
xx = numpy.sqrt(numpy.sum(numpy.square(x.as_array()), self.direction,
keepdims=True))
@@ -93,7 +93,7 @@ class Norm2sq(BaseFunction):
#return 2*self.c*self.A.adjoint( self.A.direct(x) - self.b )
return 2.0*self.c*self.A.adjoint( self.A.direct(x) - self.b )
- def fun(self,x):
+ def __call__(self,x):
#return self.c* np.sum(np.square((self.A.direct(x) - self.b).ravel()))
return self.c*( ( (self.A.direct(x)-self.b)**2).sum() )
@@ -105,7 +105,7 @@ class ZeroFun(BaseFunction):
self.L = L
super(ZeroFun, self).__init__()
- def fun(self,x):
+ def __call__(self,x):
return 0
def prox(self,x,tau):
@@ -121,7 +121,7 @@ class Norm1(BaseFunction):
self.L = 1
super(Norm1, self).__init__()
- def fun(self,x):
+ def __call__(self,x):
return self.gamma*(x.abs().sum())
def prox(self,x,tau):