summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/Python/fista_module.cpp123
-rw-r--r--src/Python/setup.py1
2 files changed, 123 insertions, 1 deletions
diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp
index d890b10..c2d9352 100644
--- a/src/Python/fista_module.cpp
+++ b/src/Python/fista_module.cpp
@@ -29,6 +29,7 @@ limitations under the License.
#include "SplitBregman_TV_core.h"
#include "FGP_TV_core.h"
#include "LLT_model_core.h"
+#include "PatchBased_Regul_core.h"
#include "utils.h"
@@ -793,7 +794,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d
if (switcher != 0) {
Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL));
}*/
- bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+ bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
np::dtype dtype = np::dtype::get_builtin<float>();
@@ -865,6 +866,126 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d
}
+bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) {
+ // the result is in the following list
+ bp::list result;
+
+ int N, M, Z, numdims, SearchW, /*SimilW, SearchW_real,*/ padXY, newsizeX, newsizeY, newsizeZ, switchpad_crop;
+ //const int *dims;
+ float *A, *B = NULL, *Ap = NULL, *Bp = NULL, h, lambda;
+
+ numdims = input.get_nd();
+ int dims[3];
+
+ dims[0] = input.shape(0);
+ dims[1] = input.shape(1);
+ if (numdims == 2) {
+ dims[2] = -1;
+ }
+ else {
+ dims[2] = input.shape(2);
+ }
+ /*numdims = mxGetNumberOfDimensions(prhs[0]);
+ dims = mxGetDimensions(prhs[0]);*/
+
+ N = dims[0];
+ M = dims[1];
+ Z = dims[2];
+
+ //if ((numdims < 2) || (numdims > 3)) { mexErrMsgTxt("The input should be 2D image or 3D volume"); }
+ //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); }
+
+ //if (nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter");
+
+ ///*Handling inputs*/
+ //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */
+ //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */
+ //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */
+ //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */
+ //lambda = (float)mxGetScalar(prhs[4]); /* regularization parameter */
+
+ //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0");
+ //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0");
+
+ SearchW = SearchW_real + 2 * SimilW;
+
+ /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */
+ /* SimilW_full = 2*SimilW + 1; */ /* the full similarity window size */
+
+
+ padXY = SearchW + 2 * SimilW; /* padding sizes */
+ newsizeX = N + 2 * (padXY); /* the X size of the padded array */
+ newsizeY = M + 2 * (padXY); /* the Y size of the padded array */
+ newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */
+ int N_dims[] = { newsizeX, newsizeY, newsizeZ };
+
+ /******************************2D case ****************************/
+ if (numdims == 2) {
+ ///*Handling output*/
+ //B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL));
+ ///*allocating memory for the padded arrays */
+ //Ap = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL));
+ //Bp = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL));
+ ///**************************************************************************/
+
+ bp::tuple shape = bp::make_tuple(N, M);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+ np::ndarray npB = np::zeros(shape, dtype);
+
+ shape = bp::make_tuple(newsizeX, newsizeY);
+ np::ndarray npAp = np::zeros(shape, dtype);
+ np::ndarray npBp = np::zeros(shape, dtype);
+ B = reinterpret_cast<float *>(npB.get_data());
+ Ap = reinterpret_cast<float *>(npAp.get_data());
+ Bp = reinterpret_cast<float *>(npBp.get_data());
+
+ /*Perform padding of image A to the size of [newsizeX * newsizeY] */
+ switchpad_crop = 0; /*padding*/
+ pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
+
+ /* Do PB regularization with the padded array */
+ PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda);
+
+ switchpad_crop = 1; /*cropping*/
+ pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
+ result.append<np::ndarray>(npB);
+ }
+ else
+ {
+ /******************************3D case ****************************/
+ ///*Handling output*/
+ //B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL));
+ ///*allocating memory for the padded arrays */
+ //Ap = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL));
+ //Bp = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL));
+ /**************************************************************************/
+ bp::tuple shape = bp::make_tuple(dims[0], dims[1], dims[2]);
+ bp::tuple shape_AB = bp::make_tuple(N_dims[0], N_dims[1], N_dims[2]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+ np::ndarray npB = np::zeros(shape, dtype);
+ np::ndarray npAp = np::zeros(shape_AB, dtype);
+ np::ndarray npBp = np::zeros(shape_AB, dtype);
+ B = reinterpret_cast<float *>(npB.get_data());
+ Ap = reinterpret_cast<float *>(npAp.get_data());
+ Bp = reinterpret_cast<float *>(npBp.get_data());
+ /*Perform padding of image A to the size of [newsizeX * newsizeY * newsizeZ] */
+ switchpad_crop = 0; /*padding*/
+ pad_crop(A, Ap, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop);
+
+ /* Do PB regularization with the padded array */
+ PB_FUNC3D(Ap, Bp, newsizeY, newsizeX, newsizeZ, padXY, SearchW, SimilW, (float)h, (float)lambda);
+
+ switchpad_crop = 1; /*cropping*/
+ pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop);
+
+ result.append<np::ndarray>(npB);
+ } /*end else ndims*/
+
+ return result;
+}
+
BOOST_PYTHON_MODULE(regularizers)
{
np::initialize();
diff --git a/src/Python/setup.py b/src/Python/setup.py
index a8feb1c..a4eed14 100644
--- a/src/Python/setup.py
+++ b/src/Python/setup.py
@@ -52,6 +52,7 @@ setup(
"..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c",
"..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c",
"..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c",
+ "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c",
"..\\..\\main_func\\regularizers_CPU\\utils.c"
],
include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ),