diff options
| -rw-r--r-- | Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py | 115 | 
1 files changed, 60 insertions, 55 deletions
| diff --git a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py index 7e55ee8..fb2bfd8 100644 --- a/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py +++ b/Wrappers/Python/ccpi/optimisation/algorithms/PDHG.py @@ -5,6 +5,8 @@ Created on Mon Feb  4 16:18:06 2019  @author: evangelos  """ +from ccpi.optimisation.algorithms import Algorithm +  from ccpi.framework import ImageData  import numpy as np @@ -13,67 +15,70 @@ import time  from ccpi.optimisation.operators import BlockOperator  from ccpi.framework import BlockDataContainer -def PDHG(f, g, operator, tau = None, sigma = None, opt = None, **kwargs): -         -    # algorithmic parameters -    if opt is None:  -        opt = {'tol': 1e-6, 'niter': 500, 'show_iter': 100, \ -               'memopt': False}  -         -    if sigma is None and tau is None: -        raise ValueError('Need sigma*tau||K||^2<1')  -                 -    niter = opt['niter'] if 'niter' in opt.keys() else 1000 -    tol = opt['tol'] if 'tol' in opt.keys() else 1e-4 -    memopt = opt['memopt'] if 'memopt' in opt.keys() else False   -    show_iter = opt['show_iter'] if 'show_iter' in opt.keys() else False  -    stop_crit = opt['stop_crit'] if 'stop_crit' in opt.keys() else False  +class PDHG(Algorithm): +    '''Primal Dual Hybrid Gradient''' -    if isinstance(operator, BlockOperator): -        x_old = operator.domain_geometry().allocate() -        y_old = operator.range_geometry().allocate() -    else: -        x_old = operator.domain_geometry().allocate() -        y_old = operator.range_geometry().allocate()        -         -     -    xbar = x_old -    x_tmp = x_old -    x = x_old -     -    y_tmp = y_old -    y = y_tmp -         -    # relaxation parameter -    theta = 1 -     -    t = time.time() -     -    objective = [] +    def __init__(self, **kwargs): +        super(PDHG, self).__init__() +        self.f        = kwargs.get('f', None) +        self.operator = kwargs.get('operator', None) +        self.g        = kwargs.get('g', None) +        self.tau      = kwargs.get('tau', None) +        self.sigma    = kwargs.get('sigma', None) + +        if self.f is not None and self.operator is not None and \ +           self.g is not None: +            print ("Calling from creator") +            self.set_up(self.f, +                        self.operator, +                        self.g,  +                        self.tau,  +                        self.sigma) + +    def set_up(self, f, g, operator, tau = None, sigma = None, opt = None, **kwargs): +        # algorithmic parameters +             +        if sigma is None and tau is None: +            raise ValueError('Need sigma*tau||K||^2<1')  +                     -    for i in range(niter): +        self.x_old = self.operator.domain_geometry().allocate() +        self.y_old = self.operator.range_geometry().allocate() +        self.xbar = self.x_old.copy() +        #x_tmp = x_old +        self.x = self.x_old.copy() +        self.y = self.y_old.copy() +        #y_tmp = y_old +        #y = y_tmp +             +        # relaxation parameter +        self.theta = 1 + +    def update(self):          # Gradient descent, Dual problem solution -        y_tmp = y_old + sigma * operator.direct(xbar) -        y = f.proximal_conjugate(y_tmp, sigma) +        self.y_old += self.sigma * self.operator.direct(self.xbar) +        self.y = self.f.proximal_conjugate(self.y_old, self.sigma)          # Gradient ascent, Primal problem solution -        x_tmp = x_old - tau * operator.adjoint(y) -        x = g.proximal(x_tmp, tau) +        self.x_old -= self.tau * self.operator.adjoint(self.y) +        self.x = self.g.proximal(self.x_old, self.tau)          #Update -        xbar = x + theta * (x - x_old) -                                 -        x_old = x -        y_old = y    -         -#        if i%100==0: -# -#            plt.imshow(x.as_array()[100]) -#            plt.show() -#            print(f(operator.direct(x)) + g(x), i) -                          -    t_end = time.time()         -         -    return x, t_end - t, objective +        #xbar = x + theta * (x - x_old) +        self.xbar.fill(self.x) +        self.xbar -= self.x_old  +        self.xbar *= self.theta +        self.xbar += self.x +                         +        self.x_old.fill(self.x) +        self.y_old.fill(self.y) +        #self.y_old = y.copy() +        #self.y = self.y_old + +    def update_objective(self): +        self.loss.append([self.f(self.operator.direct(self.x)) + self.g(self.x), +            -(self.f.convex_conjugate(self.y) + self.g.convex_conjugate(- 1 * self.operator.adjoint(self.y))) +        ]) + | 
