diff options
| author | epapoutsellis <epapoutsellis@gmail.com> | 2019-06-03 22:51:43 +0100 | 
|---|---|---|
| committer | epapoutsellis <epapoutsellis@gmail.com> | 2019-06-03 22:51:43 +0100 | 
| commit | 14e06f7ad88202114b22ed478ba6efab952fa30b (patch) | |
| tree | e5a990b77bcf2ca3649f942e876d0f3f85f70154 | |
| parent | 1aa94932776f3a95b02304b1dfd8a18459d7e37c (diff) | |
| download | framework-14e06f7ad88202114b22ed478ba6efab952fa30b.tar.gz framework-14e06f7ad88202114b22ed478ba6efab952fa30b.tar.bz2 framework-14e06f7ad88202114b22ed478ba6efab952fa30b.tar.xz framework-14e06f7ad88202114b22ed478ba6efab952fa30b.zip | |
fix call kl div
| -rw-r--r-- | Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py | 9 | 
1 files changed, 4 insertions, 5 deletions
| diff --git a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py index 0d3c8f5..6920829 100644 --- a/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py +++ b/Wrappers/Python/ccpi/optimisation/functions/KullbackLeibler.py @@ -52,8 +52,8 @@ class KullbackLeibler(Function):          ''' -        # TODO avoid scipy import ???? -        tmp = scipy.special.kl_div(self.b.as_array(), x.as_array())                 +        ind = x.as_array()>0 +        tmp = scipy.special.kl_div(self.b.as_array()[ind], x.as_array()[ind])                          return numpy.sum(tmp)  @@ -78,9 +78,8 @@ class KullbackLeibler(Function):      def convex_conjugate(self, x): -        # TODO avoid scipy import ???? -        xlogy = scipy.special.xlogy(self.b.as_array(), 1 - x.as_array()) -        return numpy.sum(-xlogy) +        xlogy = - scipy.special.xlogy(self.b.as_array(), 1 - x.as_array()) +        return numpy.sum(xlogy)      def proximal(self, x, tau, out=None): | 
