From 4c1cdbe58a441a27e46da0d6983851740db94939 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Thu, 10 Jan 2019 14:50:36 +0000
Subject: allows 3D TGV by a for loop on slices

---
 Wrappers/Python/src/cpu_regularisers.pyx | 11 +++++++++--
 Wrappers/Python/src/gpu_regularisers.pyx |  7 ++++++-
 2 files changed, 15 insertions(+), 3 deletions(-)

(limited to 'Wrappers/Python/src')

diff --git a/Wrappers/Python/src/cpu_regularisers.pyx b/Wrappers/Python/src/cpu_regularisers.pyx
index 4aa3251..33d6eb7 100644
--- a/Wrappers/Python/src/cpu_regularisers.pyx
+++ b/Wrappers/Python/src/cpu_regularisers.pyx
@@ -199,9 +199,16 @@ def TV_SB_3D(np.ndarray[np.float32_t, ndim=3, mode="c"] inputData,
 #***************************************************************#
 def TGV_CPU(inputData, regularisation_parameter, alpha1, alpha0, iterations, LipshitzConst):
     if inputData.ndim == 2:
-        return TGV_2D(inputData, regularisation_parameter, alpha1, alpha0, iterations, LipshitzConst)
+        return TGV_2D(inputData, regularisation_parameter, alpha1, alpha0, 
+                      iterations, LipshitzConst)
     elif inputData.ndim == 3:
-        return 0
+        shape = inputData.shape
+        print (shape)
+        out = inputData.copy()
+        for i in range(shape[0]):
+            out[i,:,:] = TGV_2D(inputData[i,:,:], regularisation_parameter, 
+               alpha1, alpha0, iterations, LipshitzConst)
+        return out
 
 def TGV_2D(np.ndarray[np.float32_t, ndim=2, mode="c"] inputData, 
                      float regularisation_parameter,
diff --git a/Wrappers/Python/src/gpu_regularisers.pyx b/Wrappers/Python/src/gpu_regularisers.pyx
index 2b97865..47a6149 100644
--- a/Wrappers/Python/src/gpu_regularisers.pyx
+++ b/Wrappers/Python/src/gpu_regularisers.pyx
@@ -102,7 +102,12 @@ def TGV_GPU(inputData, regularisation_parameter, alpha1, alpha0, iterations, Lip
     if inputData.ndim == 2:
         return TGV2D(inputData, regularisation_parameter, alpha1, alpha0, iterations, LipshitzConst)
     elif inputData.ndim == 3:
-        return 0
+        shape = inputData.shape
+        out = inputData.copy()
+        for i in range(shape[0]):
+            out[i,:,:] = TGV2D(inputData[i,:,:], regularisation_parameter, 
+               alpha1, alpha0, iterations, LipshitzConst)
+        return out
 # Directional Total-variation Fast-Gradient-Projection (FGP)
 def dTV_FGP_GPU(inputData,
                      refdata,
-- 
cgit v1.2.3