From 783e5842055bc6991a1963c649fd96d012c27c98 Mon Sep 17 00:00:00 2001 From: Jakob Jorgensen Date: Tue, 10 Apr 2018 07:30:43 +0100 Subject: Replaces .fun in funcs by .__call__ --- Wrappers/Python/ccpi/optimisation/algs.py | 4 ++-- Wrappers/Python/ccpi/optimisation/funcs.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) (limited to 'Wrappers/Python') diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py index 5942055..6cae65a 100755 --- a/Wrappers/Python/ccpi/optimisation/algs.py +++ b/Wrappers/Python/ccpi/optimisation/algs.py @@ -84,7 +84,7 @@ def FISTA(x_init, f=None, g=None, opt=None): # time and criterion timing[it] = time.time() - time0 - criter[it] = f.fun(x) + g.fun(x); + criter[it] = f(x) + g(x); # stopping rule #if np.linalg.norm(x - x_old) < tol * np.linalg.norm(x_old) and it > 10: @@ -156,7 +156,7 @@ def FBPD(x_init, f=None, g=None, h=None, opt=None): # time and criterion timing[it] = time.time() - t - criter[it] = f.fun(x) + g.fun(x) + h.fun(h.op.direct(x)); + criter[it] = f(x) + g(x) + h(h.op.direct(x)); # stopping rule #if np.linalg.norm(x - x_old) < tol * np.linalg.norm(x_old) and it > 10: diff --git a/Wrappers/Python/ccpi/optimisation/funcs.py b/Wrappers/Python/ccpi/optimisation/funcs.py index 0dbe28f..da6a3de 100755 --- a/Wrappers/Python/ccpi/optimisation/funcs.py +++ b/Wrappers/Python/ccpi/optimisation/funcs.py @@ -24,7 +24,7 @@ import numpy class BaseFunction(object): def __init__(self): self.op = Identity() - def fun(self,x): return 0 + def __call__(self,x): return 0 def grad(self,x): return 0 def prox(self,x,tau): return x @@ -37,7 +37,7 @@ class Norm2(BaseFunction): self.gamma = gamma; self.direction = direction; - def fun(self, x): + def __call__(self, x): xx = numpy.sqrt(numpy.sum(numpy.square(x.as_array()), self.direction, keepdims=True)) @@ -93,7 +93,7 @@ class Norm2sq(BaseFunction): #return 2*self.c*self.A.adjoint( self.A.direct(x) - self.b ) return 2.0*self.c*self.A.adjoint( self.A.direct(x) - self.b ) - def fun(self,x): + def __call__(self,x): #return self.c* np.sum(np.square((self.A.direct(x) - self.b).ravel())) return self.c*( ( (self.A.direct(x)-self.b)**2).sum() ) @@ -105,7 +105,7 @@ class ZeroFun(BaseFunction): self.L = L super(ZeroFun, self).__init__() - def fun(self,x): + def __call__(self,x): return 0 def prox(self,x,tau): @@ -121,7 +121,7 @@ class Norm1(BaseFunction): self.L = 1 super(Norm1, self).__init__() - def fun(self,x): + def __call__(self,x): return self.gamma*(x.abs().sum()) def prox(self,x,tau): -- cgit v1.2.3