From 2bc9cce049c6ae588562ac88e089553a3dcc6d19 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 14 Mar 2019 14:51:11 +0000 Subject: added ScaledFunction --- .../Python/ccpi/optimisation/functions/Function.py | 48 +++++++++++++++++ .../ccpi/optimisation/functions/ScaledFunction.py | 60 ++++++++++++++++++++++ 2 files changed, 108 insertions(+) create mode 100755 Wrappers/Python/ccpi/optimisation/functions/Function.py create mode 100755 Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py (limited to 'Wrappers/Python') diff --git a/Wrappers/Python/ccpi/optimisation/functions/Function.py b/Wrappers/Python/ccpi/optimisation/functions/Function.py new file mode 100755 index 0000000..43ce900 --- /dev/null +++ b/Wrappers/Python/ccpi/optimisation/functions/Function.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# This work is part of the Core Imaging Library developed by +# Visual Analytics and Imaging System Group of the Science Technology +# Facilities Council, STFC + +# Copyright 2018-2019 Jakob Jorgensen, Daniil Kazantsev and Edoardo Pasca + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +class Function(object): + '''Abstract class representing a function + + + ''' + def __init__(self): + self.L = None + def __call__(self,x, out=None): + raise NotImplementedError + def call_adjoint(self, x, out=None): + raise NotImplementedError + def convex_conjugate(self, x, out=None): + raise NotImplementedError + def proximal_conjugate(self, x, tau, out = None): + raise NotImplementedError + def grad(self, x): + warnings.warn('''This method will disappear in following + versions of the CIL. Use gradient instead''', DeprecationWarning) + return self.gradient(x, out=None) + def prox(self, x, tau): + warnings.warn('''This method will disappear in following + versions of the CIL. Use proximal instead''', DeprecationWarning) + return self.proximal(x, out=None) + def gradient(self, x, out=None): + raise NotImplementedError + def proximal(self, x, tau, out=None): + raise NotImplementedError \ No newline at end of file diff --git a/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py new file mode 100755 index 0000000..f2e39fb --- /dev/null +++ b/Wrappers/Python/ccpi/optimisation/functions/ScaledFunction.py @@ -0,0 +1,60 @@ +from numbers import Number +import numpy + +class ScaledFunction(object): + '''ScaledFunction + + A class to represent the scalar multiplication of an Operator with a scalar. + It holds an operator and a scalar. Basically it returns the multiplication + of the result of direct and adjoint of the operator with the scalar. + For the rest it behaves like the operator it holds. + + Args: + operator (Operator): a Operator or LinearOperator + scalar (Number): a scalar multiplier + Example: + The scaled operator behaves like the following: + sop = ScaledOperator(operator, scalar) + sop.direct(x) = scalar * operator.direct(x) + sop.adjoint(x) = scalar * operator.adjoint(x) + sop.norm() = operator.norm() + sop.range_geometry() = operator.range_geometry() + sop.domain_geometry() = operator.domain_geometry() + ''' + def __init__(self, function, scalar): + super(ScaledFunction, self).__init__() + self.L = None + if not isinstance (scalar, Number): + raise TypeError('expected scalar: got {}'.format(type(scalar))) + self.scalar = scalar + self.function = function + + def __call__(self,x, out=None): + return self.scalar * self.function(x) + + def call_adjoint(self, x, out=None): + return self.scalar * self.function.call_adjoint(x, out=out) + + def convex_conjugate(self, x, out=None): + return self.scalar * self.function.convex_conjugate(x, out=out) + + def proximal_conjugate(self, x, tau, out = None): + '''TODO check if this is mathematically correct''' + return self.function.proximal_conjugate(x, tau, out=out) + + def grad(self, x): + warnings.warn('''This method will disappear in following + versions of the CIL. Use gradient instead''', DeprecationWarning) + return self.gradient(x, out=None) + + def prox(self, x, tau): + warnings.warn('''This method will disappear in following + versions of the CIL. Use proximal instead''', DeprecationWarning) + return self.proximal(x, out=None) + + def gradient(self, x, out=None): + return self.scalar * self.function.gradient(x, out=out) + + def proximal(self, x, tau, out=None): + '''TODO check if this is mathematically correct''' + return self.function.proximal(x, tau, out=out) -- cgit v1.2.3