summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/Python/fista_module.cpp245
1 files changed, 146 insertions, 99 deletions
diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp
index c2d9352..eacda3d 100644
--- a/src/Python/fista_module.cpp
+++ b/src/Python/fista_module.cpp
@@ -30,6 +30,7 @@ limitations under the License.
#include "FGP_TV_core.h"
#include "LLT_model_core.h"
#include "PatchBased_Regul_core.h"
+#include "TGV_PD_core.h"
#include "utils.h"
@@ -103,101 +104,8 @@ If unsuccessful in a MEX file, the MEX file terminates and returns control to th
enough free heap space to create the mxArray.
*/
-void mexErrMessageText(char* text) {
- std::cerr << text << std::endl;
-}
-
-/*
-double mxGetScalar(const mxArray *pm);
-args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
-Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double.
-*/
-
-template<typename T>
-double mxGetScalar(const np::ndarray plh) {
- return (double)bp::extract<T>(plh[0]);
-}
-
-
-
-template<typename T>
-T * mxGetData(const np::ndarray pm) {
- //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray.
- //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double.
- /*Access the numpy array pointer:
- char * get_data() const;
- Returns: Array’s raw data pointer as a char
- Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it.
- probably this would work.
- A = reinterpret_cast<float *>(prhs[0]);
- */
- return reinterpret_cast<T *>(prhs[0]);
-}
-
-template<typename T>
-np::ndarray zeros(int dims, int * dim_array, T el) {
- bp::tuple shape;
- if (dims == 3)
- shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]);
- else if (dims == 2)
- shape = bp::make_tuple(dim_array[0], dim_array[1]);
- np::dtype dtype = np::dtype::get_builtin<T>();
- np::ndarray zz = np::zeros(shape, dtype);
- return zz;
-}
-
-
-bp::list mexFunction(np::ndarray input) {
- int number_of_dims = input.get_nd();
- int dim_array[3];
-
- dim_array[0] = input.shape(0);
- dim_array[1] = input.shape(1);
- if (number_of_dims == 2) {
- dim_array[2] = -1;
- }
- else {
- dim_array[2] = input.shape(2);
- }
-
- /**************************************************************************/
- np::ndarray zz = zeros(3, dim_array, (int)0);
- np::ndarray fzz = zeros(3, dim_array, (float)0);
- /**************************************************************************/
-
- int * A = reinterpret_cast<int *>(input.get_data());
- int * B = reinterpret_cast<int *>(zz.get_data());
- float * C = reinterpret_cast<float *>(fzz.get_data());
-
- //Copy data and cast
- for (int i = 0; i < dim_array[0]; i++) {
- for (int j = 0; j < dim_array[1]; j++) {
- for (int k = 0; k < dim_array[2]; k++) {
- int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i;
- int val = (*(A + index));
- float fval = (float)val;
- std::memcpy(B + index, &val, sizeof(int));
- std::memcpy(C + index, &fval, sizeof(float));
- }
- }
- }
-
- bp::list result;
-
- result.append<int>(number_of_dims);
- result.append<int>(dim_array[0]);
- result.append<int>(dim_array[1]);
- result.append<int>(dim_array[2]);
- result.append<np::ndarray>(zz);
- result.append<np::ndarray>(fzz);
-
- //result.append<bp::tuple>(tup);
- return result;
-
-}
-
bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) {
// the result is in the following list
@@ -487,7 +395,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me
np::ndarray npP1_old = np::zeros(shape, dtype);
np::ndarray npP2_old = np::zeros(shape, dtype);
np::ndarray npR1 = np::zeros(shape, dtype);
- np::ndarray npR2 = zeros(2, dim_array, (float)0);
+ np::ndarray npR2 = np::zeros(shape, dtype);
D = reinterpret_cast<float *>(npD.get_data());
D_old = reinterpret_cast<float *>(npD_old.get_data());
@@ -866,7 +774,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d
}
-bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) {
+bp::list PatchBased_Regul(np::ndarray input, double d_lambda, int SearchW_real, int SimilW, double d_h) {
// the result is in the following list
bp::list result;
@@ -899,6 +807,7 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub
///*Handling inputs*/
//A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */
+ A = reinterpret_cast<float *>(input.get_data());
//SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */
//SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */
//h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */
@@ -907,6 +816,8 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub
//if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0");
//if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0");
+ lambda = (float)d_lambda;
+ h = (float)d_h;
SearchW = SearchW_real + 2 * SimilW;
/* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */
@@ -918,7 +829,6 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub
newsizeY = M + 2 * (padXY); /* the Y size of the padded array */
newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */
int N_dims[] = { newsizeX, newsizeY, newsizeZ };
-
/******************************2D case ****************************/
if (numdims == 2) {
///*Handling output*/
@@ -943,12 +853,13 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub
/*Perform padding of image A to the size of [newsizeX * newsizeY] */
switchpad_crop = 0; /*padding*/
pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
-
+
/* Do PB regularization with the padded array */
PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda);
-
+
switchpad_crop = 1; /*cropping*/
pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop);
+
result.append<np::ndarray>(npB);
}
else
@@ -986,6 +897,141 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub
return result;
}
+bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_alpha0, int iter) {
+ // the result is in the following list
+ bp::list result;
+ int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll;
+ //const int *dim_array;
+ float *A, *U, *U_old, *P1, *P2, *P3, *Q1, *Q2, *Q3, *Q4, *Q5, *Q6, *Q7, *Q8, *Q9, *V1, *V1_old, *V2, *V2_old, *V3, *V3_old, lambda, L2, tau, sigma, alpha1, alpha0;
+
+ //number_of_dims = mxGetNumberOfDimensions(prhs[0]);
+ //dim_array = mxGetDimensions(prhs[0]);
+ number_of_dims = input.get_nd();
+ int dim_array[3];
+
+ dim_array[0] = input.shape(0);
+ dim_array[1] = input.shape(1);
+ if (number_of_dims == 2) {
+ dim_array[2] = -1;
+ }
+ else {
+ dim_array[2] = input.shape(2);
+ }
+ /*Handling Matlab input data*/
+ //A = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/
+ //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); }
+
+ A = reinterpret_cast<float *>(input.get_data());
+
+ //lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/
+ //alpha1 = (float)mxGetScalar(prhs[2]); /*first-order term*/
+ //alpha0 = (float)mxGetScalar(prhs[3]); /*second-order term*/
+ //iter = (int)mxGetScalar(prhs[4]); /*iterations number*/
+ //if (nrhs != 5) mexErrMsgTxt("Five input parameters is reqired: Image(2D/3D), Regularization parameter, alpha1, alpha0, Iterations");
+ lambda = (float)d_lambda;
+ alpha1 = (float)d_alpha1;
+ alpha0 = (float)d_alpha0;
+
+ /*Handling Matlab output data*/
+ dimX = dim_array[0]; dimY = dim_array[1];
+
+ if (number_of_dims == 2) {
+ /*2D case*/
+ dimZ = 1;
+ bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]);
+ np::dtype dtype = np::dtype::get_builtin<float>();
+
+ np::ndarray npU = np::zeros(shape, dtype);
+ np::ndarray npP1 = np::zeros(shape, dtype);
+ np::ndarray npP2 = np::zeros(shape, dtype);
+ np::ndarray npQ1 = np::zeros(shape, dtype);
+ np::ndarray npQ2 = np::zeros(shape, dtype);
+ np::ndarray npQ3 = np::zeros(shape, dtype);
+ np::ndarray npV1 = np::zeros(shape, dtype);
+ np::ndarray npV1_old = np::zeros(shape, dtype);
+ np::ndarray npV2 = np::zeros(shape, dtype);
+ np::ndarray npV2_old = np::zeros(shape, dtype);
+ np::ndarray npU_old = np::zeros(shape, dtype);
+
+ U = reinterpret_cast<float *>(npU.get_data());
+ U_old = reinterpret_cast<float *>(npU_old.get_data());
+ P1 = reinterpret_cast<float *>(npP1.get_data());
+ P2 = reinterpret_cast<float *>(npP2.get_data());
+ Q1 = reinterpret_cast<float *>(npQ1.get_data());
+ Q2 = reinterpret_cast<float *>(npQ2.get_data());
+ Q3 = reinterpret_cast<float *>(npQ3.get_data());
+ V1 = reinterpret_cast<float *>(npV1.get_data());
+ V1_old = reinterpret_cast<float *>(npV1_old.get_data());
+ V2 = reinterpret_cast<float *>(npV2.get_data());
+ V2_old = reinterpret_cast<float *>(npV2_old.get_data());
+ //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+ /*dual variables*/
+ /*P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+ Q1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ Q2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ Q3 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+ U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+
+ V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));
+ V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/
+ /*printf("%i \n", i);*/
+ L2 = 12.0; /*Lipshitz constant*/
+ tau = 1.0 / pow(L2, 0.5);
+ sigma = 1.0 / pow(L2, 0.5);
+
+ /*Copy A to U*/
+ copyIm(A, U, dimX, dimY, dimZ);
+ /* Here primal-dual iterations begin for 2D */
+ for (ll = 0; ll < iter; ll++) {
+
+ /* Calculate Dual Variable P */
+ DualP_2D(U, V1, V2, P1, P2, dimX, dimY, dimZ, sigma);
+
+ /*Projection onto convex set for P*/
+ ProjP_2D(P1, P2, dimX, dimY, dimZ, alpha1);
+
+ /* Calculate Dual Variable Q */
+ DualQ_2D(V1, V2, Q1, Q2, Q3, dimX, dimY, dimZ, sigma);
+
+ /*Projection onto convex set for Q*/
+ ProjQ_2D(Q1, Q2, Q3, dimX, dimY, dimZ, alpha0);
+
+ /*saving U into U_old*/
+ copyIm(U, U_old, dimX, dimY, dimZ);
+
+ /*adjoint operation -> divergence and projection of P*/
+ DivProjP_2D(U, A, P1, P2, dimX, dimY, dimZ, lambda, tau);
+
+ /*get updated solution U*/
+ newU(U, U_old, dimX, dimY, dimZ);
+
+ /*saving V into V_old*/
+ copyIm(V1, V1_old, dimX, dimY, dimZ);
+ copyIm(V2, V2_old, dimX, dimY, dimZ);
+
+ /* upd V*/
+ UpdV_2D(V1, V2, P1, P2, Q1, Q2, Q3, dimX, dimY, dimZ, tau);
+
+ /*get new V*/
+ newU(V1, V1_old, dimX, dimY, dimZ);
+ newU(V2, V2_old, dimX, dimY, dimZ);
+ } /*end of iterations*/
+
+ result.append<np::ndarray>(npU);
+ }
+
+
+
+
+ return result;
+}
+
BOOST_PYTHON_MODULE(regularizers)
{
np::initialize();
@@ -997,8 +1043,9 @@ BOOST_PYTHON_MODULE(regularizers)
np::dtype dt1 = np::dtype::get_builtin<uint8_t>();
np::dtype dt2 = np::dtype::get_builtin<uint16_t>();
- def("mexFunction", mexFunction);
def("SplitBregman_TV", SplitBregman_TV);
def("FGP_TV", FGP_TV);
def("LLT_model", LLT_model);
+ def("PatchBased_Regul", PatchBased_Regul);
+ def("TGV_PD", TGV_PD);
} \ No newline at end of file