diff options
| -rw-r--r-- | Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py | 9 | 
1 files changed, 4 insertions, 5 deletions
| diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py index 0d3c8f5..6920829 100644 --- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py +++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py @@ -52,8 +52,8 @@ class KullbackLeibler(Function):          ''' -        # TODO avoid scipy import ???? -        tmp = scipy.special.kl_div(self.b.as_array(), x.as_array())                 +        ind = x.as_array()>0 +        tmp = scipy.special.kl_div(self.b.as_array()[ind], x.as_array()[ind])                          return numpy.sum(tmp)  @@ -78,9 +78,8 @@ class KullbackLeibler(Function):      def convex_conjugate(self, x): -        # TODO avoid scipy import ???? -        xlogy = scipy.special.xlogy(self.b.as_array(), 1 - x.as_array()) -        return numpy.sum(-xlogy) +        xlogy = - scipy.special.xlogy(self.b.as_array(), 1 - x.as_array()) +        return numpy.sum(xlogy)      def proximal(self, x, tau, out=None): | 
