summaryrefslogtreecommitdiffstats
path: root/Wrappers/Python
diff options
context:
space:
mode:
Diffstat (limited to 'Wrappers/Python')
-rw-r--r--Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
index 0d3c8f5..6920829 100644
--- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
+++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py
@@ -52,8 +52,8 @@ class KullbackLeibler(Function):
'''
- # TODO avoid scipy import ????
- tmp = scipy.special.kl_div(self.b.as_array(), x.as_array())
+ ind = x.as_array()>0
+ tmp = scipy.special.kl_div(self.b.as_array()[ind], x.as_array()[ind])
return numpy.sum(tmp)
@@ -78,9 +78,8 @@ class KullbackLeibler(Function):
def convex_conjugate(self, x):
- # TODO avoid scipy import ????
- xlogy = scipy.special.xlogy(self.b.as_array(), 1 - x.as_array())
- return numpy.sum(-xlogy)
+ xlogy = - scipy.special.xlogy(self.b.as_array(), 1 - x.as_array())
+ return numpy.sum(xlogy)
def proximal(self, x, tau, out=None):