diff options
Diffstat (limited to 'Wrappers/Python')
| -rw-r--r-- | Wrappers/Python/src/cpu_regularisers.pyx | 11 | ||||
| -rw-r--r-- | Wrappers/Python/src/gpu_regularisers.pyx | 7 | 
2 files changed, 15 insertions, 3 deletions
| diff --git a/Wrappers/Python/src/cpu_regularisers.pyx b/Wrappers/Python/src/cpu_regularisers.pyx index 4aa3251..33d6eb7 100644 --- a/Wrappers/Python/src/cpu_regularisers.pyx +++ b/Wrappers/Python/src/cpu_regularisers.pyx @@ -199,9 +199,16 @@ def TV_SB_3D(np.ndarray[np.float32_t, ndim=3, mode="c"] inputData,  #***************************************************************#  def TGV_CPU(inputData, regularisation_parameter, alpha1, alpha0, iterations, LipshitzConst):      if inputData.ndim == 2: -        return TGV_2D(inputData, regularisation_parameter, alpha1, alpha0, iterations, LipshitzConst) +        return TGV_2D(inputData, regularisation_parameter, alpha1, alpha0,  +                      iterations, LipshitzConst)      elif inputData.ndim == 3: -        return 0 +        shape = inputData.shape +        print (shape) +        out = inputData.copy() +        for i in range(shape[0]): +            out[i,:,:] = TGV_2D(inputData[i,:,:], regularisation_parameter,  +               alpha1, alpha0, iterations, LipshitzConst) +        return out  def TGV_2D(np.ndarray[np.float32_t, ndim=2, mode="c"] inputData,                        float regularisation_parameter, diff --git a/Wrappers/Python/src/gpu_regularisers.pyx b/Wrappers/Python/src/gpu_regularisers.pyx index 2b97865..47a6149 100644 --- a/Wrappers/Python/src/gpu_regularisers.pyx +++ b/Wrappers/Python/src/gpu_regularisers.pyx @@ -102,7 +102,12 @@ def TGV_GPU(inputData, regularisation_parameter, alpha1, alpha0, iterations, Lip      if inputData.ndim == 2:          return TGV2D(inputData, regularisation_parameter, alpha1, alpha0, iterations, LipshitzConst)      elif inputData.ndim == 3: -        return 0 +        shape = inputData.shape +        out = inputData.copy() +        for i in range(shape[0]): +            out[i,:,:] = TGV2D(inputData[i,:,:], regularisation_parameter,  +               alpha1, alpha0, iterations, LipshitzConst) +        return out  # Directional Total-variation Fast-Gradient-Projection (FGP)  def dTV_FGP_GPU(inputData,                       refdata, | 
