summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
authorjakobsj <jakobsj@users.noreply.github.com>2018-04-11 13:49:28 +0100
committerGitHub <noreply@github.com>2018-04-11 13:49:28 +0100
commit761ddaacfb4b84f78639f79e4bacdeac15247a5b (patch)
tree49ab911e2f57e23ffcc9c9a7b9f09bcd89c5a702 /Wrappers/Python
parenta193b0766d5821f0c3699c757495db6c3bf1face (diff)
parentac789a915849b6814f1bc9587c24b1320b532950 (diff)
downloadframework-761ddaacfb4b84f78639f79e4bacdeac15247a5b.tar.gz
framework-761ddaacfb4b84f78639f79e4bacdeac15247a5b.tar.bz2
framework-761ddaacfb4b84f78639f79e4bacdeac15247a5b.tar.xz
framework-761ddaacfb4b84f78639f79e4bacdeac15247a5b.zip
Merge pull request #90 from vais-ral/doc_opti
First opti doc
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-xWrappers/Python/ccpi/optimisation/algs.py16
-rwxr-xr-xWrappers/Python/ccpi/optimisation/funcs.py20
2 files changed, 18 insertions, 18 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py
index 5942055..a45100c 100755
--- a/Wrappers/Python/ccpi/optimisation/algs.py
+++ b/Wrappers/Python/ccpi/optimisation/algs.py
@@ -20,7 +20,7 @@
import numpy
import time
-from ccpi.optimisation.funcs import BaseFunction
+from ccpi.optimisation.funcs import Function
def FISTA(x_init, f=None, g=None, opt=None):
'''Fast Iterative Shrinkage-Thresholding Algorithm
@@ -37,8 +37,8 @@ def FISTA(x_init, f=None, g=None, opt=None):
opt: additional algorithm
'''
# default inputs
- if f is None: f = BaseFunction()
- if g is None: g = BaseFunction()
+ if f is None: f = Function()
+ if g is None: g = Function()
# algorithmic parameters
if opt is None:
@@ -84,7 +84,7 @@ def FISTA(x_init, f=None, g=None, opt=None):
# time and criterion
timing[it] = time.time() - time0
- criter[it] = f.fun(x) + g.fun(x);
+ criter[it] = f(x) + g(x);
# stopping rule
#if np.linalg.norm(x - x_old) < tol * np.linalg.norm(x_old) and it > 10:
@@ -108,9 +108,9 @@ def FBPD(x_init, f=None, g=None, h=None, opt=None):
opt: additional algorithm
'''
# default inputs
- if f is None: f = BaseFunction()
- if g is None: g = BaseFunction()
- if h is None: h = BaseFunction()
+ if f is None: f = Function()
+ if g is None: g = Function()
+ if h is None: h = Function()
# algorithmic parameters
if opt is None:
@@ -156,7 +156,7 @@ def FBPD(x_init, f=None, g=None, h=None, opt=None):
# time and criterion
timing[it] = time.time() - t
- criter[it] = f.fun(x) + g.fun(x) + h.fun(h.op.direct(x));
+ criter[it] = f(x) + g(x) + h(h.op.direct(x));
# stopping rule
#if np.linalg.norm(x - x_old) < tol * np.linalg.norm(x_old) and it > 10:
diff --git a/Wrappers/Python/ccpi/optimisation/funcs.py b/Wrappers/Python/ccpi/optimisation/funcs.py
index 0dbe28f..d11d6c3 100755
--- a/Wrappers/Python/ccpi/optimisation/funcs.py
+++ b/Wrappers/Python/ccpi/optimisation/funcs.py
@@ -21,14 +21,14 @@ from ccpi.optimisation.ops import Identity, FiniteDiff2D
import numpy
-class BaseFunction(object):
+class Function(object):
def __init__(self):
self.op = Identity()
- def fun(self,x): return 0
+ def __call__(self,x): return 0
def grad(self,x): return 0
def prox(self,x,tau): return x
-class Norm2(BaseFunction):
+class Norm2(Function):
def __init__(self,
gamma=1.0,
@@ -37,7 +37,7 @@ class Norm2(BaseFunction):
self.gamma = gamma;
self.direction = direction;
- def fun(self, x):
+ def __call__(self, x):
xx = numpy.sqrt(numpy.sum(numpy.square(x.as_array()), self.direction,
keepdims=True))
@@ -63,7 +63,7 @@ class TV2D(Norm2):
# Define a class for squared 2-norm
-class Norm2sq(BaseFunction):
+class Norm2sq(Function):
'''
f(x) = c*||A*x-b||_2^2
@@ -93,19 +93,19 @@ class Norm2sq(BaseFunction):
#return 2*self.c*self.A.adjoint( self.A.direct(x) - self.b )
return 2.0*self.c*self.A.adjoint( self.A.direct(x) - self.b )
- def fun(self,x):
+ def __call__(self,x):
#return self.c* np.sum(np.square((self.A.direct(x) - self.b).ravel()))
return self.c*( ( (self.A.direct(x)-self.b)**2).sum() )
-class ZeroFun(BaseFunction):
+class ZeroFun(Function):
def __init__(self,gamma=0,L=1):
self.gamma = gamma
self.L = L
super(ZeroFun, self).__init__()
- def fun(self,x):
+ def __call__(self,x):
return 0
def prox(self,x,tau):
@@ -113,7 +113,7 @@ class ZeroFun(BaseFunction):
# A more interesting example, least squares plus 1-norm minimization.
# Define class to represent 1-norm including prox function
-class Norm1(BaseFunction):
+class Norm1(Function):
def __init__(self,gamma):
# Do nothing
@@ -121,7 +121,7 @@ class Norm1(BaseFunction):
self.L = 1
super(Norm1, self).__init__()
- def fun(self,x):
+ def __call__(self,x):
return self.gamma*(x.abs().sum())
def prox(self,x,tau):