diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2019-03-14 15:34:04 +0000 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2019-03-14 15:34:04 +0000 |
commit | f2b62709a1e4a9529dcee17b2cf7de87a5f02d2c (patch) | |
tree | eca0629ff88154b8b17df14f5f8208d419cee014 | |
parent | 9769759d3f7f1eab53631627474eade8e4c6f96a (diff) | |
download | framework-f2b62709a1e4a9529dcee17b2cf7de87a5f02d2c.tar.gz framework-f2b62709a1e4a9529dcee17b2cf7de87a5f02d2c.tar.bz2 framework-f2b62709a1e4a9529dcee17b2cf7de87a5f02d2c.tar.xz framework-f2b62709a1e4a9529dcee17b2cf7de87a5f02d2c.zip |
removed alpha parameter
-rw-r--r-- | Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py | 39 | ||||
-rwxr-xr-x | Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py | 39 |
2 files changed, 45 insertions, 33 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py index 5817317..54c947a 100644 --- a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py +++ b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py @@ -19,34 +19,41 @@ class SimpleL2NormSq(Function): def __init__(self, alpha=1): super(SimpleL2NormSq, self).__init__() - self.alpha = alpha - # Lispchitz constant of gradient - self.L = 2*self.alpha + self.L = 2 def __call__(self, x): return self.alpha * x.power(2).sum() - def gradient(self,x): - return 2 * self.alpha * x + def gradient(self,x, out=None): + if out is None: + return 2 * x + else: + out.fill(2*x) def convex_conjugate(self,x): - return (1/(4*self.alpha)) * x.power(2).sum() - - def proximal(self, x, tau): - return x.divide(1+2*tau*self.alpha) + return (1/4) * x.squared_norm() + + def proximal(self, x, tau, out=None): + if out is None: + return x.divide(1+2*tau) + else: + x.divide(1+2*tau, out=out) - def proximal_conjugate(self, x, tau): - return x.divide(1 + tau/(2*self.alpha) ) + def proximal_conjugate(self, x, tau, out=None): + if out is None: + return x.divide(1 + tau/2) + else: + x.divide(1+tau/2, out=out) + ############################ L2NORM FUNCTIONS ############################# class L2NormSq(SimpleL2NormSq): - def __init__(self, alpha, **kwargs): + def __init__(self, **kwargs): - super(L2NormSq, self).__init__(alpha) - self.alpha = alpha + super(L2NormSq, self).__init__() self.b = kwargs.get('b',None) def __call__(self, x): @@ -59,9 +66,9 @@ class L2NormSq(SimpleL2NormSq): def gradient(self, x): if self.b is None: - return 2*self.alpha * x + return 2 * x else: - return 2*self.alpha * (x - self.b) + return 2 * (x - self.b) def convex_conjugate(self, x): diff --git a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py index f2e39fb..7e2f20a 100755 --- a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py +++ b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py @@ -4,22 +4,17 @@ import numpy class ScaledFunction(object):
'''ScaledFunction
- A class to represent the scalar multiplication of an Operator with a scalar.
- It holds an operator and a scalar. Basically it returns the multiplication
- of the result of direct and adjoint of the operator with the scalar.
- For the rest it behaves like the operator it holds.
+ A class to represent the scalar multiplication of an Function with a scalar.
+ It holds a function and a scalar. Basically it returns the multiplication
+ of the product of the function __call__, convex_conjugate and gradient with the scalar.
+ For the rest it behaves like the function it holds.
Args:
- operator (Operator): a Operator or LinearOperator
+ function (Function): a Function or BlockOperator
scalar (Number): a scalar multiplier
Example:
The scaled operator behaves like the following:
- sop = ScaledOperator(operator, scalar)
- sop.direct(x) = scalar * operator.direct(x)
- sop.adjoint(x) = scalar * operator.adjoint(x)
- sop.norm() = operator.norm()
- sop.range_geometry() = operator.range_geometry()
- sop.domain_geometry() = operator.domain_geometry()
+
'''
def __init__(self, function, scalar):
super(ScaledFunction, self).__init__()
@@ -30,31 +25,41 @@ class ScaledFunction(object): self.function = function
def __call__(self,x, out=None):
+ '''Evaluates the function at x '''
return self.scalar * self.function(x)
- def call_adjoint(self, x, out=None):
- return self.scalar * self.function.call_adjoint(x, out=out)
-
def convex_conjugate(self, x, out=None):
- return self.scalar * self.function.convex_conjugate(x, out=out)
+ '''returns the convex_conjugate of the scaled function '''
+ if out is None:
+ return self.scalar * self.function.convex_conjugate(x/self.scalar, out=out)
+ else:
+ out.fill(self.function.convex_conjugate(x/self.scalar))
+ out *= self.scalar
def proximal_conjugate(self, x, tau, out = None):
- '''TODO check if this is mathematically correct'''
+ '''This returns the proximal operator for the function at x, tau
+
+ TODO check if this is mathematically correct'''
return self.function.proximal_conjugate(x, tau, out=out)
def grad(self, x):
+ '''Alias of gradient(x,None)'''
warnings.warn('''This method will disappear in following
versions of the CIL. Use gradient instead''', DeprecationWarning)
return self.gradient(x, out=None)
def prox(self, x, tau):
+ '''Alias of proximal(x, tau, None)'''
warnings.warn('''This method will disappear in following
versions of the CIL. Use proximal instead''', DeprecationWarning)
return self.proximal(x, out=None)
def gradient(self, x, out=None):
+ '''Returns the gradient of the function at x, if the function is differentiable'''
return self.scalar * self.function.gradient(x, out=out)
def proximal(self, x, tau, out=None):
- '''TODO check if this is mathematically correct'''
+ '''This returns the proximal operator for the function at x, tau
+
+ TODO check if this is mathematically correct'''
return self.function.proximal(x, tau, out=out)
|