diff options
author | Edoardo Pasca <edo.paskino@gmail.com> | 2017-08-04 16:16:37 +0100 |
---|---|---|
committer | Edoardo Pasca <edo.paskino@gmail.com> | 2017-10-11 15:38:09 +0100 |
commit | fd496731c8e9d4975864d76dbb6574cbeee7cf98 (patch) | |
tree | 0afe24b1a3ac3479007885abf443c912d137dbc8 /src/Python | |
parent | 36e4c296223f67bb917511089ec59533460f1695 (diff) | |
download | regularization-fd496731c8e9d4975864d76dbb6574cbeee7cf98.tar.gz regularization-fd496731c8e9d4975864d76dbb6574cbeee7cf98.tar.bz2 regularization-fd496731c8e9d4975864d76dbb6574cbeee7cf98.tar.xz regularization-fd496731c8e9d4975864d76dbb6574cbeee7cf98.zip |
Added 3 regularizers
SplitBregman_TV
FGP_TV
LLT_model
Diffstat (limited to 'src/Python')
-rw-r--r-- | src/Python/fista_module.cpp | 266 |
1 files changed, 232 insertions, 34 deletions
diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 2492884..d890b10 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,8 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "utils.h" @@ -131,6 +133,18 @@ T * mxGetData(const np::ndarray pm) { return reinterpret_cast<T *>(prhs[0]); } +template<typename T> +np::ndarray zeros(int dims, int * dim_array, T el) { + bp::tuple shape; + if (dims == 3) + shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + else if (dims == 2) + shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin<T>(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + @@ -169,7 +183,6 @@ bp::list mexFunction(np::ndarray input) { } } - bp::list result; result.append<int>(number_of_dims); @@ -189,14 +202,14 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int *dim_array; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; + //const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -252,26 +265,26 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); np::dtype dtype = np::dtype::get_builtin<float>(); - np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU = np::zeros(shape, dtype); np::ndarray npU_old = np::zeros(shape, dtype); - np::ndarray npDx = np::zeros(shape, dtype); - np::ndarray npDy = np::zeros(shape, dtype); - np::ndarray npBx = np::zeros(shape, dtype); - np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); - U = reinterpret_cast<float *>(npU.get_data()); + U = reinterpret_cast<float *>(npU.get_data()); U_old = reinterpret_cast<float *>(npU_old.get_data()); - Dx = reinterpret_cast<float *>(npDx.get_data()); - Dy = reinterpret_cast<float *>(npDy.get_data()); - Bx = reinterpret_cast<float *>(npBx.get_data()); - By = reinterpret_cast<float *>(npBy.get_data()); + Dx = reinterpret_cast<float *>(npDx.get_data()); + Dy = reinterpret_cast<float *>(npDy.get_data()); + Bx = reinterpret_cast<float *>(npBx.get_data()); + By = reinterpret_cast<float *>(npBy.get_data()); + - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ - for (ll = 0; ll<iter; ll++) { + for (ll = 0; ll < iter; ll++) { /*storing old values*/ copyIm(U, U_old, dimX, dimY, dimZ); @@ -286,7 +299,7 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi /* calculate norm to terminate earlier */ re = 0.0f; re1 = 0.0f; - for (j = 0; j<dimX*dimY*dimZ; j++) + for (j = 0; j < dimX*dimY*dimZ; j++) { re += pow(U_old[j] - U[j], 2); re1 += pow(U_old[j], 2); @@ -303,11 +316,13 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi /*printf("%f %i %i \n", re, ll, count); */ /*copyIm(U_old, U, dimX, dimY, dimZ); */ - result.append<np::ndarray>(npU); - result.append<int>(ll); + } //printf("SB iterations stopped at iteration: %i\n", ll); - if (number_of_dims == 3) { + result.append<np::ndarray>(npU); + result.append<int>(ll); + } + if (number_of_dims == 3) { /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); @@ -375,24 +390,25 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi result.append<np::ndarray>(npU); result.append<int>(ll); } - } return result; -} + } + + bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; - float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval, mu; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -512,7 +528,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); //funcvalA[0] = sqrt(funcval); float fv = sqrt(funcval); - std::memcpy(funcvalA, &fv), sizeof(float)); + std::memcpy(funcvalA, &fv, sizeof(float)); break; } @@ -524,7 +540,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); //funcvalA[0] = sqrt(funcval); float fv = sqrt(funcval); - std::memcpy(funcvalA, &fv), sizeof(float)); + std::memcpy(funcvalA, &fv, sizeof(float)); break; } } @@ -544,7 +560,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); //funcvalA[0] = sqrt(funcval); float fv = sqrt(funcval); - std::memcpy(funcvalA, &fv), sizeof(float)); + std::memcpy(funcvalA, &fv, sizeof(float)); } } //printf("FGP-TV iterations stopped at iteration %i with the function value %f \n", ll, funcvalA[0]); @@ -622,7 +638,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); //funcvalA[0] = sqrt(funcval); float fv = sqrt(funcval); - std::memcpy(funcvalA, &fv), sizeof(float)); + std::memcpy(funcvalA, &fv, sizeof(float)); break; } @@ -634,7 +650,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); //funcvalA[0] = sqrt(funcval); float fv = sqrt(funcval); - std::memcpy(funcvalA, &fv), sizeof(float)); + std::memcpy(funcvalA, &fv, sizeof(float)); break; } } @@ -655,7 +671,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); //funcvalA[0] = sqrt(funcval); float fv = sqrt(funcval); - std::memcpy(funcvalA, &fv), sizeof(float)); + std::memcpy(funcvalA, &fv, sizeof(float)); } } @@ -668,13 +684,194 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me return result; } -BOOST_PYTHON_MODULE(fista) +bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) { + // the result is in the following list + bp::list result; + + int number_of_dims, dimX, dimY, dimZ, ll, j, count; + //const int *dim_array; + float *U0, *U = NULL, *U_old = NULL, *D1 = NULL, *D2 = NULL, *D3 = NULL, lambda, tau, re, re1, epsil, re_old; + unsigned short *Map = NULL; + + number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + ///*Handling Matlab input data*/ + //U0 = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/ + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + //lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/ + //tau = (float)mxGetScalar(prhs[2]); /* time-step */ + //iter = (int)mxGetScalar(prhs[3]); /*iterations number*/ + //epsil = (float)mxGetScalar(prhs[4]); /* tolerance constant */ + //switcher = (int)mxGetScalar(prhs[5]); /*switch on (1) restrictive smoothing in Z dimension*/ + + U0 = reinterpret_cast<float *>(input.get_data()); + lambda = (float)d_lambda; + tau = (float)d_tau; + // iter is passed as parameter + epsil = (float)d_epsil; + // switcher is passed as parameter + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = 1; + + if (number_of_dims == 2) { + /*2D case*/ + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin<float>(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + + + U = reinterpret_cast<float *>(npU.get_data()); + U_old = reinterpret_cast<float *>(npU_old.get_data()); + D1 = reinterpret_cast<float *>(npD1.get_data()); + D2 = reinterpret_cast<float *>(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der2D(U, D1, D2, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd2D(U0, U, D1, D2, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j<dimX*dimY*dimZ; j++) + { + re += pow(U_old[j] - U[j], 2); + re1 += pow(U_old[j], 2); + } + re = sqrt(re) / sqrt(re1); + if (re < epsil) count++; + if (count > 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + + result.append<np::ndarray>(npU); + } + else if (number_of_dims == 3) { + /*3D case*/ + dimZ = dim_array[2]; + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + if (switcher != 0) { + Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + }*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin<float>(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + np::ndarray npD3 = np::zeros(shape, dtype); + np::ndarray npMap = np::zeros(shape, np::dtype::get_builtin<unsigned short>()); + Map = reinterpret_cast<unsigned short *>(npMap.get_data()); + if (switcher != 0) { + //Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + + Map = reinterpret_cast<unsigned short *>(npMap.get_data()); + } + + U = reinterpret_cast<float *>(npU.get_data()); + U_old = reinterpret_cast<float *>(npU_old.get_data()); + D1 = reinterpret_cast<float *>(npD1.get_data()); + D2 = reinterpret_cast<float *>(npD2.get_data()); + D3 = reinterpret_cast<float *>(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + + if (switcher == 1) { + /* apply restrictive smoothing */ + calcMap(U, Map, dimX, dimY, dimZ); + /*clear outliers */ + cleanMap(Map, dimX, dimY, dimZ); + } + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der3D(U, D1, D2, D3, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd3D(U0, U, D1, D2, D3, Map, switcher, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j<dimX*dimY*dimZ; j++) + { + re += pow(U_old[j] - U[j], 2); + re1 += pow(U_old[j], 2); + } + re = sqrt(re) / sqrt(re1); + if (re < epsil) count++; + if (count > 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + result.append<np::ndarray>(npU); + if (switcher != 0) result.append<np::ndarray>(npMap); + + } + return result; +} + + +BOOST_PYTHON_MODULE(regularizers) { np::initialize(); //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "fista"; + package.attr("__path__") = "regularizers"; np::dtype dt1 = np::dtype::get_builtin<uint8_t>(); np::dtype dt2 = np::dtype::get_builtin<uint16_t>(); @@ -682,4 +879,5 @@ BOOST_PYTHON_MODULE(fista) def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); + def("LLT_model", LLT_model); }
\ No newline at end of file |