diff options
author | jakobsj <jakobsj@users.noreply.github.com> | 2018-04-11 13:49:28 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-11 13:49:28 +0100 |
commit | 761ddaacfb4b84f78639f79e4bacdeac15247a5b (patch) | |
tree | 49ab911e2f57e23ffcc9c9a7b9f09bcd89c5a702 /Wrappers/Python | |
parent | a193b0766d5821f0c3699c757495db6c3bf1face (diff) | |
parent | ac789a915849b6814f1bc9587c24b1320b532950 (diff) | |
download | framework-761ddaacfb4b84f78639f79e4bacdeac15247a5b.tar.gz framework-761ddaacfb4b84f78639f79e4bacdeac15247a5b.tar.bz2 framework-761ddaacfb4b84f78639f79e4bacdeac15247a5b.tar.xz framework-761ddaacfb4b84f78639f79e4bacdeac15247a5b.zip |
Merge pull request #90 from vais-ral/doc_opti
First opti doc
Diffstat (limited to 'Wrappers/Python')
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/algs.py | 16 | ||||
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/funcs.py | 20 |
2 files changed, 18 insertions, 18 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/algs.py b/Wrappers/Python/ccpi/optimisation/algs.py index 5942055..a45100c 100755 --- a/Wrappers/Python/ccpi/optimisation/algs.py +++ b/Wrappers/Python/ccpi/optimisation/algs.py @@ -20,7 +20,7 @@ import numpy import time -from ccpi.optimisation.funcs import BaseFunction +from ccpi.optimisation.funcs import Function def FISTA(x_init, f=None, g=None, opt=None): '''Fast Iterative Shrinkage-Thresholding Algorithm @@ -37,8 +37,8 @@ def FISTA(x_init, f=None, g=None, opt=None): opt: additional algorithm ''' # default inputs - if f is None: f = BaseFunction() - if g is None: g = BaseFunction() + if f is None: f = Function() + if g is None: g = Function() # algorithmic parameters if opt is None: @@ -84,7 +84,7 @@ def FISTA(x_init, f=None, g=None, opt=None): # time and criterion timing[it] = time.time() - time0 - criter[it] = f.fun(x) + g.fun(x); + criter[it] = f(x) + g(x); # stopping rule #if np.linalg.norm(x - x_old) < tol * np.linalg.norm(x_old) and it > 10: @@ -108,9 +108,9 @@ def FBPD(x_init, f=None, g=None, h=None, opt=None): opt: additional algorithm ''' # default inputs - if f is None: f = BaseFunction() - if g is None: g = BaseFunction() - if h is None: h = BaseFunction() + if f is None: f = Function() + if g is None: g = Function() + if h is None: h = Function() # algorithmic parameters if opt is None: @@ -156,7 +156,7 @@ def FBPD(x_init, f=None, g=None, h=None, opt=None): # time and criterion timing[it] = time.time() - t - criter[it] = f.fun(x) + g.fun(x) + h.fun(h.op.direct(x)); + criter[it] = f(x) + g(x) + h(h.op.direct(x)); # stopping rule #if np.linalg.norm(x - x_old) < tol * np.linalg.norm(x_old) and it > 10: diff --git a/Wrappers/Python/ccpi/optimisation/funcs.py b/Wrappers/Python/ccpi/optimisation/funcs.py index 0dbe28f..d11d6c3 100755 --- a/Wrappers/Python/ccpi/optimisation/funcs.py +++ b/Wrappers/Python/ccpi/optimisation/funcs.py @@ -21,14 +21,14 @@ from ccpi.optimisation.ops import Identity, FiniteDiff2D import numpy -class BaseFunction(object): +class Function(object): def __init__(self): self.op = Identity() - def fun(self,x): return 0 + def __call__(self,x): return 0 def grad(self,x): return 0 def prox(self,x,tau): return x -class Norm2(BaseFunction): +class Norm2(Function): def __init__(self, gamma=1.0, @@ -37,7 +37,7 @@ class Norm2(BaseFunction): self.gamma = gamma; self.direction = direction; - def fun(self, x): + def __call__(self, x): xx = numpy.sqrt(numpy.sum(numpy.square(x.as_array()), self.direction, keepdims=True)) @@ -63,7 +63,7 @@ class TV2D(Norm2): # Define a class for squared 2-norm -class Norm2sq(BaseFunction): +class Norm2sq(Function): ''' f(x) = c*||A*x-b||_2^2 @@ -93,19 +93,19 @@ class Norm2sq(BaseFunction): #return 2*self.c*self.A.adjoint( self.A.direct(x) - self.b ) return 2.0*self.c*self.A.adjoint( self.A.direct(x) - self.b ) - def fun(self,x): + def __call__(self,x): #return self.c* np.sum(np.square((self.A.direct(x) - self.b).ravel())) return self.c*( ( (self.A.direct(x)-self.b)**2).sum() ) -class ZeroFun(BaseFunction): +class ZeroFun(Function): def __init__(self,gamma=0,L=1): self.gamma = gamma self.L = L super(ZeroFun, self).__init__() - def fun(self,x): + def __call__(self,x): return 0 def prox(self,x,tau): @@ -113,7 +113,7 @@ class ZeroFun(BaseFunction): # A more interesting example, least squares plus 1-norm minimization. # Define class to represent 1-norm including prox function -class Norm1(BaseFunction): +class Norm1(Function): def __init__(self,gamma): # Do nothing @@ -121,7 +121,7 @@ class Norm1(BaseFunction): self.L = 1 super(Norm1, self).__init__() - def fun(self,x): + def __call__(self,x): return self.gamma*(x.abs().sum()) def prox(self,x,tau): |