summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py39
-rwxr-xr-xWrappers/Python/ccpi/optimisation/functions/ScaledFunction.py39
2 files changed, 45 insertions, 33 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
index 5817317..54c947a 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/L2NormSquared.py
@@ -19,34 +19,41 @@ class SimpleL2NormSq(Function):
def __init__(self, alpha=1):
super(SimpleL2NormSq, self).__init__()
- self.alpha = alpha
-
# Lispchitz constant of gradient
- self.L = 2*self.alpha
+ self.L = 2
def __call__(self, x):
return self.alpha * x.power(2).sum()
- def gradient(self,x):
- return 2 * self.alpha * x
+ def gradient(self,x, out=None):
+ if out is None:
+ return 2 * x
+ else:
+ out.fill(2*x)
def convex_conjugate(self,x):
- return (1/(4*self.alpha)) * x.power(2).sum()
-
- def proximal(self, x, tau):
- return x.divide(1+2*tau*self.alpha)
+ return (1/4) * x.squared_norm()
+
+ def proximal(self, x, tau, out=None):
+ if out is None:
+ return x.divide(1+2*tau)
+ else:
+ x.divide(1+2*tau, out=out)
- def proximal_conjugate(self, x, tau):
- return x.divide(1 + tau/(2*self.alpha) )
+ def proximal_conjugate(self, x, tau, out=None):
+ if out is None:
+ return x.divide(1 + tau/2)
+ else:
+ x.divide(1+tau/2, out=out)
+
############################ L2NORM FUNCTIONS #############################
class L2NormSq(SimpleL2NormSq):
- def __init__(self, alpha, **kwargs):
+ def __init__(self, **kwargs):
- super(L2NormSq, self).__init__(alpha)
- self.alpha = alpha
+ super(L2NormSq, self).__init__()
self.b = kwargs.get('b',None)
def __call__(self, x):
@@ -59,9 +66,9 @@ class L2NormSq(SimpleL2NormSq):
def gradient(self, x):
if self.b is None:
- return 2*self.alpha * x
+ return 2 * x
else:
- return 2*self.alpha * (x - self.b)
+ return 2 * (x - self.b)
def convex_conjugate(self, x):
diff --git a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
index f2e39fb..7e2f20a 100755
--- a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py
@@ -4,22 +4,17 @@ import numpy
class ScaledFunction(object):
'''ScaledFunction
- A class to represent the scalar multiplication of an Operator with a scalar.
- It holds an operator and a scalar. Basically it returns the multiplication
- of the result of direct and adjoint of the operator with the scalar.
- For the rest it behaves like the operator it holds.
+ A class to represent the scalar multiplication of an Function with a scalar.
+ It holds a function and a scalar. Basically it returns the multiplication
+ of the product of the function __call__, convex_conjugate and gradient with the scalar.
+ For the rest it behaves like the function it holds.
Args:
- operator (Operator): a Operator or LinearOperator
+ function (Function): a Function or BlockOperator
scalar (Number): a scalar multiplier
Example:
The scaled operator behaves like the following:
- sop = ScaledOperator(operator, scalar)
- sop.direct(x) = scalar * operator.direct(x)
- sop.adjoint(x) = scalar * operator.adjoint(x)
- sop.norm() = operator.norm()
- sop.range_geometry() = operator.range_geometry()
- sop.domain_geometry() = operator.domain_geometry()
+
'''
def __init__(self, function, scalar):
super(ScaledFunction, self).__init__()
@@ -30,31 +25,41 @@ class ScaledFunction(object):
self.function = function
def __call__(self,x, out=None):
+ '''Evaluates the function at x '''
return self.scalar * self.function(x)
- def call_adjoint(self, x, out=None):
- return self.scalar * self.function.call_adjoint(x, out=out)
-
def convex_conjugate(self, x, out=None):
- return self.scalar * self.function.convex_conjugate(x, out=out)
+ '''returns the convex_conjugate of the scaled function '''
+ if out is None:
+ return self.scalar * self.function.convex_conjugate(x/self.scalar, out=out)
+ else:
+ out.fill(self.function.convex_conjugate(x/self.scalar))
+ out *= self.scalar
def proximal_conjugate(self, x, tau, out = None):
- '''TODO check if this is mathematically correct'''
+ '''This returns the proximal operator for the function at x, tau
+
+ TODO check if this is mathematically correct'''
return self.function.proximal_conjugate(x, tau, out=out)
def grad(self, x):
+ '''Alias of gradient(x,None)'''
warnings.warn('''This method will disappear in following
versions of the CIL. Use gradient instead''', DeprecationWarning)
return self.gradient(x, out=None)
def prox(self, x, tau):
+ '''Alias of proximal(x, tau, None)'''
warnings.warn('''This method will disappear in following
versions of the CIL. Use proximal instead''', DeprecationWarning)
return self.proximal(x, out=None)
def gradient(self, x, out=None):
+ '''Returns the gradient of the function at x, if the function is differentiable'''
return self.scalar * self.function.gradient(x, out=out)
def proximal(self, x, tau, out=None):
- '''TODO check if this is mathematically correct'''
+ '''This returns the proximal operator for the function at x, tau
+
+ TODO check if this is mathematically correct'''
return self.function.proximal(x, tau, out=out)