summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEdoardo Pasca <edo.paskino@gmail.com>2019-01-10 14:50:36 +0000
committerEdoardo Pasca <edo.paskino@gmail.com>2019-01-10 14:50:36 +0000
commit4c1cdbe58a441a27e46da0d6983851740db94939 (patch)
tree8fc0faee0cf549214053b24b1a27a417f02ba9c7
parent9f943b6ee5018e917a032eeebfa938855502934c (diff)
downloadregularization-4c1cdbe58a441a27e46da0d6983851740db94939.tar.gz
regularization-4c1cdbe58a441a27e46da0d6983851740db94939.tar.bz2
regularization-4c1cdbe58a441a27e46da0d6983851740db94939.tar.xz
regularization-4c1cdbe58a441a27e46da0d6983851740db94939.zip
allows 3D TGV by a for loop on slices
-rw-r--r--Wrappers/Python/src/cpu_regularisers.pyx11
-rw-r--r--Wrappers/Python/src/gpu_regularisers.pyx7
2 files changed, 15 insertions, 3 deletions
diff --git a/Wrappers/Python/src/cpu_regularisers.pyx b/Wrappers/Python/src/cpu_regularisers.pyx
index 4aa3251..33d6eb7 100644
--- a/Wrappers/Python/src/cpu_regularisers.pyx
+++ b/Wrappers/Python/src/cpu_regularisers.pyx
@@ -199,9 +199,16 @@ def TV_SB_3D(np.ndarray[np.float32_t, ndim=3, mode="c"] inputData,
#***************************************************************#
def TGV_CPU(inputData, regularisation_parameter, alpha1, alpha0, iterations, LipshitzConst):
if inputData.ndim == 2:
- return TGV_2D(inputData, regularisation_parameter, alpha1, alpha0, iterations, LipshitzConst)
+ return TGV_2D(inputData, regularisation_parameter, alpha1, alpha0,
+ iterations, LipshitzConst)
elif inputData.ndim == 3:
- return 0
+ shape = inputData.shape
+ print (shape)
+ out = inputData.copy()
+ for i in range(shape[0]):
+ out[i,:,:] = TGV_2D(inputData[i,:,:], regularisation_parameter,
+ alpha1, alpha0, iterations, LipshitzConst)
+ return out
def TGV_2D(np.ndarray[np.float32_t, ndim=2, mode="c"] inputData,
float regularisation_parameter,
diff --git a/Wrappers/Python/src/gpu_regularisers.pyx b/Wrappers/Python/src/gpu_regularisers.pyx
index 2b97865..47a6149 100644
--- a/Wrappers/Python/src/gpu_regularisers.pyx
+++ b/Wrappers/Python/src/gpu_regularisers.pyx
@@ -102,7 +102,12 @@ def TGV_GPU(inputData, regularisation_parameter, alpha1, alpha0, iterations, Lip
if inputData.ndim == 2:
return TGV2D(inputData, regularisation_parameter, alpha1, alpha0, iterations, LipshitzConst)
elif inputData.ndim == 3:
- return 0
+ shape = inputData.shape
+ out = inputData.copy()
+ for i in range(shape[0]):
+ out[i,:,:] = TGV2D(inputData[i,:,:], regularisation_parameter,
+ alpha1, alpha0, iterations, LipshitzConst)
+ return out
# Directional Total-variation Fast-Gradient-Projection (FGP)
def dTV_FGP_GPU(inputData,
refdata,