From 0f40ee8ad7d6e0b3b7059e5e1242d8ab97cd3caf Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 3 Aug 2017 15:26:29 +0100 Subject: Added Python modules Matlab2Python_utils.cpp contains utilities for handling numpy arrays. Together with setup_test.py it creates a functional module for testing. fista_module.cpp and setup.py are meant for the real fista module. --- src/Python/Matlab2Python_utils.cpp | 206 ++++++++++++++++++++++++ src/Python/fista_module.cpp | 315 +++++++++++++++++++++++++++++++++++++ src/Python/setup.py | 58 +++++++ src/Python/setup_test.py | 58 +++++++ 4 files changed, 637 insertions(+) create mode 100644 src/Python/Matlab2Python_utils.cpp create mode 100644 src/Python/fista_module.cpp create mode 100644 src/Python/setup.py create mode 100644 src/Python/setup_test.py (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp new file mode 100644 index 0000000..138e8da --- /dev/null +++ b/src/Python/Matlab2Python_utils.cpp @@ -0,0 +1,206 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazanteev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include + +#include +#include +#include "boost/tuple/tuple.hpp" + +#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) +#include +// this trick only if compiler is MSVC +__if_not_exists(uint8_t) { typedef __int8 uint8_t; } +__if_not_exists(uint16_t) { typedef __int8 uint16_t; } +#endif + +namespace bp = boost::python; +namespace np = boost::python::numpy; + +/*! in the Matlab implementation this is called as +void mexFunction( +int nlhs, mxArray *plhs[], +int nrhs, const mxArray *prhs[]) +where: +prhs Array of pointers to the INPUT mxArrays +nrhs int number of INPUT mxArrays + +nlhs Array of pointers to the OUTPUT mxArrays +plhs int number of OUTPUT mxArrays + +*********************************************************** + +*********************************************************** +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*********************************************************** +char *mxArrayToString(const mxArray *array_ptr); +args: array_ptr Pointer to mxCHAR array. +Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array. +Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string. +*********************************************************** +mxClassID mxGetClassID(const mxArray *pm); +args: pm Pointer to an mxArray +Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types, +mxGetClassId returns a unique value identifying the class of the array contents. +Use mxIsClass to determine whether an array is of a specific user-defined type. + +mxClassID Value MATLAB Type MEX Type C Primitive Type +mxINT8_CLASS int8 int8_T char, byte +mxUINT8_CLASS uint8 uint8_T unsigned char, byte +mxINT16_CLASS int16 int16_T short +mxUINT16_CLASS uint16 uint16_T unsigned short +mxINT32_CLASS int32 int32_T int +mxUINT32_CLASS uint32 uint32_T unsigned int +mxINT64_CLASS int64 int64_T long long +mxUINT64_CLASS uint64 uint64_T unsigned long long +mxSINGLE_CLASS single float float +mxDOUBLE_CLASS double double double + +**************************************************************** +double *mxGetPr(const mxArray *pm); +args: pm Pointer to an mxArray of type double +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +**************************************************************** +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). +Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + +void mexErrMessageText(char* text) { + std::cerr << text << std::endl; +} + +/* +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*/ + +template +double mxGetScalar(const np::ndarray plh) { + return (double)bp::extract(plh[0]); +} + + + +template +T * mxGetData(const np::ndarray pm) { + //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. + //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. + /*Access the numpy array pointer: + char * get_data() const; + Returns: Array’s raw data pointer as a char + Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. + probably this would work. + A = reinterpret_cast(prhs[0]); + */ + return reinterpret_cast(prhs[0]); +} + +template +np::ndarray zeros(int dims , int * dim_array, T el) { + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + + +bp::list mexFunction( np::ndarray input ) { + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast( input.get_data() ); + int * B = reinterpret_cast( zz.get_data() ); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = (float)val; + std::memcpy(B + index , &val, sizeof(int)); + std::memcpy(C + index , &fval, sizeof(float)); + } + } + } + + + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} + + +BOOST_PYTHON_MODULE(fista) +{ + np::initialize(); + + //To specify that this module is a package + bp::object package = bp::scope(); + package.attr("__path__") = "fista"; + + np::dtype dt1 = np::dtype::get_builtin(); + np::dtype dt2 = np::dtype::get_builtin(); + + //import_array(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + def("mexFunction", mexFunction); +} \ No newline at end of file diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp new file mode 100644 index 0000000..5344083 --- /dev/null +++ b/src/Python/fista_module.cpp @@ -0,0 +1,315 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazanteev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include + +#include +#include +#include "boost/tuple/tuple.hpp" + +// include the regularizers +#include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "PatchBased_Regul_core.h" +#include "SplitBregman_TV_core.h" +#include "TGV_PD_core.h" + +#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) +#include +// this trick only if compiler is MSVC +__if_not_exists(uint8_t) { typedef __int8 uint8_t; } +__if_not_exists(uint16_t) { typedef __int8 uint16_t; } +#endif + +namespace bp = boost::python; +namespace np = boost::python::numpy; + + +/*! in the Matlab implementation this is called as +void mexFunction( +int nlhs, mxArray *plhs[], +int nrhs, const mxArray *prhs[]) +where: +prhs Array of pointers to the INPUT mxArrays +nrhs int number of INPUT mxArrays + +nlhs Array of pointers to the OUTPUT mxArrays +plhs int number of OUTPUT mxArrays + +*********************************************************** +mxGetData +args: pm Pointer to an mxArray +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +*********************************************************** +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*********************************************************** +char *mxArrayToString(const mxArray *array_ptr); +args: array_ptr Pointer to mxCHAR array. +Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array. +Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string. +*********************************************************** +mxClassID mxGetClassID(const mxArray *pm); +args: pm Pointer to an mxArray +Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types, +mxGetClassId returns a unique value identifying the class of the array contents. +Use mxIsClass to determine whether an array is of a specific user-defined type. + +mxClassID Value MATLAB Type MEX Type C Primitive Type +mxINT8_CLASS int8 int8_T char, byte +mxUINT8_CLASS uint8 uint8_T unsigned char, byte +mxINT16_CLASS int16 int16_T short +mxUINT16_CLASS uint16 uint16_T unsigned short +mxINT32_CLASS int32 int32_T int +mxUINT32_CLASS uint32 uint32_T unsigned int +mxINT64_CLASS int64 int64_T long long +mxUINT64_CLASS uint64 uint64_T unsigned long long +mxSINGLE_CLASS single float float +mxDOUBLE_CLASS double double double + +**************************************************************** +double *mxGetPr(const mxArray *pm); +args: pm Pointer to an mxArray of type double +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +**************************************************************** +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); +args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. + dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. + For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. + classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. + For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. + ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). + Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). + +Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). + If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not + enough free heap space to create the mxArray. +*/ + +template +np::ndarray zeros(int dims, int * dim_array, T el) { + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + + +bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) { + /* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) + * + * Input Parameters: + * 1. Noisy image/volume + * 2. lambda - regularization parameter + * 3. Number of iterations [OPTIONAL parameter] + * 4. eplsilon - tolerance constant [OPTIONAL parameter] + * 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] + * + * Output: + * Filtered/regularized image + * + * All sanity checks and default values are set in Python + */ + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + const int dim_array[3]; + float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + number_of_dims = input.get_nd(); + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -11; + } + else { + dim_array[2] = input.shape(2); + } + + /*Handling Matlab input data*/ + //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + + /*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ + A = reinterpret_cast(input.get_data()); + + + //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ + mu = (float)d_mu; + //iter = 35; /* default iterations number */ + iter = niterations; + //epsil = 0.0001; /* default tolerance constant */ + epsil = (float)d_epsil; + //methTV = 0; /* default isotropic TV penalty */ + methTV = TV_type; + //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ + //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ + //if (nrhs == 5) { + // char *penalty_type; + // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ + // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); + // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */ + // mxFree(penalty_type); + //} + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + lambda = 2.0f*mu; + count = 1; + re_old = 0.0f; + /*Handling Matlab output data*/ + dimY = dim_array[0]; dimX = dim_array[1]; dimZ = dim_array[2]; + + if (number_of_dims == 2) { + dimZ = 1; /*2D case*/ + /* + mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); +args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. + dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. + For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. + classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. + For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. + ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). + Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). + + mxCreateNumericArray initializes all its real data elements to 0. +*/ + +/* + U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +*/ + //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U = A = reinterpret_castinput.get_data(); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + /*printf("%f %i %i \n", re, ll, count); */ + + /*copyIm(U_old, U, dimX, dimY, dimZ); */ + } + printf("SB iterations stopped at iteration: %i\n", ll); + } + if (number_of_dims == 3) { + U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + /*printf("%f %i %i \n", re, ll, count); */ + re_old = re; + } + printf("SB iterations stopped at iteration: %i\n", ll); + } + bp::list result; + return result; +} + + +BOOST_PYTHON_MODULE(fista) +{ + np::initialize(); + + //To specify that this module is a package + bp::object package = bp::scope(); + package.attr("__path__") = "fista"; + + np::dtype dt1 = np::dtype::get_builtin(); + np::dtype dt2 = np::dtype::get_builtin(); + + + def("mexFunction", mexFunction); +} \ No newline at end of file diff --git a/src/Python/setup.py b/src/Python/setup.py new file mode 100644 index 0000000..ffb9c02 --- /dev/null +++ b/src/Python/setup.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform + +cil_version=os.environ['CIL_VERSION'] +if cil_version == '': + print("Please set the environmental variable CIL_VERSION") + sys.exit(1) + +library_include_path = "" +library_lib_path = "" +try: + library_include_path = os.environ['LIBRARY_INC'] + library_lib_path = os.environ['LIBRARY_LIB'] +except: + library_include_path = os.environ['PREFIX']+'/include' + pass + +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +if platform.system() == 'Windows': + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] + extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + if sys.version_info.major == 3 : + extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] + else: + extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: + extra_include_dirs += ["../ContourTree/", "../Core/","."] + if sys.version_info.major == 3: + extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] + else: + extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( + name='ccpi', + description='CCPi Core Imaging Library - FISTA Reconstruction Module', + version=cil_version, + cmdclass = {'build_ext': build_ext}, + ext_modules = [Extension("fista", + sources=[ "Matlab2Python_utils.cpp", + ], + include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), + + ], + zip_safe = False, + packages = {'ccpi','ccpi.reconstruction'}, +) diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py new file mode 100644 index 0000000..ffb9c02 --- /dev/null +++ b/src/Python/setup_test.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform + +cil_version=os.environ['CIL_VERSION'] +if cil_version == '': + print("Please set the environmental variable CIL_VERSION") + sys.exit(1) + +library_include_path = "" +library_lib_path = "" +try: + library_include_path = os.environ['LIBRARY_INC'] + library_lib_path = os.environ['LIBRARY_LIB'] +except: + library_include_path = os.environ['PREFIX']+'/include' + pass + +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +if platform.system() == 'Windows': + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] + extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + if sys.version_info.major == 3 : + extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] + else: + extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: + extra_include_dirs += ["../ContourTree/", "../Core/","."] + if sys.version_info.major == 3: + extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] + else: + extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( + name='ccpi', + description='CCPi Core Imaging Library - FISTA Reconstruction Module', + version=cil_version, + cmdclass = {'build_ext': build_ext}, + ext_modules = [Extension("fista", + sources=[ "Matlab2Python_utils.cpp", + ], + include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), + + ], + zip_safe = False, + packages = {'ccpi','ccpi.reconstruction'}, +) -- cgit v1.2.3 From cf94b779bf8f11128ce0e4535ba1e12ccb2b50a1 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 3 Aug 2017 16:52:20 +0100 Subject: added FGP_TV wrapper --- src/Python/fista_module.cpp | 576 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 473 insertions(+), 103 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 5344083..2492884 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -26,12 +26,10 @@ limitations under the License. #include #include "boost/tuple/tuple.hpp" -// include the regularizers -#include "FGP_TV_core.h" -#include "LLT_model_core.h" -#include "PatchBased_Regul_core.h" #include "SplitBregman_TV_core.h" -#include "TGV_PD_core.h" +#include "FGP_TV_core.h" + + #if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) #include @@ -43,7 +41,6 @@ __if_not_exists(uint16_t) { typedef __int8 uint16_t; } namespace bp = boost::python; namespace np = boost::python::numpy; - /*! in the Matlab implementation this is called as void mexFunction( int nlhs, mxArray *plhs[], @@ -56,9 +53,7 @@ nlhs Array of pointers to the OUTPUT mxArrays plhs int number of OUTPUT mxArrays *********************************************************** -mxGetData -args: pm Pointer to an mxArray -Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. + *********************************************************** double mxGetScalar(const mxArray *pm); args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. @@ -92,77 +87,143 @@ double *mxGetPr(const mxArray *pm); args: pm Pointer to an mxArray of type double Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. **************************************************************** -mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); -args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. - dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. - For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. - classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. - For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. - ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). - Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). - +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). - If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not - enough free heap space to create the mxArray. +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + +void mexErrMessageText(char* text) { + std::cerr << text << std::endl; +} + +/* +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. */ template -np::ndarray zeros(int dims, int * dim_array, T el) { - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); - np::dtype dtype = np::dtype::get_builtin(); - np::ndarray zz = np::zeros(shape, dtype); - return zz; +double mxGetScalar(const np::ndarray plh) { + return (double)bp::extract(plh[0]); } -bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) { - /* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) - * - * Input Parameters: - * 1. Noisy image/volume - * 2. lambda - regularization parameter - * 3. Number of iterations [OPTIONAL parameter] - * 4. eplsilon - tolerance constant [OPTIONAL parameter] - * 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] - * - * Output: - * Filtered/regularized image - * - * All sanity checks and default values are set in Python + +template +T * mxGetData(const np::ndarray pm) { + //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. + //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. + /*Access the numpy array pointer: + char * get_data() const; + Returns: Array’s raw data pointer as a char + Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. + probably this would work. + A = reinterpret_cast(prhs[0]); */ + return reinterpret_cast(prhs[0]); +} + + + + +bp::list mexFunction(np::ndarray input) { + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast(input.get_data()); + int * B = reinterpret_cast(zz.get_data()); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = (float)val; + std::memcpy(B + index, &val, sizeof(int)); + std::memcpy(C + index, &fval, sizeof(float)); + } + } + } + + + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} + +bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + + // the result is in the following list + bp::list result; + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int dim_array[3]; + const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; - + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - number_of_dims = input.get_nd(); + + int number_of_dims = input.get_nd(); + int dim_array[3]; dim_array[0] = input.shape(0); dim_array[1] = input.shape(1); if (number_of_dims == 2) { - dim_array[2] = -11; + dim_array[2] = -1; } else { dim_array[2] = input.shape(2); } - /*Handling Matlab input data*/ + // Parameter handling is be done in Python + ///*Handling Matlab input data*/ //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); - /*Handling Matlab input data*/ + ///*Handling Matlab input data*/ //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ A = reinterpret_cast(input.get_data()); - //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ mu = (float)d_mu; + //iter = 35; /* default iterations number */ - iter = niterations; + //epsil = 0.0001; /* default tolerance constant */ epsil = (float)d_epsil; //methTV = 0; /* default isotropic TV penalty */ - methTV = TV_type; //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ //if (nrhs == 5) { @@ -182,34 +243,31 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, doub if (number_of_dims == 2) { dimZ = 1; /*2D case*/ - /* - mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); -args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. - dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. - For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. - classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. - For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. - ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). - Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). - - mxCreateNumericArray initializes all its real data elements to 0. -*/ - -/* - U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); -*/ //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - U = A = reinterpret_castinput.get_data(); - U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ @@ -245,59 +303,370 @@ args: ndim: Number of dimensions. If you specify a value for ndim that is less /*printf("%f %i %i \n", re, ll, count); */ /*copyIm(U_old, U, dimX, dimY, dimZ); */ + result.append(npU); + result.append(ll); + } + //printf("SB iterations stopped at iteration: %i\n", ll); + if (number_of_dims == 3) { + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npDz = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npBz = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Dz = reinterpret_cast(npDz.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + Bz = reinterpret_cast(npBz.get_data()); + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + /*printf("%f %i %i \n", re, ll, count); */ + re_old = re; + } + //printf("SB iterations stopped at iteration: %i\n", ll); + result.append(npU); + result.append(ll); } - printf("SB iterations stopped at iteration: %i\n", ll); } - if (number_of_dims == 3) { - U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + return result; - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ +} - /* begin outer SB iterations */ +bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + + // the result is in the following list + bp::list result; + + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + // Parameter handling is be done in Python + ///*Handling Matlab input data*/ + //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + + ///*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ + A = reinterpret_cast(input.get_data()); + + //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ + mu = (float)d_mu; + + //iter = 35; /* default iterations number */ + + //epsil = 0.0001; /* default tolerance constant */ + epsil = (float)d_epsil; + //methTV = 0; /* default isotropic TV penalty */ + //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ + //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ + //if (nrhs == 5) { + // char *penalty_type; + // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ + // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); + // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */ + // mxFree(penalty_type); + //} + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + //plhs[1] = mxCreateNumericMatrix(1, 1, mxSINGLE_CLASS, mxREAL); + bp::tuple shape1 = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray out1 = np::zeros(shape1, dtype); + + //float *funcvalA = (float *)mxGetData(plhs[1]); + float * funcvalA = reinterpret_cast(out1.get_data()); + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = dim_array[2]; + + tk = 1.0f; + tkp1 = 1.0f; + count = 1; + re_old = 0.0f; + + if (number_of_dims == 2) { + dimZ = 1; /*2D case*/ + /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + R1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + R2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npD = np::zeros(shape, dtype); + np::ndarray npD_old = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npP1_old = np::zeros(shape, dtype); + np::ndarray npP2_old = np::zeros(shape, dtype); + np::ndarray npR1 = np::zeros(shape, dtype); + np::ndarray npR2 = zeros(2, dim_array, (float)0); + + D = reinterpret_cast(npD.get_data()); + D_old = reinterpret_cast(npD_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + P1_old = reinterpret_cast(npP1_old.get_data()); + P2_old = reinterpret_cast(npP2_old.get_data()); + R1 = reinterpret_cast(npR1.get_data()); + R2 = reinterpret_cast(npR2.get_data()); + + /* begin iterations */ for (ll = 0; ll 4) break; + if (count > 3) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j 2) { - if (re > re_old) break; + if (re > re_old) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j(npD); + result.append(out1); + result.append(ll); + } + if (number_of_dims == 3) { + /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P1_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P2_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P3_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npD = np::zeros(shape, dtype); + np::ndarray npD_old = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npP3 = np::zeros(shape, dtype); + np::ndarray npP1_old = np::zeros(shape, dtype); + np::ndarray npP2_old = np::zeros(shape, dtype); + np::ndarray npP3_old = np::zeros(shape, dtype); + np::ndarray npR1 = np::zeros(shape, dtype); + np::ndarray npR2 = np::zeros(shape, dtype); + np::ndarray npR3 = np::zeros(shape, dtype); + + D = reinterpret_cast(npD.get_data()); + D_old = reinterpret_cast(npD_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + P3 = reinterpret_cast(npP3.get_data()); + P1_old = reinterpret_cast(npP1_old.get_data()); + P2_old = reinterpret_cast(npP2_old.get_data()); + P3_old = reinterpret_cast(npP3_old.get_data()); + R1 = reinterpret_cast(npR1.get_data()); + R2 = reinterpret_cast(npR2.get_data()); + R2 = reinterpret_cast(npR3.get_data()); + /* begin iterations */ + for (ll = 0; ll 3) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j 2) { + if (re > re_old) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j(npD); + result.append(out1); + result.append(ll); } - bp::list result; + return result; } - BOOST_PYTHON_MODULE(fista) { @@ -310,6 +679,7 @@ BOOST_PYTHON_MODULE(fista) np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); - def("mexFunction", mexFunction); + def("SplitBregman_TV", SplitBregman_TV); + def("FGP_TV", FGP_TV); } \ No newline at end of file -- cgit v1.2.3 From 16a2b514d191563eb7691ee24f063bd27f5ff12d Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:15:03 +0100 Subject: compilation fixes --- src/Python/setup.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index ffb9c02..a8feb1c 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -29,14 +29,14 @@ extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\env extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] extra_libraries = [] if platform.system() == 'Windows': - extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] - extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB' , '/openmp' ] + extra_include_dirs += ["..\\..\\main_func\\regularizers_CPU\\","."] if sys.version_info.major == 3 : extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] else: extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] else: - extra_include_dirs += ["../ContourTree/", "../Core/","."] + extra_include_dirs += ["../../main_func/regularizers_CPU","."] if sys.version_info.major == 3: extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] else: @@ -47,8 +47,12 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("fista", - sources=[ "Matlab2Python_utils.cpp", + ext_modules = [Extension("regularizers", + sources=["fista_module.cpp", + "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", + "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", + "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", + "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 0859756c281f514d53b05a3cc9dc5035136d73a9 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:15:17 +0100 Subject: test facility for regularizers --- src/Python/test_regularizers.py | 265 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 src/Python/test_regularizers.py (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py new file mode 100644 index 0000000..6abfba4 --- /dev/null +++ b/src/Python/test_regularizers.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Aug 4 11:10:05 2017 + +@author: ofn77899 +""" + +from ccpi.viewer.CILViewer2D import Converter +import vtk + +import regularizers +import matplotlib.pyplot as plt +import numpy as np +import os +from enum import Enum + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) + 4) + 5) + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = regularizers.SplitBregman_TV + FGP_TV = regularizers.FGP_TV + LLT_model = regularizers.LLT_model + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm): + + self.algorithm = algorithm + self.pars = self.parsForAlgorithm(algorithm) + # __init__ + + def parsForAlgorithm(self, algorithm): + pars = dict() + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + return pars + # parsForAlgorithm + + def __call__(self, input, regularization_parameter, **kwargs): + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + self.pars['input'] = input + self.pars['regularization_parameter'] = regularization_parameter + #for key, value in self.pars.items(): + # print("{0} = {1}".format(key, value)) + + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if None in self.pars: + raise Exception("Not all parameters have been provided") + else: + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + return out + + +#Example: +# figure; +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +reader = vtk.vtkTIFFReader() +reader.SetFileName(os.path.normpath(filename)) +reader.Update() +#vtk returns 3D images, let's take just the one slice there is as 2D +Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 + +#imgplot = plt.imshow(Im) +perc = 0.05 +u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +# map the u0 u0->u0>0 +f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +u0 = f(u0).astype('float32') + +# plot +fig = plt.figure() +a=fig.add_subplot(2,3,1) +a.set_title('Original') +imgplot = plt.imshow(Im) + +a=fig.add_subplot(2,3,2) +a.set_title('noise') +imgplot = plt.imshow(u0) + + +############################################################################## +# Call regularizer + +####################### SplitBregman_TV ##################################### +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + +out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) +out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) +pars = out2[2] + +a=fig.add_subplot(2,3,3) +a.set_title('SplitBregman_TV') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['TV_penalty'].name) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + +###################### FGP_TV ######################################### +# u = FGP_TV(single(u0), 0.05, 100, 1e-04); +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, + number_of_iterations=10) +pars = out2[-1] + +a=fig.add_subplot(2,3,4) +a.set_title('FGP_TV') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['TV_penalty'].name) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + +###################### LLT_model ######################################### +# * u0 = Im + .03*randn(size(Im)); % adding noise +# [Den] = LLT_model(single(u0), 10, 0.1, 1); +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10., + time_step=0.1, + tolerance_constant=1e-4, + number_of_iterations=10) +pars = out2[-1] + +a=fig.add_subplot(2,3,5) +a.set_title('LLT_model') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['time_step'] + ) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + + + -- cgit v1.2.3 From 3e96c0d80387225894a8e5f1456ea310cd7e797b Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:16:37 +0100 Subject: Added 3 regularizers SplitBregman_TV FGP_TV LLT_model --- src/Python/fista_module.cpp | 266 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 232 insertions(+), 34 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 2492884..d890b10 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,8 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "utils.h" @@ -131,6 +133,18 @@ T * mxGetData(const np::ndarray pm) { return reinterpret_cast(prhs[0]); } +template +np::ndarray zeros(int dims, int * dim_array, T el) { + bp::tuple shape; + if (dims == 3) + shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + else if (dims == 2) + shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + @@ -169,7 +183,6 @@ bp::list mexFunction(np::ndarray input) { } } - bp::list result; result.append(number_of_dims); @@ -189,14 +202,14 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int *dim_array; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; + //const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -252,26 +265,26 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); np::dtype dtype = np::dtype::get_builtin(); - np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU = np::zeros(shape, dtype); np::ndarray npU_old = np::zeros(shape, dtype); - np::ndarray npDx = np::zeros(shape, dtype); - np::ndarray npDy = np::zeros(shape, dtype); - np::ndarray npBx = np::zeros(shape, dtype); - np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); - U = reinterpret_cast(npU.get_data()); + U = reinterpret_cast(npU.get_data()); U_old = reinterpret_cast(npU_old.get_data()); - Dx = reinterpret_cast(npDx.get_data()); - Dy = reinterpret_cast(npDy.get_data()); - Bx = reinterpret_cast(npBx.get_data()); - By = reinterpret_cast(npBy.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ - for (ll = 0; ll(npU); - result.append(ll); + } //printf("SB iterations stopped at iteration: %i\n", ll); - if (number_of_dims == 3) { + result.append(npU); + result.append(ll); + } + if (number_of_dims == 3) { /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); @@ -375,24 +390,25 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi result.append(npU); result.append(ll); } - } return result; -} + } + + bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; - float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval, mu; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -512,7 +528,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j(input.get_data()); + lambda = (float)d_lambda; + tau = (float)d_tau; + // iter is passed as parameter + epsil = (float)d_epsil; + // switcher is passed as parameter + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = 1; + + if (number_of_dims == 2) { + /*2D case*/ + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + D1 = reinterpret_cast(npD1.get_data()); + D2 = reinterpret_cast(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der2D(U, D1, D2, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd2D(U0, U, D1, D2, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + + result.append(npU); + } + else if (number_of_dims == 3) { + /*3D case*/ + dimZ = dim_array[2]; + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + if (switcher != 0) { + Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + }*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + np::ndarray npD3 = np::zeros(shape, dtype); + np::ndarray npMap = np::zeros(shape, np::dtype::get_builtin()); + Map = reinterpret_cast(npMap.get_data()); + if (switcher != 0) { + //Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + + Map = reinterpret_cast(npMap.get_data()); + } + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + D1 = reinterpret_cast(npD1.get_data()); + D2 = reinterpret_cast(npD2.get_data()); + D3 = reinterpret_cast(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + + if (switcher == 1) { + /* apply restrictive smoothing */ + calcMap(U, Map, dimX, dimY, dimZ); + /*clear outliers */ + cleanMap(Map, dimX, dimY, dimZ); + } + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der3D(U, D1, D2, D3, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd3D(U0, U, D1, D2, D3, Map, switcher, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + result.append(npU); + if (switcher != 0) result.append(npMap); + + } + return result; +} + + +BOOST_PYTHON_MODULE(regularizers) { np::initialize(); //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "fista"; + package.attr("__path__") = "regularizers"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); @@ -682,4 +879,5 @@ BOOST_PYTHON_MODULE(fista) def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); + def("LLT_model", LLT_model); } \ No newline at end of file -- cgit v1.2.3 From fbaf7281141e0ddad5046b433ba0a72d360d09aa Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:16:53 +0100 Subject: minor change --- src/Python/Matlab2Python_utils.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index 138e8da..6aaad90 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -128,7 +128,11 @@ T * mxGetData(const np::ndarray pm) { template np::ndarray zeros(int dims , int * dim_array, T el) { - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + bp::tuple shape; + if (dims == 3) + shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + else if (dims == 2) + shape = bp::make_tuple(dim_array[0], dim_array[1]); np::dtype dtype = np::dtype::get_builtin(); np::ndarray zz = np::zeros(shape, dtype); return zz; @@ -163,7 +167,7 @@ bp::list mexFunction( np::ndarray input ) { for (int k = 0; k < dim_array[2]; k++) { int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; int val = (*(A + index)); - float fval = (float)val; + float fval = sqrt((float)val); std::memcpy(B + index , &val, sizeof(int)); std::memcpy(C + index , &fval, sizeof(float)); } @@ -186,7 +190,7 @@ bp::list mexFunction( np::ndarray input ) { } -BOOST_PYTHON_MODULE(fista) +BOOST_PYTHON_MODULE(prova) { np::initialize(); -- cgit v1.2.3 From 9974b5a4bf88ac3e7929d7c33a911c90fcfcff29 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:17:18 +0100 Subject: test for general boost::python / numpy routines --- src/Python/test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 src/Python/test.py (limited to 'src') diff --git a/src/Python/test.py b/src/Python/test.py new file mode 100644 index 0000000..e283f89 --- /dev/null +++ b/src/Python/test.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Aug 3 14:08:09 2017 + +@author: ofn77899 +""" + +import fista +import numpy as np + +a = np.asarray([i for i in range(3*4*5)]) +a = a.reshape([3,4,5]) +print (a) +b = fista.mexFunction(a) +#print (b) +print (b[4].shape) +print (b[4]) +print (b[5]) \ No newline at end of file -- cgit v1.2.3 From 5a27224a373c12ba8e3af6e25a4c5eaec522b834 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:49:21 +0100 Subject: added PatchBased_Regul --- src/Python/fista_module.cpp | 123 +++++++++++++++++++++++++++++++++++++++++++- src/Python/setup.py | 1 + 2 files changed, 123 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index d890b10..c2d9352 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" #include "LLT_model_core.h" +#include "PatchBased_Regul_core.h" #include "utils.h" @@ -793,7 +794,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d if (switcher != 0) { Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); }*/ - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); np::dtype dtype = np::dtype::get_builtin(); @@ -865,6 +866,126 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d } +bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) { + // the result is in the following list + bp::list result; + + int N, M, Z, numdims, SearchW, /*SimilW, SearchW_real,*/ padXY, newsizeX, newsizeY, newsizeZ, switchpad_crop; + //const int *dims; + float *A, *B = NULL, *Ap = NULL, *Bp = NULL, h, lambda; + + numdims = input.get_nd(); + int dims[3]; + + dims[0] = input.shape(0); + dims[1] = input.shape(1); + if (numdims == 2) { + dims[2] = -1; + } + else { + dims[2] = input.shape(2); + } + /*numdims = mxGetNumberOfDimensions(prhs[0]); + dims = mxGetDimensions(prhs[0]);*/ + + N = dims[0]; + M = dims[1]; + Z = dims[2]; + + //if ((numdims < 2) || (numdims > 3)) { mexErrMsgTxt("The input should be 2D image or 3D volume"); } + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + + //if (nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter"); + + ///*Handling inputs*/ + //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */ + //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ + //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */ + //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ + //lambda = (float)mxGetScalar(prhs[4]); /* regularization parameter */ + + //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); + //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + + SearchW = SearchW_real + 2 * SimilW; + + /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ + /* SimilW_full = 2*SimilW + 1; */ /* the full similarity window size */ + + + padXY = SearchW + 2 * SimilW; /* padding sizes */ + newsizeX = N + 2 * (padXY); /* the X size of the padded array */ + newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ + newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ + int N_dims[] = { newsizeX, newsizeY, newsizeZ }; + + /******************************2D case ****************************/ + if (numdims == 2) { + ///*Handling output*/ + //B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL)); + ///*allocating memory for the padded arrays */ + //Ap = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + //Bp = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + ///**************************************************************************/ + + bp::tuple shape = bp::make_tuple(N, M); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npB = np::zeros(shape, dtype); + + shape = bp::make_tuple(newsizeX, newsizeY); + np::ndarray npAp = np::zeros(shape, dtype); + np::ndarray npBp = np::zeros(shape, dtype); + B = reinterpret_cast(npB.get_data()); + Ap = reinterpret_cast(npAp.get_data()); + Bp = reinterpret_cast(npBp.get_data()); + + /*Perform padding of image A to the size of [newsizeX * newsizeY] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + result.append(npB); + } + else + { + /******************************3D case ****************************/ + ///*Handling output*/ + //B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL)); + ///*allocating memory for the padded arrays */ + //Ap = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + //Bp = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + /**************************************************************************/ + bp::tuple shape = bp::make_tuple(dims[0], dims[1], dims[2]); + bp::tuple shape_AB = bp::make_tuple(N_dims[0], N_dims[1], N_dims[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npB = np::zeros(shape, dtype); + np::ndarray npAp = np::zeros(shape_AB, dtype); + np::ndarray npBp = np::zeros(shape_AB, dtype); + B = reinterpret_cast(npB.get_data()); + Ap = reinterpret_cast(npAp.get_data()); + Bp = reinterpret_cast(npBp.get_data()); + /*Perform padding of image A to the size of [newsizeX * newsizeY * newsizeZ] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + PB_FUNC3D(Ap, Bp, newsizeY, newsizeX, newsizeZ, padXY, SearchW, SimilW, (float)h, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + result.append(npB); + } /*end else ndims*/ + + return result; +} + BOOST_PYTHON_MODULE(regularizers) { np::initialize(); diff --git a/src/Python/setup.py b/src/Python/setup.py index a8feb1c..a4eed14 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -52,6 +52,7 @@ setup( "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", + "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 6589fa197d9f87f7a37f46943aa995d97f50bb46 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 7 Aug 2017 17:21:12 +0100 Subject: added TGV_PD, removed useless code --- src/Python/fista_module.cpp | 245 ++++++++++++++++++++++++++------------------ 1 file changed, 146 insertions(+), 99 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index c2d9352..eacda3d 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -30,6 +30,7 @@ limitations under the License. #include "FGP_TV_core.h" #include "LLT_model_core.h" #include "PatchBased_Regul_core.h" +#include "TGV_PD_core.h" #include "utils.h" @@ -103,101 +104,8 @@ If unsuccessful in a MEX file, the MEX file terminates and returns control to th enough free heap space to create the mxArray. */ -void mexErrMessageText(char* text) { - std::cerr << text << std::endl; -} - -/* -double mxGetScalar(const mxArray *pm); -args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. -Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. -*/ - -template -double mxGetScalar(const np::ndarray plh) { - return (double)bp::extract(plh[0]); -} - - - -template -T * mxGetData(const np::ndarray pm) { - //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. - //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. - /*Access the numpy array pointer: - char * get_data() const; - Returns: Array’s raw data pointer as a char - Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. - probably this would work. - A = reinterpret_cast(prhs[0]); - */ - return reinterpret_cast(prhs[0]); -} - -template -np::ndarray zeros(int dims, int * dim_array, T el) { - bp::tuple shape; - if (dims == 3) - shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); - else if (dims == 2) - shape = bp::make_tuple(dim_array[0], dim_array[1]); - np::dtype dtype = np::dtype::get_builtin(); - np::ndarray zz = np::zeros(shape, dtype); - return zz; -} - - -bp::list mexFunction(np::ndarray input) { - int number_of_dims = input.get_nd(); - int dim_array[3]; - - dim_array[0] = input.shape(0); - dim_array[1] = input.shape(1); - if (number_of_dims == 2) { - dim_array[2] = -1; - } - else { - dim_array[2] = input.shape(2); - } - - /**************************************************************************/ - np::ndarray zz = zeros(3, dim_array, (int)0); - np::ndarray fzz = zeros(3, dim_array, (float)0); - /**************************************************************************/ - - int * A = reinterpret_cast(input.get_data()); - int * B = reinterpret_cast(zz.get_data()); - float * C = reinterpret_cast(fzz.get_data()); - - //Copy data and cast - for (int i = 0; i < dim_array[0]; i++) { - for (int j = 0; j < dim_array[1]; j++) { - for (int k = 0; k < dim_array[2]; k++) { - int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; - int val = (*(A + index)); - float fval = (float)val; - std::memcpy(B + index, &val, sizeof(int)); - std::memcpy(C + index, &fval, sizeof(float)); - } - } - } - - bp::list result; - - result.append(number_of_dims); - result.append(dim_array[0]); - result.append(dim_array[1]); - result.append(dim_array[2]); - result.append(zz); - result.append(fzz); - - //result.append(tup); - return result; - -} - bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list @@ -487,7 +395,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me np::ndarray npP1_old = np::zeros(shape, dtype); np::ndarray npP2_old = np::zeros(shape, dtype); np::ndarray npR1 = np::zeros(shape, dtype); - np::ndarray npR2 = zeros(2, dim_array, (float)0); + np::ndarray npR2 = np::zeros(shape, dtype); D = reinterpret_cast(npD.get_data()); D_old = reinterpret_cast(npD_old.get_data()); @@ -866,7 +774,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d } -bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) { +bp::list PatchBased_Regul(np::ndarray input, double d_lambda, int SearchW_real, int SimilW, double d_h) { // the result is in the following list bp::list result; @@ -899,6 +807,7 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub ///*Handling inputs*/ //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */ + A = reinterpret_cast(input.get_data()); //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */ //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ @@ -907,6 +816,8 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + lambda = (float)d_lambda; + h = (float)d_h; SearchW = SearchW_real + 2 * SimilW; /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ @@ -918,7 +829,6 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ int N_dims[] = { newsizeX, newsizeY, newsizeZ }; - /******************************2D case ****************************/ if (numdims == 2) { ///*Handling output*/ @@ -943,12 +853,13 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub /*Perform padding of image A to the size of [newsizeX * newsizeY] */ switchpad_crop = 0; /*padding*/ pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); - + /* Do PB regularization with the padded array */ PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); - + switchpad_crop = 1; /*cropping*/ pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + result.append(npB); } else @@ -983,6 +894,141 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub result.append(npB); } /*end else ndims*/ + return result; +} + +bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_alpha0, int iter) { + // the result is in the following list + bp::list result; + int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll; + //const int *dim_array; + float *A, *U, *U_old, *P1, *P2, *P3, *Q1, *Q2, *Q3, *Q4, *Q5, *Q6, *Q7, *Q8, *Q9, *V1, *V1_old, *V2, *V2_old, *V3, *V3_old, lambda, L2, tau, sigma, alpha1, alpha0; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + /*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/ + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + + A = reinterpret_cast(input.get_data()); + + //lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/ + //alpha1 = (float)mxGetScalar(prhs[2]); /*first-order term*/ + //alpha0 = (float)mxGetScalar(prhs[3]); /*second-order term*/ + //iter = (int)mxGetScalar(prhs[4]); /*iterations number*/ + //if (nrhs != 5) mexErrMsgTxt("Five input parameters is reqired: Image(2D/3D), Regularization parameter, alpha1, alpha0, Iterations"); + lambda = (float)d_lambda; + alpha1 = (float)d_alpha1; + alpha0 = (float)d_alpha0; + + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; + + if (number_of_dims == 2) { + /*2D case*/ + dimZ = 1; + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npQ1 = np::zeros(shape, dtype); + np::ndarray npQ2 = np::zeros(shape, dtype); + np::ndarray npQ3 = np::zeros(shape, dtype); + np::ndarray npV1 = np::zeros(shape, dtype); + np::ndarray npV1_old = np::zeros(shape, dtype); + np::ndarray npV2 = np::zeros(shape, dtype); + np::ndarray npV2_old = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + Q1 = reinterpret_cast(npQ1.get_data()); + Q2 = reinterpret_cast(npQ2.get_data()); + Q3 = reinterpret_cast(npQ3.get_data()); + V1 = reinterpret_cast(npV1.get_data()); + V1_old = reinterpret_cast(npV1_old.get_data()); + V2 = reinterpret_cast(npV2.get_data()); + V2_old = reinterpret_cast(npV2_old.get_data()); + //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + /*dual variables*/ + /*P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + Q1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Q2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Q3 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + /*printf("%i \n", i);*/ + L2 = 12.0; /*Lipshitz constant*/ + tau = 1.0 / pow(L2, 0.5); + sigma = 1.0 / pow(L2, 0.5); + + /*Copy A to U*/ + copyIm(A, U, dimX, dimY, dimZ); + /* Here primal-dual iterations begin for 2D */ + for (ll = 0; ll < iter; ll++) { + + /* Calculate Dual Variable P */ + DualP_2D(U, V1, V2, P1, P2, dimX, dimY, dimZ, sigma); + + /*Projection onto convex set for P*/ + ProjP_2D(P1, P2, dimX, dimY, dimZ, alpha1); + + /* Calculate Dual Variable Q */ + DualQ_2D(V1, V2, Q1, Q2, Q3, dimX, dimY, dimZ, sigma); + + /*Projection onto convex set for Q*/ + ProjQ_2D(Q1, Q2, Q3, dimX, dimY, dimZ, alpha0); + + /*saving U into U_old*/ + copyIm(U, U_old, dimX, dimY, dimZ); + + /*adjoint operation -> divergence and projection of P*/ + DivProjP_2D(U, A, P1, P2, dimX, dimY, dimZ, lambda, tau); + + /*get updated solution U*/ + newU(U, U_old, dimX, dimY, dimZ); + + /*saving V into V_old*/ + copyIm(V1, V1_old, dimX, dimY, dimZ); + copyIm(V2, V2_old, dimX, dimY, dimZ); + + /* upd V*/ + UpdV_2D(V1, V2, P1, P2, Q1, Q2, Q3, dimX, dimY, dimZ, tau); + + /*get new V*/ + newU(V1, V1_old, dimX, dimY, dimZ); + newU(V2, V2_old, dimX, dimY, dimZ); + } /*end of iterations*/ + + result.append(npU); + } + + + + return result; } @@ -997,8 +1043,9 @@ BOOST_PYTHON_MODULE(regularizers) np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); - def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); def("LLT_model", LLT_model); + def("PatchBased_Regul", PatchBased_Regul); + def("TGV_PD", TGV_PD); } \ No newline at end of file -- cgit v1.2.3 From db50cddf2cfe92c652ff16ce51a3bcecca96de68 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 7 Aug 2017 17:21:54 +0100 Subject: added TGV_PD --- src/Python/setup.py | 1 + src/Python/test_regularizers.py | 195 ++++++++++++++++++++++++++++++++++------ 2 files changed, 168 insertions(+), 28 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index a4eed14..0468722 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -53,6 +53,7 @@ setup( "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", + "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c", "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 6abfba4..6a34749 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -47,6 +47,8 @@ class Regularizer(): SplitBregman_TV = regularizers.SplitBregman_TV FGP_TV = regularizers.FGP_TV LLT_model = regularizers.LLT_model + PatchBased_Regul = regularizers.PatchBased_Regul + TGV_PD = regularizers.TGV_PD # Algorithm class TotalVariationPenalty(Enum): @@ -55,13 +57,17 @@ class Regularizer(): # TotalVariationPenalty def __init__(self , algorithm): - + self.setAlgorithm ( algorithm ) + # __init__ + + def setAlgorithm(self, algorithm): self.algorithm = algorithm self.pars = self.parsForAlgorithm(algorithm) - # __init__ + # setAlgorithm def parsForAlgorithm(self, algorithm): pars = dict() + if algorithm == Regularizer.Algorithm.SplitBregman_TV : pars['algorithm'] = algorithm pars['input'] = None @@ -69,6 +75,7 @@ class Regularizer(): pars['number_of_iterations'] = 35 pars['tolerance_constant'] = 0.0001 pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.FGP_TV : pars['algorithm'] = algorithm pars['input'] = None @@ -76,6 +83,7 @@ class Regularizer(): pars['number_of_iterations'] = 50 pars['tolerance_constant'] = 0.001 pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.LLT_model: pars['algorithm'] = algorithm pars['input'] = None @@ -85,6 +93,24 @@ class Regularizer(): pars['tolerance_constant'] = None pars['restrictive_Z_smoothing'] = 0 + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + + return pars # parsForAlgorithm @@ -98,6 +124,8 @@ class Regularizer(): self.pars['regularization_parameter'] = regularization_parameter #for key, value in self.pars.items(): # print("{0} = {1}".format(key, value)) + if None in self.pars: + raise Exception("Not all parameters have been provided") if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : return self.algorithm(input, regularization_parameter, @@ -112,15 +140,27 @@ class Regularizer(): elif self.algorithm == Regularizer.Algorithm.LLT_model : #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) # no default - if None in self.pars: - raise Exception("Not all parameters have been provided") - else: - return self.algorithm(input, - regularization_parameter, - self.pars['time_step'] , - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['restrictive_Z_smoothing'] ) + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # __call__ @@ -142,13 +182,40 @@ class Regularizer(): @staticmethod def LLT_model(input, regularization_parameter , time_step, number_of_iterations, tolerance_constant, restrictive_Z_smoothing=0): - reg = Regularizer(Regularizer.Algorithm.FGP_TV) + reg = Regularizer(Regularizer.Algorithm.LLT_model) out = list( reg(input, regularization_parameter, time_step=time_step, number_of_iterations=number_of_iterations, tolerance_constant=tolerance_constant, restrictive_Z_smoothing=restrictive_Z_smoothing) ) out.append(reg.pars) return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + return out #Example: @@ -171,17 +238,17 @@ u0 = Im + (perc* np.random.normal(size=np.shape(Im))) f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) u0 = f(u0).astype('float32') -# plot +## plot fig = plt.figure() -a=fig.add_subplot(2,3,1) -a.set_title('Original') -imgplot = plt.imshow(Im) +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) -a=fig.add_subplot(2,3,2) +a=fig.add_subplot(2,3,1) a.set_title('noise') imgplot = plt.imshow(u0) - +reg_output = [] ############################################################################## # Call regularizer @@ -199,8 +266,9 @@ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., numbe TV_Penalty=Regularizer.TotalVariationPenalty.l1) out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) pars = out2[2] +reg_output.append(out2) -a=fig.add_subplot(2,3,3) +a=fig.add_subplot(2,3,2) a.set_title('SplitBregman_TV') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' textstr = textstr % (pars['regularization_parameter'], @@ -213,7 +281,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); @@ -221,7 +289,9 @@ out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, number_of_iterations=10) pars = out2[-1] -a=fig.add_subplot(2,3,4) +reg_output.append(out2) + +a=fig.add_subplot(2,3,3) a.set_title('FGP_TV') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' textstr = textstr % (pars['regularization_parameter'], @@ -234,18 +304,23 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) ###################### LLT_model ######################################### # * u0 = Im + .03*randn(size(Im)); % adding noise # [Den] = LLT_model(single(u0), 10, 0.1, 1); -out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10., - time_step=0.1, - tolerance_constant=1e-4, - number_of_iterations=10) +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +#input, regularization_parameter , time_step, number_of_iterations, +# tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) pars = out2[-1] -a=fig.add_subplot(2,3,5) +reg_output.append(out2) + +a=fig.add_subplot(2,3,4) a.set_title('LLT_model') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' textstr = textstr % (pars['regularization_parameter'], @@ -259,7 +334,71 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) + +###################### PatchBased_Regul ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-1] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) +a.set_title('PatchBased_Regul') +textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f' +textstr = textstr % (pars['regularization_parameter'], + pars['searching_window_ratio'], + pars['similarity_window_ratio'], + pars['PB_filtering_parameter']) + + + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) + + +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-1] +reg_output.append(out2) + +a=fig.add_subplot(2,3,6) +a.set_title('TGV_PD') +textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d' +textstr = textstr % (pars['regularization_parameter'], + pars['first_order_term'], + pars['second_order_term'], + pars['number_of_iterations']) + + + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) -- cgit v1.2.3 From 62b30291105e8a48633629350fc9820b404da2ff Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 17 Aug 2017 16:33:09 +0100 Subject: initial revision --- src/Python/test/astra_test.py | 85 ++++++++++++++++++++++++++++++++++++ src/Python/test/simple_astra_test.py | 25 +++++++++++ 2 files changed, 110 insertions(+) create mode 100644 src/Python/test/astra_test.py create mode 100644 src/Python/test/simple_astra_test.py (limited to 'src') diff --git a/src/Python/test/astra_test.py b/src/Python/test/astra_test.py new file mode 100644 index 0000000..42c375a --- /dev/null +++ b/src/Python/test/astra_test.py @@ -0,0 +1,85 @@ +import astra +import numpy +import filefun + + +# read in the same data as the DemoRD2 +angles = filefun.dlmread("DemoRD2/angles.csv") +darks_ar = filefun.dlmread("DemoRD2/darks_ar.csv", separator=",") +flats_ar = filefun.dlmread("DemoRD2/flats_ar.csv", separator=",") + +if True: + Sino3D = numpy.load("DemoRD2/Sino3D.npy") +else: + sino = filefun.dlmread("DemoRD2/sino_01.csv", separator=",") + a = map (lambda x:x, numpy.shape(sino)) + a.append(20) + + Sino3D = numpy.zeros(tuple(a), dtype="float") + + for i in range(1,numpy.shape(Sino3D)[2]+1): + print("Read file DemoRD2/sino_%02d.csv" % i) + sino = filefun.dlmread("DemoRD2/sino_%02d.csv" % i, separator=",") + Sino3D.T[i-1] = sino.T + +Weights3D = numpy.asarray(Sino3D, dtype="float") + +##angles_rad = angles*(pi/180); % conversion to radians +##size_det = size(data_raw3D,1); % detectors dim +##angSize = size(data_raw3D, 2); % angles dim +##slices_tot = size(data_raw3D, 3); % no of slices +##recon_size = 950; % reconstruction size + + +angles_rad = angles * numpy.pi /180. +size_det, angSize, slices_tot = numpy.shape(Sino3D) +size_det, angSize, slices_tot = [int(i) for i in numpy.shape(Sino3D)] +recon_size = 950 +Z_slices = 3; +det_row_count = Z_slices; + +#proj_geom = astra_create_proj_geom('parallel3d', 1, 1, +# det_row_count, size_det, angles_rad); + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX +proj_geom = astra.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + size_det, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +vol_geom = astra.create_vol_geom(recon_size,recon_size,Z_slices); + +sino = numpy.zeros((size_det, angSize, slices_tot), dtype="float") + +#weights = ones(size(sino)); +weights = numpy.ones(numpy.shape(sino)) + +##################################################################### +## PowerMethod for Lipschitz constant + +N = vol_geom['GridColCount'] +x1 = numpy.random.rand(1,N,N) +#sqweight = sqrt(weights(:,:,1)); +sqweight = numpy.sqrt(weights.T[0]).T +##proj_geomT = proj_geom; +proj_geomT = proj_geom.copy() +##proj_geomT.DetectorRowCount = 1; +proj_geomT['DetectorRowCount'] = 1 +##vol_geomT = vol_geom; +vol_geomT = vol_geom.copy() +##vol_geomT.GridSliceCount = 1; +vol_geomT['GridSliceCount'] = 1 + +##[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + +#sino_id, y = astra.create_sino3d_gpu(x1, proj_geomT, vol_geomT); +sino_id, y = astra.create_sino(x1, proj_geomT, vol_geomT); + +##y = sqweight.*y; +##astra_mex_data3d('delete', sino_id); + + diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py new file mode 100644 index 0000000..905eeea --- /dev/null +++ b/src/Python/test/simple_astra_test.py @@ -0,0 +1,25 @@ +import astra +import numpy + +detectorSpacingX = 1.0 +detectorSpacingY = 1.0 +det_row_count = 128 +det_col_count = 128 + +angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +image_size_x = 64 +image_size_y = 64 +image_size_z = 32 + +vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z) + +x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x) +sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom) -- cgit v1.2.3 From 97e0c63f883f62ed0cc84c969756517fe4bedfe8 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 12:55:19 +0100 Subject: Regularizer.pyfirst commit --- src/Python/Regularizer.py | 322 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 src/Python/Regularizer.py (limited to 'src') diff --git a/src/Python/Regularizer.py b/src/Python/Regularizer.py new file mode 100644 index 0000000..15dbbb4 --- /dev/null +++ b/src/Python/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 8 14:26:00 2017 + +@author: ofn77899 +""" + +import regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 CPU (OMP) regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) LLT_model + 4) PatchBased_Regul + 5) TGV_PD + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = regularizers.SplitBregman_TV + FGP_TV = regularizers.FGP_TV + LLT_model = regularizers.LLT_model + PatchBased_Regul = regularizers.PatchBased_Regul + TGV_PD = regularizers.TGV_PD + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm, debug = True): + self.setAlgorithm ( algorithm ) + self.debug = debug + # __init__ + + def setAlgorithm(self, algorithm): + self.algorithm = algorithm + self.pars = self.getDefaultParsForAlgorithm(algorithm) + # setAlgorithm + + def getDefaultParsForAlgorithm(self, algorithm): + pars = dict() + + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + else: + raise Exception('Unknown regularizer algorithm') + + return pars + # parsForAlgorithm + + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.pars.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + def getParameter(self, **kwargs): + ret = {} + for key , value in kwargs.items(): + if key in self.pars.keys(): + ret[key] = self.pars[key] + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + + def __call__(self, input = None, regularization_parameter = None, **kwargs): + '''Actual call for the regularizer. + + One can either set the regularization parameters first and then call the + algorithm or set the regularization parameter during the call (as + is done in the static methods). + ''' + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + if input is not None: + self.pars['input'] = input + if regularization_parameter is not None: + self.pars['regularization_parameter'] = regularization_parameter + + if self.debug: + print ("--------------------------------------------------") + for key, value in self.pars.items(): + if key== 'algorithm' : + print("{0} = {1}".format(key, value.__name__)) + elif key == 'input': + print("{0} = {1}".format(key, np.shape(value))) + else: + print("{0} = {1}".format(key, value)) + + + if None in self.pars: + raise Exception("Not all parameters have been provided") + + input = self.pars['input'] + regularization_parameter = self.pars['regularization_parameter'] + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if len(np.shape(input)) == 2: + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + elif len(np.shape(input)) == 3: + #assuming it's 3D + # run independent calls on each slice + out3d = input.copy() + for i in range(np.shape(input)[2]): + out = self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # copy the result in the 3D image + out3d.T[i] = out[0].copy() + # append the rest of the info that the algorithm returns + output = [out3d] + for i in range(1,len(out)): + output.append(out[i]) + return output + + + + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.LLT_model) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + start_time = timeit.default_timer() + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + + return out + + def printParametersToString(self): + txt = r'' + for key, value in self.pars.items(): + if key== 'algorithm' : + txt += "{0} = {1}".format(key, value.__name__) + elif key == 'input': + txt += "{0} = {1}".format(key, np.shape(value)) + else: + txt += "{0} = {1}".format(key, value) + txt += '\n' + return txt + -- cgit v1.2.3 From 5ed47a3fc9839b1803731fe5f422d43689f66763 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 12:56:09 +0100 Subject: Test module for Boost Python currently can pass a function to the C++ layer to be evaluated. --- src/Python/Matlab2Python_utils.cpp | 68 +++++++++++++++++++++++++++++++++++++- src/Python/setup_test.py | 6 ++-- src/Python/test.py | 34 ++++++++++++++++--- 3 files changed, 99 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index 6aaad90..e15d738 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -175,6 +175,71 @@ bp::list mexFunction( np::ndarray input ) { } + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} +bp::list doSomething(np::ndarray input, PyObject *pyobj , PyObject *pyobj2) { + + boost::python::object output(boost::python::handle<>(boost::python::borrowed(pyobj))); + int isOutput = !(output == boost::python::api::object()); + + boost::python::object calculate(boost::python::handle<>(boost::python::borrowed(pyobj2))); + int isCalculate = !(calculate == boost::python::api::object()); + + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast(input.get_data()); + int * B = reinterpret_cast(zz.get_data()); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = sqrt((float)val); + std::memcpy(B + index, &val, sizeof(int)); + std::memcpy(C + index, &fval, sizeof(float)); + // if the PyObj is not None evaluate the function + if (isOutput) + output(fval); + if (isCalculate) { + float nfval = (float)bp::extract(calculate(val)); + if (isOutput) + output(nfval); + std::memcpy(C + index, &nfval, sizeof(float)); + } + } + } + } + + bp::list result; result.append(number_of_dims); @@ -196,7 +261,7 @@ BOOST_PYTHON_MODULE(prova) //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "fista"; + package.attr("__path__") = "prova"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); @@ -207,4 +272,5 @@ BOOST_PYTHON_MODULE(prova) //numpy_boost_python_register_type(); //numpy_boost_python_register_type(); def("mexFunction", mexFunction); + def("doSomething", doSomething); } \ No newline at end of file diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py index ffb9c02..7c86175 100644 --- a/src/Python/setup_test.py +++ b/src/Python/setup_test.py @@ -30,13 +30,13 @@ extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x' extra_libraries = [] if platform.system() == 'Windows': extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] - extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + #extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] if sys.version_info.major == 3 : extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] else: extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] else: - extra_include_dirs += ["../ContourTree/", "../Core/","."] + #extra_include_dirs += ["../ContourTree/", "../Core/","."] if sys.version_info.major == 3: extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] else: @@ -47,7 +47,7 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("fista", + ext_modules = [Extension("prova", sources=[ "Matlab2Python_utils.cpp", ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), diff --git a/src/Python/test.py b/src/Python/test.py index e283f89..db47380 100644 --- a/src/Python/test.py +++ b/src/Python/test.py @@ -5,14 +5,38 @@ Created on Thu Aug 3 14:08:09 2017 @author: ofn77899 """ -import fista +import prova import numpy as np -a = np.asarray([i for i in range(3*4*5)]) -a = a.reshape([3,4,5]) +a = np.asarray([i for i in range(1*2*3)]) +a = a.reshape([1,2,3]) print (a) -b = fista.mexFunction(a) +b = prova.mexFunction(a) #print (b) print (b[4].shape) print (b[4]) -print (b[5]) \ No newline at end of file +print (b[5]) + +def print_element(input): + print ("f: {0}".format(input)) + +prova.doSomething(a, print_element, None) + +c = [] +def append_to_list(input, shouldPrint=False): + c.append(input) + if shouldPrint: + print ("{0} appended to list {1}".format(input, c)) + +def element_wise_algebra(input, shouldPrint=True): + ret = input - 7 + if shouldPrint: + print ("element_wise {0}".format(ret)) + return ret + +prova.doSomething(a, append_to_list, None) +#print ("this is c: {0}".format(c)) + +b = prova.doSomething(a, None, element_wise_algebra) +#print (a) +print (b[5]) -- cgit v1.2.3 From a9274a7533b6d33a99810b2c1f1ad455768820ae Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:27:28 +0100 Subject: updated test for regularizer API --- src/Python/test_regularizers.py | 590 ++++++++++++++++++++-------------------- 1 file changed, 290 insertions(+), 300 deletions(-) (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 6a34749..5d25f02 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -8,216 +8,37 @@ Created on Fri Aug 4 11:10:05 2017 from ccpi.viewer.CILViewer2D import Converter import vtk -import regularizers import matplotlib.pyplot as plt import numpy as np import os from enum import Enum - -class Regularizer(): - '''Class to handle regularizer algorithms to be used during reconstruction - - Currently 5 regularization algorithms are available: - - 1) SplitBregman_TV - 2) FGP_TV - 3) - 4) - 5) - - Usage: - the regularizer can be invoked as object or as static method - Depending on the actual regularizer the input parameter may vary, and - a different default setting is defined. - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - - out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., - number_of_iterations=30, tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - A number of optional parameters can be passed or skipped - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) - - ''' - class Algorithm(Enum): - SplitBregman_TV = regularizers.SplitBregman_TV - FGP_TV = regularizers.FGP_TV - LLT_model = regularizers.LLT_model - PatchBased_Regul = regularizers.PatchBased_Regul - TGV_PD = regularizers.TGV_PD - # Algorithm - - class TotalVariationPenalty(Enum): - isotropic = 0 - l1 = 1 - # TotalVariationPenalty - - def __init__(self , algorithm): - self.setAlgorithm ( algorithm ) - # __init__ - - def setAlgorithm(self, algorithm): - self.algorithm = algorithm - self.pars = self.parsForAlgorithm(algorithm) - # setAlgorithm - - def parsForAlgorithm(self, algorithm): - pars = dict() - - if algorithm == Regularizer.Algorithm.SplitBregman_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 35 - pars['tolerance_constant'] = 0.0001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.FGP_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 50 - pars['tolerance_constant'] = 0.001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.LLT_model: - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['time_step'] = None - pars['number_of_iterations'] = None - pars['tolerance_constant'] = None - pars['restrictive_Z_smoothing'] = 0 - - elif algorithm == Regularizer.Algorithm.PatchBased_Regul: - pars['algorithm'] = algorithm - pars['input'] = None - pars['searching_window_ratio'] = None - pars['similarity_window_ratio'] = None - pars['PB_filtering_parameter'] = None - pars['regularization_parameter'] = None - - elif algorithm == Regularizer.Algorithm.TGV_PD: - pars['algorithm'] = algorithm - pars['input'] = None - pars['first_order_term'] = None - pars['second_order_term'] = None - pars['number_of_iterations'] = None - pars['regularization_parameter'] = None - - - - return pars - # parsForAlgorithm - - def __call__(self, input, regularization_parameter, **kwargs): - - if kwargs is not None: - for key, value in kwargs.items(): - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - self.pars['input'] = input - self.pars['regularization_parameter'] = regularization_parameter - #for key, value in self.pars.items(): - # print("{0} = {1}".format(key, value)) - if None in self.pars: - raise Exception("Not all parameters have been provided") - - if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : - return self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.FGP_TV : - return self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.LLT_model : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, - regularization_parameter, - self.pars['time_step'] , - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['restrictive_Z_smoothing'] ) - elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, regularization_parameter, - self.pars['searching_window_ratio'] , - self.pars['similarity_window_ratio'] , - self.pars['PB_filtering_parameter']) - elif self.algorithm == Regularizer.Algorithm.TGV_PD : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, regularization_parameter, - self.pars['first_order_term'] , - self.pars['second_order_term'] , - self.pars['number_of_iterations']) - - - - # __call__ - - @staticmethod - def SplitBregman_TV(input, regularization_parameter , **kwargs): - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - out = list( reg(input, regularization_parameter, **kwargs) ) - out.append(reg.pars) - return out - - @staticmethod - def FGP_TV(input, regularization_parameter , **kwargs): - reg = Regularizer(Regularizer.Algorithm.FGP_TV) - out = list( reg(input, regularization_parameter, **kwargs) ) - out.append(reg.pars) - return out - - @staticmethod - def LLT_model(input, regularization_parameter , time_step, number_of_iterations, - tolerance_constant, restrictive_Z_smoothing=0): - reg = Regularizer(Regularizer.Algorithm.LLT_model) - out = list( reg(input, regularization_parameter, time_step=time_step, - number_of_iterations=number_of_iterations, - tolerance_constant=tolerance_constant, - restrictive_Z_smoothing=restrictive_Z_smoothing) ) - out.append(reg.pars) - return out - - @staticmethod - def PatchBased_Regul(input, regularization_parameter, - searching_window_ratio, - similarity_window_ratio, - PB_filtering_parameter): - reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) - out = list( reg(input, - regularization_parameter, - searching_window_ratio=searching_window_ratio, - similarity_window_ratio=similarity_window_ratio, - PB_filtering_parameter=PB_filtering_parameter ) - ) - out.append(reg.pars) - return out - - @staticmethod - def TGV_PD(input, regularization_parameter , first_order_term, - second_order_term, number_of_iterations): - - reg = Regularizer(Regularizer.Algorithm.TGV_PD) - out = list( reg(input, regularization_parameter, - first_order_term=first_order_term, - second_order_term=second_order_term, - number_of_iterations=number_of_iterations) ) - out.append(reg.pars) - return out - - +import timeit + +from Regularizer import Regularizer + +############################################################################### +#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 +#NRMSE a normalization of the root of the mean squared error +#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum +# intensity from the two images being compared, and respectively the same for +# minval. RMSE is given by the square root of MSE: +# sqrt[(sum(A - B) ** 2) / |A|], +# where |A| means the number of elements in A. By doing this, the maximum value +# given by RMSE is maxval. + +def nrmse(im1, im2): + a, b = im1.shape + rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b)) + max_val = max(np.max(im1), np.max(im2)) + min_val = min(np.min(im1), np.min(im2)) + return 1 - (rmse / (max_val - min_val)) +############################################################################### + +############################################################################### +# +# 2D Regularizers +# +############################################################################### #Example: # figure; # Im = double(imread('lena_gray_256.tif'))/255; % loading image @@ -255,49 +76,55 @@ reg_output = [] ####################### SplitBregman_TV ##################################### # u = SplitBregman_TV(single(u0), 10, 30, 1e-04); -reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +use_object = True +if use_object: + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + # or + # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + #TV_Penalty=Regularizer.TotalVariationPenalty.l1) + plotme = reg() [0] + pars = reg.pars + textstr = reg.printParametersToString() + + #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + # TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) -out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, - #tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - -out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) -out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) -pars = out2[2] -reg_output.append(out2) +else: + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + pars = out2[2] + reg_output.append(out2) + plotme = reg_output[-1][0] + textstr = out2[-1] a=fig.add_subplot(2,3,2) -a.set_title('SplitBregman_TV') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['TV_penalty'].name) + # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) +imgplot = plt.imshow(plotme) ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); -out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, - number_of_iterations=10) -pars = out2[-1] +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, + number_of_iterations=200) +pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,3) -a.set_title('FGP_TV') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['TV_penalty'].name) + +textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) @@ -316,50 +143,12 @@ out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, time_step=0.0003, tolerance_constant=0.0001, number_of_iterations=300) -pars = out2[-1] +pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,4) -a.set_title('LLT_model') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['time_step'] - ) - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) - -###################### PatchBased_Regul ######################################### -# Quick 2D denoising example in Matlab: -# Im = double(imread('lena_gray_256.tif'))/255; % loading image -# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); - -out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - searching_window_ratio=3, - similarity_window_ratio=1, - PB_filtering_parameter=0.08) -pars = out2[-1] -reg_output.append(out2) - -a=fig.add_subplot(2,3,5) -a.set_title('PatchBased_Regul') -textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f' -textstr = textstr % (pars['regularization_parameter'], - pars['searching_window_ratio'], - pars['similarity_window_ratio'], - pars['PB_filtering_parameter']) - - - - +textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords @@ -367,6 +156,215 @@ a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) imgplot = plt.imshow(reg_output[-1][0]) +# ###################### PatchBased_Regul ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + # searching_window_ratio=3, + # similarity_window_ratio=1, + # PB_filtering_parameter=0.08) +# pars = out2[-2] +# reg_output.append(out2) + +# a=fig.add_subplot(2,3,5) + + +# textstr = out2[-1] + +# # these are matplotlib.patch.Patch properties +# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# # place a text box in upper left in axes coords +# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + # verticalalignment='top', bbox=props) +# imgplot = plt.imshow(reg_output[-1][0]) + + +# ###################### TGV_PD ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + # first_order_term=1.3, + # second_order_term=1, + # number_of_iterations=550) +# pars = out2[-2] +# reg_output.append(out2) + +# a=fig.add_subplot(2,3,6) + + +# textstr = out2[-1] + + +# # these are matplotlib.patch.Patch properties +# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# # place a text box in upper left in axes coords +# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + # verticalalignment='top', bbox=props) +# imgplot = plt.imshow(reg_output[-1][0]) + + +plt.show() + +################################################################################ +## +## 3D Regularizers +## +################################################################################ +##Example: +## figure; +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" +# +#reader = vtk.vtkMetaImageReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +##vtk returns 3D images, let's take just the one slice there is as 2D +#Im = Converter.vtk2numpy(reader.GetOutput()) +#Im = Im.astype('float32') +##imgplot = plt.imshow(Im) +#perc = 0.05 +#u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +## map the u0 u0->u0>0 +#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +#u0 = f(u0).astype('float32') +#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), +# reader.GetOutput().GetOrigin()) +#converter.Update() +#writer = vtk.vtkMetaImageWriter() +#writer.SetInputData(converter.GetOutput()) +#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") +##writer.Write() +# +# +### plot +#fig3D = plt.figure() +##a=fig.add_subplot(3,3,1) +##a.set_title('Original') +##imgplot = plt.imshow(Im) +#sliceNo = 32 +# +#a=fig3D.add_subplot(2,3,1) +#a.set_title('noise') +#imgplot = plt.imshow(u0.T[sliceNo]) +# +#reg_output3d = [] +# +############################################################################### +## Call regularizer +# +######################## SplitBregman_TV ##################################### +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +# +##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +## #tolerance_constant=1e-4, +## TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +# +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### FGP_TV ######################################### +## u = FGP_TV(single(u0), 0.05, 100, 1e-04); +#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +# number_of_iterations=200) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### LLT_model ######################################### +## * u0 = Im + .03*randn(size(Im)); % adding noise +## [Den] = LLT_model(single(u0), 10, 0.1, 1); +##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +##input, regularization_parameter , time_step, number_of_iterations, +## tolerance_constant, restrictive_Z_smoothing=0 +#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +# time_step=0.0003, +# tolerance_constant=0.0001, +# number_of_iterations=300) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### PatchBased_Regul ######################################### +## Quick 2D denoising example in Matlab: +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +## ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); +# +#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +# searching_window_ratio=3, +# similarity_window_ratio=1, +# PB_filtering_parameter=0.08) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# ###################### TGV_PD ######################################### # Quick 2D denoising example in Matlab: @@ -375,30 +373,22 @@ imgplot = plt.imshow(reg_output[-1][0]) # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - first_order_term=1.3, - second_order_term=1, - number_of_iterations=550) -pars = out2[-1] -reg_output.append(out2) - -a=fig.add_subplot(2,3,6) -a.set_title('TGV_PD') -textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d' -textstr = textstr % (pars['regularization_parameter'], - pars['first_order_term'], - pars['second_order_term'], - pars['number_of_iterations']) - - - - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) - - - +#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +# first_order_term=1.3, +# second_order_term=1, +# number_of_iterations=550) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) -- cgit v1.2.3 From 73fed4964d81f1f47a0b6ecbe66517f569327b27 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:31:16 +0100 Subject: initial commit of Reconstructor.py --- src/Python/ccpi/reconstruction/Reconstructor.py | 598 ++++++++++++++++++++++++ 1 file changed, 598 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/Reconstructor.py (limited to 'src') diff --git a/src/Python/ccpi/reconstruction/Reconstructor.py b/src/Python/ccpi/reconstruction/Reconstructor.py new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/reconstruction/Reconstructor.py @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: + + class Algorithm(Enum): + CGLS = alg.cgls + CGLS_CONV = alg.cgls_conv + SIRT = alg.sirt + MLEM = alg.mlem + CGLS_TICHONOV = alg.cgls_tikhonov + CGLS_TVREG = alg.cgls_TVreg + FISTA = 'fista' + + def __init__(self, algorithm = None, projection_data = None, + angles = None, center_of_rotation = None , + flat_field = None, dark_field = None, + iterations = None, resolution = None, isLogScale = False, threads = None, + normalized_projection = None): + + self.pars = dict() + self.pars['algorithm'] = algorithm + self.pars['projection_data'] = projection_data + self.pars['normalized_projection'] = normalized_projection + self.pars['angles'] = angles + self.pars['center_of_rotation'] = numpy.double(center_of_rotation) + self.pars['flat_field'] = flat_field + self.pars['iterations'] = iterations + self.pars['dark_field'] = dark_field + self.pars['resolution'] = resolution + self.pars['isLogScale'] = isLogScale + self.pars['threads'] = threads + if (iterations != None): + self.pars['iterationValues'] = numpy.zeros((iterations)) + + if projection_data != None and dark_field != None and flat_field != None: + norm = self.normalize(projection_data, dark_field, flat_field, 0.1) + self.pars['normalized_projection'] = norm + + + def setPars(self, parameters): + keys = ['algorithm','projection_data' ,'normalized_projection', \ + 'angles' , 'center_of_rotation' , 'flat_field', \ + 'iterations','dark_field' , 'resolution', 'isLogScale' , \ + 'threads' , 'iterationValues', 'regularize'] + + for k in keys: + if k not in parameters.keys(): + self.pars[k] = None + else: + self.pars[k] = parameters[k] + + + def sanityCheck(self): + projection_data = self.pars['projection_data'] + dark_field = self.pars['dark_field'] + flat_field = self.pars['flat_field'] + angles = self.pars['angles'] + + if projection_data != None and dark_field != None and \ + angles != None and flat_field != None: + data_shape = numpy.shape(projection_data) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + + if data_shape[1:] != numpy.shape(flat_field): + #raise Exception('Projection and flat field dimensions do not match') + return (False , 'Projection and flat field dimensions do not match') + if data_shape[1:] != numpy.shape(dark_field): + #raise Exception('Projection and dark field dimensions do not match') + return (False , 'Projection and dark field dimensions do not match') + + return (True , '' ) + elif self.pars['normalized_projection'] != None: + data_shape = numpy.shape(self.pars['normalized_projection']) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + else: + return (True , '' ) + else: + return (False , 'Not enough data') + + def reconstruct(self, parameters = None): + if parameters != None: + self.setPars(parameters) + + go , reason = self.sanityCheck() + if go: + return self._reconstruct() + else: + raise Exception(reason) + + + def _reconstruct(self, parameters=None): + if parameters!=None: + self.setPars(parameters) + parameters = self.pars + + if parameters['algorithm'] != None and \ + parameters['normalized_projection'] != None and \ + parameters['angles'] != None and \ + parameters['center_of_rotation'] != None and \ + parameters['iterations'] != None and \ + parameters['resolution'] != None and\ + parameters['threads'] != None and\ + parameters['isLogScale'] != None: + + + if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, + Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): + #store parameters + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['isLogScale'] + ) + return result + elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, + Reconstructor.Algorithm.CGLS_TICHONOV, + Reconstructor.Algorithm.CGLS_TVREG) : + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['regularize'], + numpy.zeros((parameters['iterations'])), + parameters['isLogScale'] + ) + + elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: + pass + + else: + if parameters['projection_data'] != None and \ + parameters['dark_field'] != None and \ + parameters['flat_field'] != None: + norm = self.normalize(parameters['projection_data'], + parameters['dark_field'], + parameters['flat_field'], 0.1) + self.pars['normalized_projection'] = norm + return self._reconstruct(parameters) + + + + def _normalize(self, projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + def normalize(self, projections, dark, flat, def_val=0): + norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] + return numpy.asarray (norm, dtype=numpy.float32) + + + +class FISTA(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() -- cgit v1.2.3 From a15873ea24734c9a2a7c71eed4d106968b406a07 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:56:08 +0100 Subject: module rename to cpu_regularizers --- src/Python/setup.py | 4 ++-- src/Python/test_regularizers.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index 0468722..94467c4 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -47,7 +47,7 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("regularizers", + ext_modules = [Extension("ccpi.imaging.cpu_regularizers", sources=["fista_module.cpp", "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", @@ -60,5 +60,5 @@ setup( ], zip_safe = False, - packages = {'ccpi','ccpi.reconstruction'}, + packages = {'ccpi','ccpi.fistareconstruction'}, ) diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 5d25f02..755804a 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -14,7 +14,8 @@ import os from enum import Enum import timeit -from Regularizer import Regularizer +#from Regularizer import Regularizer +from ccpi.imaging.Regularizer import Regularizer ############################################################################### #https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 -- cgit v1.2.3 From 74e2b61d107f5871b8e78de3f3d0a503c494e64c Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:58:11 +0100 Subject: changed the backward slash to forward --- src/Python/setup.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index 0468722..e6c2dc6 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -49,12 +49,12 @@ setup( cmdclass = {'build_ext': build_ext}, ext_modules = [Extension("regularizers", sources=["fista_module.cpp", - "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", - "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", - "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", - "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", - "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c", - "..\\..\\main_func\\regularizers_CPU\\utils.c" + "../../main_func/regularizers_CPU/FGP_TV_core.c", + "../../main_func/regularizers_CPU/SplitBregman_TV_core.c", + "../../main_func/regularizers_CPU/LLT_model_core.c", + "../../main_func/regularizers_CPU/PatchBased_Regul_core.c", + "../../main_func/regularizers_CPU/TGV_PD_core.c", + "../../main_func/regularizers_CPU/utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 105b57a1e98c2bb7b3bf94c43b6c669925ebb1b9 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:01:28 +0100 Subject: added viewer for testing --- src/Python/ccpi/viewer/CILViewer.py | 361 +++++++ src/Python/ccpi/viewer/CILViewer2D.py | 1126 ++++++++++++++++++++ src/Python/ccpi/viewer/QVTKWidget.py | 340 ++++++ src/Python/ccpi/viewer/QVTKWidget2.py | 84 ++ src/Python/ccpi/viewer/__init__.py | 1 + .../viewer/__pycache__/CILViewer.cpython-35.pyc | Bin 0 -> 10542 bytes .../viewer/__pycache__/CILViewer2D.cpython-35.pyc | Bin 0 -> 35633 bytes .../viewer/__pycache__/QVTKWidget.cpython-35.pyc | Bin 0 -> 10099 bytes .../viewer/__pycache__/QVTKWidget2.cpython-35.pyc | Bin 0 -> 1316 bytes .../viewer/__pycache__/__init__.cpython-35.pyc | Bin 0 -> 210 bytes src/Python/ccpi/viewer/embedvtk.py | 75 ++ 11 files changed, 1987 insertions(+) create mode 100644 src/Python/ccpi/viewer/CILViewer.py create mode 100644 src/Python/ccpi/viewer/CILViewer2D.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget2.py create mode 100644 src/Python/ccpi/viewer/__init__.py create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/embedvtk.py (limited to 'src') diff --git a/src/Python/ccpi/viewer/CILViewer.py b/src/Python/ccpi/viewer/CILViewer.py new file mode 100644 index 0000000..efcf8be --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +import math +from vtk.util import numpy_support + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + + + +class CILViewer(): + '''Simple 3D Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600): + '''creates the rendering pipeline''' + + # create a rendering window and renderer + self.ren = vtk.vtkRenderer() + self.renWin = vtk.vtkRenderWindow() + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + # img 3D as slice + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceActor = None + self.voi = None + self.wl = None + self.ia = None + self.sliceActorNo = 0 + # create a renderwindowinteractor + self.iren = vtk.vtkRenderWindowInteractor() + self.iren.SetRenderWindow(self.renWin) + + self.style = vtk.vtkInteractorStyleTrackballCamera() + self.iren.SetInteractorStyle(self.style) + + self.ren.SetBackground(.1, .2, .4) + + self.actors = {} + self.iren.RemoveObservers('MouseWheelForwardEvent') + self.iren.RemoveObservers('MouseWheelBackwardEvent') + + self.iren.AddObserver('MouseWheelForwardEvent', self.mouseInteraction, 1.0) + self.iren.AddObserver('MouseWheelBackwardEvent', self.mouseInteraction, 1.0) + + self.iren.RemoveObservers('KeyPressEvent') + self.iren.AddObserver('KeyPressEvent', self.keyPress, 1.0) + + + self.iren.Initialize() + + + + def getRenderer(self): + '''returns the renderer''' + return self.ren + + def getRenderWindow(self): + '''returns the render window''' + return self.renWin + + def getInteractor(self): + '''returns the render window interactor''' + return self.iren + + def getCamera(self): + '''returns the active camera''' + return self.ren.GetActiveCamera() + + def createPolyDataActor(self, polydata): + '''returns an actor for a given polydata''' + mapper = vtk.vtkPolyDataMapper() + if vtk.VTK_MAJOR_VERSION <= 5: + mapper.SetInput(polydata) + else: + mapper.SetInputData(polydata) + + # actor + actor = vtk.vtkActor() + actor.SetMapper(mapper) + #actor.GetProperty().SetOpacity(0.8) + return actor + + def setPolyDataActor(self, actor): + '''displays the given polydata''' + + self.ren.AddActor(actor) + + self.actors[len(self.actors)+1] = [actor, True] + self.iren.Initialize() + self.renWin.Render() + + def displayPolyData(self, polydata): + self.setPolyDataActor(self.createPolyDataActor(polydata)) + + def hideActor(self, actorno): + '''Hides an actor identified by its number in the list of actors''' + try: + if self.actors[actorno][1]: + self.ren.RemoveActor(self.actors[actorno][0]) + self.actors[actorno][1] = False + except KeyError as ke: + print ("Warning Actor not present") + + def showActor(self, actorno, actor = None): + '''Shows hidden actor identified by its number in the list of actors''' + try: + if not self.actors[actorno][1]: + self.ren.AddActor(self.actors[actorno][0]) + self.actors[actorno][1] = True + return actorno + except KeyError as ke: + # adds it to the actors if not there already + if actor != None: + self.ren.AddActor(actor) + self.actors[len(self.actors)+1] = [actor, True] + return len(self.actors) + + def addActor(self, actor): + '''Adds an actor to the render''' + return self.showActor(0, actor) + + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + + def startRenderLoop(self): + self.iren.Start() + + + def setupObservers(self, interactor): + interactor.RemoveObservers('LeftButtonPressEvent') + interactor.AddObserver('LeftButtonPressEvent', self.mouseInteraction) + interactor.Initialize() + + + def mouseInteraction(self, interactor, event): + if event == 'MouseWheelForwardEvent': + maxSlice = self.img3D.GetDimensions()[self.sliceOrientation] + if (self.sliceno + 1 < maxSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno + 1 + self.displaySliceActor(self.sliceno) + else: + minSlice = 0 + if (self.sliceno - 1 > minSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno - 1 + self.displaySliceActor(self.sliceno) + + + def keyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "x": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_YZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "y": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "z": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceno = int(self.img3D.GetDimensions()[2] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + if interactor.GetKeyCode() == "X": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("x") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("y") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = math.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("z") + self.keyPress(interactor, event) + else : + print ("Unhandled event %s" % interactor.GetKeyCode()) + + + + def setInput3DData(self, imageData): + self.img3D = imageData + + def setInputAsNumpy(self, numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + + def displaySliceActor(self, sliceno = 0): + self.sliceno = sliceno + first = False + + self.sliceActor , self.voi, self.wl , self.ia = \ + self.getSliceActor(self.img3D, + sliceno, + self.sliceActor, + self.voi, + self.wl, + self.ia) + no = self.showActor(self.sliceActorNo, self.sliceActor) + self.sliceActorNo = no + + self.iren.Initialize() + self.renWin.Render() + + return self.sliceActorNo + + + def getSliceActor(self, + imageData , + sliceno=0, + imageActor=None , + voi=None, + windowLevel=None, + imageAccumulate=None): + '''Slices a 3D volume and then creates an actor to be rendered''' + if (voi==None): + voi = vtk.vtkExtractVOI() + #voi = vtk.vtkImageClip() + voi.SetInputData(imageData) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = sliceno + extent[self.sliceOrientation * 2 + 1] = sliceno + voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + voi.Update() + # set window/level for all slices + if imageAccumulate == None: + imageAccumulate = vtk.vtkImageAccumulate() + + if (windowLevel == None): + windowLevel = vtk.vtkImageMapToWindowLevelColors() + imageAccumulate.SetInputData(imageData) + imageAccumulate.Update() + cmax = imageAccumulate.GetMax()[0] + cmin = imageAccumulate.GetMin()[0] + windowLevel.SetLevel((cmax+cmin)/2) + windowLevel.SetWindow(cmax-cmin) + + windowLevel.SetInputData(voi.GetOutput()) + windowLevel.Update() + + if imageActor == None: + imageActor = vtk.vtkImageActor() + imageActor.SetInputData(windowLevel.GetOutput()) + imageActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + imageActor.Update() + return (imageActor , voi, windowLevel, imageAccumulate) + + + # Set interpolation on + def setInterpolateOn(self): + self.sliceActor.SetInterpolate(True) + self.renWin.Render() + + # Set interpolation off + def setInterpolateOff(self): + self.sliceActor.SetInterpolate(False) + self.renWin.Render() \ No newline at end of file diff --git a/src/Python/ccpi/viewer/CILViewer2D.py b/src/Python/ccpi/viewer/CILViewer2D.py new file mode 100644 index 0000000..c1629af --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer2D.py @@ -0,0 +1,1126 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +from vtk.util import numpy_support , vtkImageImportFromArray +from enum import Enum + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + +CONTROL_KEY = 8 +SHIFT_KEY = 4 +ALT_KEY = -128 + + +# Converter class +class Converter(): + + # Utility functions to transform numpy arrays to vtkImageData and viceversa + @staticmethod + def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Creates a vtkImageImportFromArray object and returns it. + + It handles the different axis order from numpy to VTK''' + importer = vtkImageImportFromArray.vtkImageImportFromArray() + importer.SetArray(numpy.transpose(nparray).copy()) + importer.SetDataSpacing(spacing) + importer.SetDataOrigin(origin) + return importer + + @staticmethod + def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Converts a 3D numpy array to a vtkImageData''' + importer = Converter.numpy2vtkImporter(nparray, spacing, origin) + importer.Update() + return importer.GetOutput() + + @staticmethod + def vtk2numpy(imgdata): + '''Converts the VTK data to 3D numpy array''' + img_data = numpy_support.vtk_to_numpy( + imgdata.GetPointData().GetScalars()) + + dims = imgdata.GetDimensions() + dims = (dims[2],dims[1],dims[0]) + data3d = numpy.reshape(img_data, dims) + + return numpy.transpose(data3d).copy() + + @staticmethod + def tiffStack2numpy(filename, indices, + extent = None , sampleRate = None ,\ + flatField = None, darkField = None): + '''Converts a stack of TIFF files to numpy array. + + filename must contain the whole path. The filename is supposed to be named and + have a suffix with the ordinal file number, i.e. /path/to/projection_%03d.tif + + indices are the suffix, generally an increasing number + + Optionally extracts only a selection of the 2D images and (optionally) + normalizes. + ''' + + stack = vtk.vtkImageData() + reader = vtk.vtkTIFFReader() + voi = vtk.vtkExtractVOI() + + #directory = "C:\\Users\\ofn77899\\Documents\\CCPi\\IMAT\\20170419_crabtomo\\crabtomo\\" + + stack_image = numpy.asarray([]) + nreduced = len(indices) + + for num in range(len(indices)): + fn = filename % indices[num] + print ("resampling %s" % ( fn ) ) + reader.SetFileName(fn) + reader.Update() + print (reader.GetOutput().GetScalarTypeAsString()) + if num == 0: + if (extent == None): + sliced = reader.GetOutput().GetExtent() + stack.SetExtent(sliced[0],sliced[1], sliced[2],sliced[3], 0, nreduced-1) + else: + sliced = extent + voi.SetVOI(extent) + + if sampleRate is not None: + voi.SetSampleRate(sampleRate) + ext = numpy.asarray([(sliced[2*i+1] - sliced[2*i])/sampleRate[i] for i in range(3)], dtype=int) + print ("ext {0}".format(ext)) + stack.SetExtent(0, ext[0] , 0, ext[1], 0, nreduced-1) + else: + stack.SetExtent(0, sliced[1] - sliced[0] , 0, sliced[3]-sliced[2], 0, nreduced-1) + if (flatField != None and darkField != None): + stack.AllocateScalars(vtk.VTK_FLOAT, 1) + else: + stack.AllocateScalars(reader.GetOutput().GetScalarType(), 1) + print ("Image Size: %d" % ((sliced[1]+1)*(sliced[3]+1) )) + stack_image = Converter.vtk2numpy(stack) + print ("Stack shape %s" % str(numpy.shape(stack_image))) + + if extent!=None: + voi.SetInputData(reader.GetOutput()) + voi.Update() + img = voi.GetOutput() + else: + img = reader.GetOutput() + + theSlice = Converter.vtk2numpy(img).T[0] + if darkField != None and flatField != None: + print("Try to normalize") + #if numpy.shape(darkField) == numpy.shape(flatField) and numpy.shape(flatField) == numpy.shape(theSlice): + theSlice = Converter.normalize(theSlice, darkField, flatField, 0.01) + print (theSlice.dtype) + + + print ("Slice shape %s" % str(numpy.shape(theSlice))) + stack_image.T[num] = theSlice.copy() + + return stack_image + + @staticmethod + def normalize(projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + + +## Utility functions to transform numpy arrays to vtkImageData and viceversa +#def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtkImporter(nparray, spacing, origin) +# +#def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtk(nparray, spacing, origin) +# +#def vtk2numpy(imgdata): +# return Converter.vtk2numpy(imgdata) +# +#def tiffStack2numpy(filename, indices): +# return Converter.tiffStack2numpy(filename, indices) + +class ViewerEvent(Enum): + # left button + PICK_EVENT = 0 + # alt + right button + move + WINDOW_LEVEL_EVENT = 1 + # shift + right button + ZOOM_EVENT = 2 + # control + right button + PAN_EVENT = 3 + # control + left button + CREATE_ROI_EVENT = 4 + # alt + left button + DELETE_ROI_EVENT = 5 + # release button + NO_EVENT = -1 + + +#class CILInteractorStyle(vtk.vtkInteractorStyleTrackballCamera): +class CILInteractorStyle(vtk.vtkInteractorStyleImage): + + def __init__(self, callback): + vtk.vtkInteractorStyleImage.__init__(self) + self.callback = callback + self._viewer = callback + priority = 1.0 + +# self.AddObserver("MouseWheelForwardEvent" , callback.OnMouseWheelForward , priority) +# self.AddObserver("MouseWheelBackwardEvent" , callback.OnMouseWheelBackward, priority) +# self.AddObserver('KeyPressEvent', callback.OnKeyPress, priority) +# self.AddObserver('LeftButtonPressEvent', callback.OnLeftButtonPressEvent, priority) +# self.AddObserver('RightButtonPressEvent', callback.OnRightButtonPressEvent, priority) +# self.AddObserver('LeftButtonReleaseEvent', callback.OnLeftButtonReleaseEvent, priority) +# self.AddObserver('RightButtonReleaseEvent', callback.OnRightButtonReleaseEvent, priority) +# self.AddObserver('MouseMoveEvent', callback.OnMouseMoveEvent, priority) + + self.AddObserver("MouseWheelForwardEvent" , self.OnMouseWheelForward , priority) + self.AddObserver("MouseWheelBackwardEvent" , self.OnMouseWheelBackward, priority) + self.AddObserver('KeyPressEvent', self.OnKeyPress, priority) + self.AddObserver('LeftButtonPressEvent', self.OnLeftButtonPressEvent, priority) + self.AddObserver('RightButtonPressEvent', self.OnRightButtonPressEvent, priority) + self.AddObserver('LeftButtonReleaseEvent', self.OnLeftButtonReleaseEvent, priority) + self.AddObserver('RightButtonReleaseEvent', self.OnRightButtonReleaseEvent, priority) + self.AddObserver('MouseMoveEvent', self.OnMouseMoveEvent, priority) + + self.InitialEventPosition = (0,0) + + + def SetInitialEventPosition(self, xy): + self.InitialEventPosition = xy + + def GetInitialEventPosition(self): + return self.InitialEventPosition + + def GetKeyCode(self): + return self.GetInteractor().GetKeyCode() + + def SetKeyCode(self, keycode): + self.GetInteractor().SetKeyCode(keycode) + + def GetControlKey(self): + return self.GetInteractor().GetControlKey() == CONTROL_KEY + + def GetShiftKey(self): + return self.GetInteractor().GetShiftKey() == SHIFT_KEY + + def GetAltKey(self): + return self.GetInteractor().GetAltKey() == ALT_KEY + + def GetEventPosition(self): + return self.GetInteractor().GetEventPosition() + + def GetEventPositionInWorldCoordinates(self): + pass + + def GetDeltaEventPosition(self): + x,y = self.GetInteractor().GetEventPosition() + return (x - self.InitialEventPosition[0] , y - self.InitialEventPosition[1]) + + def Dolly(self, factor): + self.callback.camera.Dolly(factor) + self.callback.ren.ResetCameraClippingRange() + + def GetDimensions(self): + return self._viewer.img3D.GetDimensions() + + def GetInputData(self): + return self._viewer.img3D + + def GetSliceOrientation(self): + return self._viewer.sliceOrientation + + def SetSliceOrientation(self, orientation): + self._viewer.sliceOrientation = orientation + + def GetActiveSlice(self): + return self._viewer.sliceno + + def SetActiveSlice(self, sliceno): + self._viewer.sliceno = sliceno + + def UpdatePipeline(self, reset = False): + self._viewer.updatePipeline(reset) + + def GetActiveCamera(self): + return self._viewer.ren.GetActiveCamera() + + def SetActiveCamera(self, camera): + self._viewer.ren.SetActiveCamera(camera) + + def ResetCamera(self): + self._viewer.ren.ResetCamera() + + def Render(self): + self._viewer.renWin.Render() + + def UpdateSliceActor(self): + self._viewer.sliceActor.Update() + + def AdjustCamera(self): + self._viewer.AdjustCamera() + + def SaveRender(self, filename): + self._viewer.SaveRender(filename) + + def GetRenderWindow(self): + return self._viewer.renWin + + def GetRenderer(self): + return self._viewer.ren + + def GetROIWidget(self): + return self._viewer.ROIWidget + + def SetViewerEvent(self, event): + self._viewer.event = event + + def GetViewerEvent(self): + return self._viewer.event + + def SetInitialCameraPosition(self, position): + self._viewer.InitialCameraPosition = position + + def GetInitialCameraPosition(self): + return self._viewer.InitialCameraPosition + + def SetInitialLevel(self, level): + self._viewer.InitialLevel = level + + def GetInitialLevel(self): + return self._viewer.InitialLevel + + def SetInitialWindow(self, window): + self._viewer.InitialWindow = window + + def GetInitialWindow(self): + return self._viewer.InitialWindow + + def GetWindowLevel(self): + return self._viewer.wl + + def SetROI(self, roi): + self._viewer.ROI = roi + + def GetROI(self): + return self._viewer.ROI + + def UpdateCornerAnnotation(self, text, corner): + self._viewer.updateCornerAnnotation(text, corner) + + def GetPicker(self): + return self._viewer.picker + + def GetCornerAnnotation(self): + return self._viewer.cornerAnnotation + + def UpdateROIHistogram(self): + self._viewer.updateROIHistogram() + + + ############### Handle events + def OnMouseWheelForward(self, interactor, event): + maxSlice = self.GetDimensions()[self.GetSliceOrientation()] + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + + if (self.GetActiveSlice() + advance < maxSlice): + self.SetActiveSlice(self.GetActiveSlice() + advance) + + self.UpdatePipeline() + else: + print ("maxSlice %d request %d" % (maxSlice, self.GetActiveSlice() + 1 )) + + def OnMouseWheelBackward(self, interactor, event): + minSlice = 0 + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + if (self.GetActiveSlice() - advance >= minSlice): + self.SetActiveSlice( self.GetActiveSlice() - advance) + self.UpdatePipeline() + else: + print ("minSlice %d request %d" % (minSlice, self.GetActiveSlice() + 1 )) + + def OnKeyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "X": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_YZ ) + self.SetActiveSlice( int(self.GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Y": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XZ ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Z": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XY ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[2] / 2) ) + self.UpdatePipeline(True) + if interactor.GetKeyCode() == "x": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("X") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("Y") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = numpy.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,1,0) + self.SetActiveCamera(camera) + self.ResetCamera() + self.Render() + interactor.SetKeyCode("Z") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "a": + # reset color/window + cmax = self._viewer.ia.GetMax()[0] + cmin = self._viewer.ia.GetMin()[0] + + self.SetInitialLevel( (cmax+cmin)/2 ) + self.SetInitialWindow( cmax-cmin ) + + self.GetWindowLevel().SetLevel(self.GetInitialLevel()) + self.GetWindowLevel().SetWindow(self.GetInitialWindow()) + + self.GetWindowLevel().Update() + + self.UpdateSliceActor() + self.AdjustCamera() + self.Render() + + elif interactor.GetKeyCode() == "s": + filename = "current_render" + self.SaveRender(filename) + elif interactor.GetKeyCode() == "q": + print ("Terminating by pressing q %s" % (interactor.GetKeyCode(), )) + interactor.SetKeyCode("e") + self.OnKeyPress(interactor, event) + else : + #print ("Unhandled event %s" % (interactor.GetKeyCode(), ))) + pass + + def OnLeftButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + if ctrl and not (alt and shift): + self.SetViewerEvent( ViewerEvent.CREATE_ROI_EVENT ) + wsize = self.GetRenderWindow().GetSize() + position = interactor.GetEventPosition() + self.GetROIWidget().GetBorderRepresentation().SetPosition((position[0]/wsize[0] - 0.05) , (position[1]/wsize[1] - 0.05)) + self.GetROIWidget().GetBorderRepresentation().SetPosition2( (0.1) , (0.1)) + + self.GetROIWidget().On() + self.SetDisplayHistogram(True) + self.Render() + print ("Event %s is CREATE_ROI_EVENT" % (event)) + elif alt and not (shift and ctrl): + self.SetViewerEvent( ViewerEvent.DELETE_ROI_EVENT ) + self.GetROIWidget().Off() + self._viewer.updateCornerAnnotation("", 1, False) + self.SetDisplayHistogram(False) + self.Render() + print ("Event %s is DELETE_ROI_EVENT" % (event)) + elif not (ctrl and alt and shift): + self.SetViewerEvent ( ViewerEvent.PICK_EVENT ) + self.HandlePickEvent(interactor, event) + print ("Event %s is PICK_EVENT" % (event)) + + + def SetDisplayHistogram(self, display): + if display: + if (self._viewer.displayHistogram == 0): + self.GetRenderer().AddActor(self._viewer.histogramPlotActor) + self.firstHistogram = 1 + self.Render() + + self._viewer.histogramPlotActor.VisibilityOn() + self._viewer.displayHistogram = True + else: + self._viewer.histogramPlotActor.VisibilityOff() + self._viewer.displayHistogram = False + + + def OnLeftButtonReleaseEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.CREATE_ROI_EVENT: + #bc = self.ROIWidget.GetBorderRepresentation().GetPositionCoordinate() + #print (bc.GetValue()) + self.OnROIModifiedEvent(interactor, event) + + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def OnRightButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + + if alt and not (ctrl and shift): + self.SetViewerEvent( ViewerEvent.WINDOW_LEVEL_EVENT ) + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif shift and not (ctrl and alt): + self.SetViewerEvent( ViewerEvent.ZOOM_EVENT ) + self.SetInitialCameraPosition( self.GetActiveCamera().GetPosition()) + print ("Event %s is ZOOM_EVENT" % (event)) + elif ctrl and not (shift and alt): + self.SetViewerEvent (ViewerEvent.PAN_EVENT ) + self.SetInitialCameraPosition ( self.GetActiveCamera().GetPosition() ) + print ("Event %s is PAN_EVENT" % (event)) + + def OnRightButtonReleaseEvent(self, interactor, event): + print (event) + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + self.SetInitialLevel( self.GetWindowLevel().GetLevel() ) + self.SetInitialWindow ( self.GetWindowLevel().GetWindow() ) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT or \ + self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.SetInitialCameraPosition( () ) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + + def OnROIModifiedEvent(self, interactor, event): + + #print ("ROI EVENT " + event) + p1 = self.GetROIWidget().GetBorderRepresentation().GetPositionCoordinate() + p2 = self.GetROIWidget().GetBorderRepresentation().GetPosition2Coordinate() + wsize = self.GetRenderWindow().GetSize() + + #print (p1.GetValue()) + #print (p2.GetValue()) + pp1 = [p1.GetValue()[0] * wsize[0] , p1.GetValue()[1] * wsize[1] , 0.0] + pp2 = [p2.GetValue()[0] * wsize[0] + pp1[0] , p2.GetValue()[1] * wsize[1] + pp1[1] , 0.0] + vox1 = self.viewport2imageCoordinate(pp1) + vox2 = self.viewport2imageCoordinate(pp2) + + self.SetROI( (vox1 , vox2) ) + roi = self.GetROI() + print ("Pixel1 %d,%d,%d Value %f" % vox1 ) + print ("Pixel2 %d,%d,%d Value %f" % vox2 ) + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][2] - roi[0][2]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + x = abs(roi[1][1] - roi[0][1]) + y = abs(roi[1][2] - roi[0][2]) + + text = "ROI: %d x %d, %.2f kp" % (x,y,float(x*y)/1024.) + print (text) + self.UpdateCornerAnnotation(text, 1) + self.UpdateROIHistogram() + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.GetPicker().Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.GetPicker().GetPickPosition()) + pickPosition[self.GetSliceOrientation()] = \ + self.GetInputData().GetSpacing()[self.GetSliceOrientation()] * self.GetActiveSlice() + \ + self.GetInputData().GetOrigin()[self.GetSliceOrientation()] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.GetInputData().GetDimensions() + print (dims) + spac = self.GetInputData().GetSpacing() + orig = self.GetInputData().GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.GetInputData().GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + + def OnMouseMoveEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT: + self.HandleZoomEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.HandlePanEvent(interactor, event) + + + def HandleZoomEvent(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + size = self.GetRenderWindow().GetSize() + dy = - 4 * dy / size[1] + + print ("distance: " + str(self.GetActiveCamera().GetDistance())) + + print ("\ndy: %f\ncamera dolly %f\n" % (dy, 1 + dy)) + + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + dist = newposition[SLICE_ORIENTATION_XY] * ( 1 + dy ) + newposition[SLICE_ORIENTATION_XY] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[SLICE_ORIENTATION_XZ] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[SLICE_ORIENTATION_YZ] *= ( 1 + dy ) + #print ("new position " + str(newposition)) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + print ("distance after: " + str(self.GetActiveCamera().GetDistance())) + + def HandlePanEvent(self, interactor, event): + x,y = interactor.GetEventPosition() + x0,y0 = interactor.GetInitialEventPosition() + + ic = self.viewport2imageCoordinate((x,y)) + ic0 = self.viewport2imageCoordinate((x0,y0)) + + dx = 4 *( ic[0] - ic0[0]) + dy = 4* (ic[1] - ic0[1]) + + camera = vtk.vtkCamera() + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + newfocalpoint = [i for i in self.GetActiveCamera().GetFocalPoint()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + newposition[0] -= dx + newposition[1] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[1] = newposition[1] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[0] -= dx + newposition[2] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[2] = newposition[2] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[1] -= dx + newposition[2] -= dy + newfocalpoint[2] = newposition[2] + newfocalpoint[1] = newposition[1] + #print ("new position " + str(newposition)) + camera.SetFocalPoint(newfocalpoint) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + def HandleWindowLevel(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + print ("Event delta %d %d" % (dx,dy)) + size = self.GetRenderWindow().GetSize() + + dx = 4 * dx / size[0] + dy = 4 * dy / size[1] + window = self.GetInitialWindow() + level = self.GetInitialLevel() + + if abs(window) > 0.01: + dx = dx * window + else: + dx = dx * (lambda x: -0.01 if x <0 else 0.01)(window); + + if abs(level) > 0.01: + dy = dy * level + else: + dy = dy * (lambda x: -0.01 if x <0 else 0.01)(level) + + + # Abs so that direction does not flip + + if window < 0.0: + dx = -1*dx + if level < 0.0: + dy = -1*dy + + # Compute new window level + + newWindow = dx + window + newLevel = level - dy + + # Stay away from zero and really + + if abs(newWindow) < 0.01: + newWindow = 0.01 * (lambda x: -1 if x <0 else 1)(newWindow) + + if abs(newLevel) < 0.01: + newLevel = 0.01 * (lambda x: -1 if x <0 else 1)(newLevel) + + self.GetWindowLevel().SetWindow(newWindow) + self.GetWindowLevel().SetLevel(newLevel) + + self.GetWindowLevel().Update() + self.UpdateSliceActor() + self.AdjustCamera() + + self.Render() + + def HandlePickEvent(self, interactor, event): + position = interactor.GetEventPosition() + #print ("PICK " + str(position)) + vox = self.viewport2imageCoordinate(position) + #print ("Pixel %d,%d,%d Value %f" % vox ) + self._viewer.cornerAnnotation.VisibilityOn() + self.UpdateCornerAnnotation("[%d,%d,%d] : %.2f" % vox , 0) + self.Render() + +############################################################################### + + + +class CILViewer2D(): + '''Simple Interactive Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600, ren=None, renWin=None,iren=None): + '''creates the rendering pipeline''' + # create a rendering window and renderer + if ren == None: + self.ren = vtk.vtkRenderer() + else: + self.ren = ren + if renWin == None: + self.renWin = vtk.vtkRenderWindow() + else: + self.renWin = renWin + if iren == None: + self.iren = vtk.vtkRenderWindowInteractor() + else: + self.iren = iren + + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + self.style = CILInteractorStyle(self) + + self.iren.SetInteractorStyle(self.style) + self.iren.SetRenderWindow(self.renWin) + self.iren.Initialize() + self.ren.SetBackground(.1, .2, .4) + + self.camera = vtk.vtkCamera() + self.camera.ParallelProjectionOn() + self.ren.SetActiveCamera(self.camera) + + # data + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + + #Actors + self.sliceActor = vtk.vtkImageActor() + self.voi = vtk.vtkExtractVOI() + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia = vtk.vtkImageAccumulate() + self.sliceActorNo = 0 + + #initial Window/Level + self.InitialLevel = 0 + self.InitialWindow = 0 + + #ViewerEvent + self.event = ViewerEvent.NO_EVENT + + # ROI Widget + self.ROIWidget = vtk.vtkBorderWidget() + self.ROIWidget.SetInteractor(self.iren) + self.ROIWidget.CreateDefaultRepresentation() + self.ROIWidget.GetBorderRepresentation().GetBorderProperty().SetColor(0,1,0) + self.ROIWidget.AddObserver(vtk.vtkWidgetEvent.Select, self.style.OnROIModifiedEvent, 1.0) + + # edge points of the ROI + self.ROI = () + + #picker + self.picker = vtk.vtkPropPicker() + self.picker.PickFromListOn() + self.picker.AddPickList(self.sliceActor) + + self.iren.SetPicker(self.picker) + + # corner annotation + self.cornerAnnotation = vtk.vtkCornerAnnotation() + self.cornerAnnotation.SetMaximumFontSize(12); + self.cornerAnnotation.PickableOff(); + self.cornerAnnotation.VisibilityOff(); + self.cornerAnnotation.GetTextProperty().ShadowOn(); + self.cornerAnnotation.SetLayerNumber(1); + + + + # cursor doesn't show up + self.cursor = vtk.vtkCursor2D() + self.cursorMapper = vtk.vtkPolyDataMapper2D() + self.cursorActor = vtk.vtkActor2D() + self.cursor.SetModelBounds(-10, 10, -10, 10, 0, 0) + self.cursor.SetFocalPoint(0, 0, 0) + self.cursor.AllOff() + self.cursor.AxesOn() + self.cursorActor.PickableOff() + self.cursorActor.VisibilityOn() + self.cursorActor.GetProperty().SetColor(1, 1, 1) + self.cursorActor.SetLayerNumber(1) + self.cursorMapper.SetInputData(self.cursor.GetOutput()) + self.cursorActor.SetMapper(self.cursorMapper) + + # Zoom + self.InitialCameraPosition = () + + # XY Plot actor for histogram + self.displayHistogram = False + self.firstHistogram = 0 + self.roiIA = vtk.vtkImageAccumulate() + self.roiVOI = vtk.vtkExtractVOI() + self.histogramPlotActor = vtk.vtkXYPlotActor() + self.histogramPlotActor.ExchangeAxesOff(); + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetAdjustXLabels(3) + self.histogramPlotActor.SetXTitle( "Level" ) + self.histogramPlotActor.SetYTitle( "N" ) + self.histogramPlotActor.SetXValuesToValue() + self.histogramPlotActor.SetPlotColor(0, (0,1,1) ) + self.histogramPlotActor.SetPosition(0.6,0.6) + self.histogramPlotActor.SetPosition2(0.4,0.4) + + + + def GetInteractor(self): + return self.iren + + def GetRenderer(self): + return self.ren + + def setInput3DData(self, imageData): + self.img3D = imageData + self.installPipeline() + + def setInputAsNumpy(self, numpyarray, origin=(0,0,0), spacing=(1.,1.,1.), + rescale=True, dtype=vtk.VTK_UNSIGNED_SHORT): + importer = Converter.numpy2vtkImporter(numpyarray, spacing, origin) + importer.Update() + + if rescale: + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(importer.GetOutput()) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + if (iMax - iMin == 0): + scale = 1 + else: + if dtype == vtk.VTK_UNSIGNED_SHORT: + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + elif dtype == vtk.VTK_UNSIGNED_INT: + scale = vtk.VTK_UNSIGNED_INT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(importer.GetOutput()) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(-iMin) + shiftScaler.SetOutputScalarType(dtype) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + else: + self.img3D = importer.GetOutput() + + self.installPipeline() + + def displaySlice(self, sliceno = 0): + self.sliceno = sliceno + + self.updatePipeline() + + self.renWin.Render() + + return self.sliceActorNo + + def updatePipeline(self, resetcamera = False): + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + self.ia.Update() + self.wl.Update() + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + + self.updateCornerAnnotation("Slice %d/%d" % (self.sliceno + 1 , self.img3D.GetDimensions()[self.sliceOrientation])) + + if self.displayHistogram: + self.updateROIHistogram() + + self.AdjustCamera(resetcamera) + + self.renWin.Render() + + + def installPipeline(self): + '''Slices a 3D volume and then creates an actor to be rendered''' + + self.ren.AddViewProp(self.cornerAnnotation) + + self.voi.SetInputData(self.img3D) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + # set window/level for current slices + + + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia.SetInputData(self.voi.GetOutput()) + self.ia.Update() + cmax = self.ia.GetMax()[0] + cmin = self.ia.GetMin()[0] + + self.InitialLevel = (cmax+cmin)/2 + self.InitialWindow = cmax-cmin + + + self.wl.SetLevel(self.InitialLevel) + self.wl.SetWindow(self.InitialWindow) + + self.wl.SetInputData(self.voi.GetOutput()) + self.wl.Update() + + self.sliceActor.SetInputData(self.wl.GetOutput()) + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + self.sliceActor.SetInterpolate(False) + self.ren.AddActor(self.sliceActor) + self.ren.ResetCamera() + self.ren.Render() + + self.AdjustCamera() + + self.ren.AddViewProp(self.cursorActor) + self.cursorActor.VisibilityOn() + + self.iren.Initialize() + self.renWin.Render() + #self.iren.Start() + + def AdjustCamera(self, resetcamera = False): + self.ren.ResetCameraClippingRange() + if resetcamera: + self.ren.ResetCamera() + + + def getROI(self): + return self.ROI + + def getROIExtent(self): + p0 = self.ROI[0] + p1 = self.ROI[1] + return (p0[0], p1[0],p0[1],p1[1],p0[2],p1[2]) + + ############### Handle events are moved to the interactor style + + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.picker.Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.picker.GetPickPosition()) + pickPosition[self.sliceOrientation] = \ + self.img3D.GetSpacing()[self.sliceOrientation] * self.sliceno + \ + self.img3D.GetOrigin()[self.sliceOrientation] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.img3D.GetDimensions() + print (dims) + spac = self.img3D.GetSpacing() + orig = self.img3D.GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.img3D.GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + def GetRenderWindow(self): + return self.renWin + + + def startRenderLoop(self): + self.iren.Start() + + def GetSliceOrientation(self): + return self.sliceOrientation + + def GetActiveSlice(self): + return self.sliceno + + def updateCornerAnnotation(self, text , idx=0, visibility=True): + if visibility: + self.cornerAnnotation.VisibilityOn() + else: + self.cornerAnnotation.VisibilityOff() + + self.cornerAnnotation.SetText(idx, text) + self.iren.Render() + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + def updateROIHistogram(self): + + extent = [0 for i in range(6)] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + extent[4] = self.GetActiveSlice() + extent[5] = self.GetActiveSlice()+1 + #y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + #x = abs(roi[1][0] - roi[0][0]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[2] = self.GetActiveSlice() + extent[3] = self.GetActiveSlice()+1 + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + #x = abs(roi[1][1] - roi[0][1]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[0] = self.GetActiveSlice() + extent[1] = self.GetActiveSlice()+1 + + self.roiVOI.SetVOI(extent) + self.roiVOI.SetInputData(self.img3D) + self.roiVOI.Update() + irange = self.roiVOI.GetOutput().GetScalarRange() + + self.roiIA.SetInputData(self.roiVOI.GetOutput()) + self.roiIA.IgnoreZeroOff() + self.roiIA.SetComponentExtent(0,int(irange[1]-irange[0]-1),0,0,0,0 ) + self.roiIA.SetComponentOrigin( int(irange[0]),0,0 ); + self.roiIA.SetComponentSpacing( 1,0,0 ); + self.roiIA.Update() + + self.histogramPlotActor.AddDataSetInputConnection(self.roiIA.GetOutputPort()) + self.histogramPlotActor.SetXRange(irange[0],irange[1]) + + self.histogramPlotActor.SetYRange( self.roiIA.GetOutput().GetScalarRange() ) + + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/QVTKWidget.py b/src/Python/ccpi/viewer/QVTKWidget.py new file mode 100644 index 0000000..906786b --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget.py @@ -0,0 +1,340 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter + +class QVTKWidget(QtWidgets.QWidget): + + """ A QVTKWidget for Python and Qt.""" + + # Map between VTK and Qt cursors. + _CURSOR_MAP = { + 0: QtCore.Qt.ArrowCursor, # VTK_CURSOR_DEFAULT + 1: QtCore.Qt.ArrowCursor, # VTK_CURSOR_ARROW + 2: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZENE + 3: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZENWSE + 4: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZESW + 5: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZESE + 6: QtCore.Qt.SizeVerCursor, # VTK_CURSOR_SIZENS + 7: QtCore.Qt.SizeHorCursor, # VTK_CURSOR_SIZEWE + 8: QtCore.Qt.SizeAllCursor, # VTK_CURSOR_SIZEALL + 9: QtCore.Qt.PointingHandCursor, # VTK_CURSOR_HAND + 10: QtCore.Qt.CrossCursor, # VTK_CURSOR_CROSSHAIR + } + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + # the current button + self._ActiveButton = QtCore.Qt.NoButton + + # private attributes + self.__oldFocus = None + self.__saveX = 0 + self.__saveY = 0 + self.__saveModifiers = QtCore.Qt.NoModifier + self.__saveButtons = QtCore.Qt.NoButton + self.__timeframe = 0 + + # create qt-level widget + QtWidgets.QWidget.__init__(self, parent, wflags|QtCore.Qt.MSWindowsOwnDC) + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D() + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + + self._Iren.Register(self._RenderWindow) + self._Iren.SetRenderWindow(self._RenderWindow) + self._RenderWindow.SetWindowInfo(str(int(self.winId()))) + + # do all the necessary qt setup + self.setAttribute(QtCore.Qt.WA_OpaquePaintEvent) + self.setAttribute(QtCore.Qt.WA_PaintOnScreen) + self.setMouseTracking(True) # get all mouse events + self.setFocusPolicy(QtCore.Qt.WheelFocus) + self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)) + + self._Timer = QtCore.QTimer(self) + #self.connect(self._Timer, QtCore.pyqtSignal('timeout()'), self.TimerEvent) + + self._Iren.AddObserver('CreateTimerEvent', self.CreateTimer) + self._Iren.AddObserver('DestroyTimerEvent', self.DestroyTimer) + self._Iren.GetRenderWindow().AddObserver('CursorChangedEvent', + self.CursorChangedEvent) + + # Destructor + def __del__(self): + self._Iren.UnRegister(self._RenderWindow) + #QtWidgets.QWidget.__del__(self) + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + + # GetInteractor + def GetInteractor(self): + return self._Iren + + # Display image data + def GetPyveViewer(self): + return self._PyveViewer + + def __getattr__(self, attr): + """Makes the object behave like a vtkGenericRenderWindowInteractor""" + print (attr) + if attr == '__vtk__': + return lambda t=self._Iren: t + elif hasattr(self._Iren, attr): + return getattr(self._Iren, attr) +# else: +# raise AttributeError( self.__class__.__name__ + \ +# " has no attribute named " + attr ) + + def CreateTimer(self, obj, evt): + self._Timer.start(10) + + def DestroyTimer(self, obj, evt): + self._Timer.stop() + return 1 + + def TimerEvent(self): + self._Iren.InvokeEvent("TimerEvent") + + def CursorChangedEvent(self, obj, evt): + """Called when the CursorChangedEvent fires on the render window.""" + # This indirection is needed since when the event fires, the current + # cursor is not yet set so we defer this by which time the current + # cursor should have been set. + QtCore.QTimer.singleShot(0, self.ShowCursor) + + def HideCursor(self): + """Hides the cursor.""" + self.setCursor(QtCore.Qt.BlankCursor) + + def ShowCursor(self): + """Shows the cursor.""" + vtk_cursor = self._Iren.GetRenderWindow().GetCurrentCursor() + qt_cursor = self._CURSOR_MAP.get(vtk_cursor, QtCore.Qt.ArrowCursor) + self.setCursor(qt_cursor) + + def sizeHint(self): + return QtCore.QSize(400, 400) + + def paintEngine(self): + return None + + def paintEvent(self, ev): + self._RenderWindow.Render() + + def resizeEvent(self, ev): + self._RenderWindow.Render() + w = self.width() + h = self.height() + + self._RenderWindow.SetSize(w, h) + self._Iren.SetSize(w, h) + + def _GetCtrlShiftAlt(self, ev): + ctrl = shift = alt = False + + if hasattr(ev, 'modifiers'): + if ev.modifiers() & QtCore.Qt.ShiftModifier: + shift = True + if ev.modifiers() & QtCore.Qt.ControlModifier: + ctrl = True + if ev.modifiers() & QtCore.Qt.AltModifier: + alt = True + else: + if self.__saveModifiers & QtCore.Qt.ShiftModifier: + shift = True + if self.__saveModifiers & QtCore.Qt.ControlModifier: + ctrl = True + if self.__saveModifiers & QtCore.Qt.AltModifier: + alt = True + + return ctrl, shift, alt + + def enterEvent(self, ev): + if not self.hasFocus(): + self.__oldFocus = self.focusWidget() + self.setFocus() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("EnterEvent") + + def leaveEvent(self, ev): + if self.__saveButtons == QtCore.Qt.NoButton and self.__oldFocus: + self.__oldFocus.setFocus() + self.__oldFocus = None + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("LeaveEvent") + + def mousePressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + repeat = 0 + if ev.type() == QtCore.QEvent.MouseButtonDblClick: + repeat = 1 + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), repeat, None) + + self._Iren.SetAltKey(alt) + self._ActiveButton = ev.button() + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonPressEvent") + + def mouseReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonReleaseEvent") + + def mouseMoveEvent(self, ev): + self.__saveModifiers = ev.modifiers() + self.__saveButtons = ev.buttons() + self.__saveX = ev.x() + self.__saveY = ev.y() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("MouseMoveEvent") + + def keyPressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = str(ev.text()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyPressEvent") + self._Iren.InvokeEvent("CharEvent") + + def keyReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = chr(ev.key()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyReleaseEvent") + + def wheelEvent(self, ev): + print ("angleDeltaX %d" % ev.angleDelta().x()) + print ("angleDeltaY %d" % ev.angleDelta().y()) + if ev.angleDelta().y() >= 0: + self._Iren.InvokeEvent("MouseWheelForwardEvent") + else: + self._Iren.InvokeEvent("MouseWheelBackwardEvent") + + def GetRenderWindow(self): + return self._RenderWindow + + def Render(self): + self.update() + + +def QVTKExample(): + """A simple example that uses the QVTKWidget class.""" + + # every QT app needs an app + app = QtWidgets.QApplication(['PyVE QVTKWidget Example']) + page_VTK = QtWidgets.QWidget() + page_VTK.resize(500,500) + layout = QtWidgets.QVBoxLayout(page_VTK) + # create the widget + widget = QVTKWidget(parent=None) + layout.addWidget(widget) + + #reader = vtk.vtkPNGReader() + #reader.SetFileName("F:\Diagnostics\Images\PyVE\VTKData\Data\camscene.png") + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + widget.SetInput(reader.GetOutput()) + + # show the widget + page_VTK.show() + # start event processing + app.exec_() + +if __name__ == "__main__": + QVTKExample() diff --git a/src/Python/ccpi/viewer/QVTKWidget2.py b/src/Python/ccpi/viewer/QVTKWidget2.py new file mode 100644 index 0000000..e32e1c2 --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget2.py @@ -0,0 +1,84 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor + +class QVTKWidget(QVTKRenderWindowInteractor): + + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + kw = dict() + super().__init__(parent, **kw) + + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D(400,400) + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + kw['iren'] = self._Iren + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + kw['rw'] = self._RenderWindow + + + + + def GetInteractor(self): + return self._Iren + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__init__.py b/src/Python/ccpi/viewer/__init__.py new file mode 100644 index 0000000..946188b --- /dev/null +++ b/src/Python/ccpi/viewer/__init__.py @@ -0,0 +1 @@ +from ccpi.viewer.CILViewer import CILViewer \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc new file mode 100644 index 0000000..711f77a Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc new file mode 100644 index 0000000..77c2ca8 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc new file mode 100644 index 0000000..3d11b87 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc new file mode 100644 index 0000000..2fa2eaf Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..fcea537 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/embedvtk.py b/src/Python/ccpi/viewer/embedvtk.py new file mode 100644 index 0000000..b5eb0a7 --- /dev/null +++ b/src/Python/ccpi/viewer/embedvtk.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Jul 27 12:18:58 2017 + +@author: ofn77899 +""" + +#!/usr/bin/env python + +import sys +import vtk +from PyQt5 import QtCore, QtWidgets +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor +import QVTKWidget2 + +class MainWindow(QtWidgets.QMainWindow): + + def __init__(self, parent = None): + QtWidgets.QMainWindow.__init__(self, parent) + + self.frame = QtWidgets.QFrame() + + self.vl = QtWidgets.QVBoxLayout() +# self.vtkWidget = QVTKRenderWindowInteractor(self.frame) + + self.vtkWidget = QVTKWidget2.QVTKWidget(self.frame) + self.iren = self.vtkWidget.GetInteractor() + self.vl.addWidget(self.vtkWidget) + + + + + self.ren = vtk.vtkRenderer() + self.vtkWidget.GetRenderWindow().AddRenderer(self.ren) +# self.iren = self.vtkWidget.GetRenderWindow().GetInteractor() +# +# # Create source +# source = vtk.vtkSphereSource() +# source.SetCenter(0, 0, 0) +# source.SetRadius(5.0) +# +# # Create a mapper +# mapper = vtk.vtkPolyDataMapper() +# mapper.SetInputConnection(source.GetOutputPort()) +# +# # Create an actor +# actor = vtk.vtkActor() +# actor.SetMapper(mapper) +# +# self.ren.AddActor(actor) +# +# self.ren.ResetCamera() +# + self.frame.setLayout(self.vl) + self.setCentralWidget(self.frame) + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + self.vtkWidget.SetInput(reader.GetOutput()) + + #self.vktWidget.Initialize() + #self.vktWidget.Start() + + self.show() + #self.iren.Initialize() + + +if __name__ == "__main__": + + app = QtWidgets.QApplication(sys.argv) + + window = MainWindow() + + sys.exit(app.exec_()) \ No newline at end of file -- cgit v1.2.3 From ad62962697509d977087c25d24a3ff083d9c4308 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:08:32 +0100 Subject: initial revision --- src/Python/ccpi/imaging/Regularizer.py | 322 +++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 src/Python/ccpi/imaging/Regularizer.py (limited to 'src') diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py new file mode 100644 index 0000000..fb9ae08 --- /dev/null +++ b/src/Python/ccpi/imaging/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 8 14:26:00 2017 + +@author: ofn77899 +""" + +from ccpi.imaging import cpu_regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 CPU (OMP) regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) LLT_model + 4) PatchBased_Regul + 5) TGV_PD + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = cpu_regularizers.SplitBregman_TV + FGP_TV = cpu_regularizers.FGP_TV + LLT_model = cpu_regularizers.LLT_model + PatchBased_Regul = cpu_regularizers.PatchBased_Regul + TGV_PD = cpu_regularizers.TGV_PD + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm, debug = True): + self.setAlgorithm ( algorithm ) + self.debug = debug + # __init__ + + def setAlgorithm(self, algorithm): + self.algorithm = algorithm + self.pars = self.getDefaultParsForAlgorithm(algorithm) + # setAlgorithm + + def getDefaultParsForAlgorithm(self, algorithm): + pars = dict() + + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + else: + raise Exception('Unknown regularizer algorithm') + + return pars + # parsForAlgorithm + + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.pars.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + def getParameter(self, **kwargs): + ret = {} + for key , value in kwargs.items(): + if key in self.pars.keys(): + ret[key] = self.pars[key] + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + + def __call__(self, input = None, regularization_parameter = None, **kwargs): + '''Actual call for the regularizer. + + One can either set the regularization parameters first and then call the + algorithm or set the regularization parameter during the call (as + is done in the static methods). + ''' + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + if input is not None: + self.pars['input'] = input + if regularization_parameter is not None: + self.pars['regularization_parameter'] = regularization_parameter + + if self.debug: + print ("--------------------------------------------------") + for key, value in self.pars.items(): + if key== 'algorithm' : + print("{0} = {1}".format(key, value.__name__)) + elif key == 'input': + print("{0} = {1}".format(key, np.shape(value))) + else: + print("{0} = {1}".format(key, value)) + + + if None in self.pars: + raise Exception("Not all parameters have been provided") + + input = self.pars['input'] + regularization_parameter = self.pars['regularization_parameter'] + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if len(np.shape(input)) == 2: + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + elif len(np.shape(input)) == 3: + #assuming it's 3D + # run independent calls on each slice + out3d = input.copy() + for i in range(np.shape(input)[2]): + out = self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # copy the result in the 3D image + out3d.T[i] = out[0].copy() + # append the rest of the info that the algorithm returns + output = [out3d] + for i in range(1,len(out)): + output.append(out[i]) + return output + + + + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.LLT_model) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + start_time = timeit.default_timer() + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + + return out + + def printParametersToString(self): + txt = r'' + for key, value in self.pars.items(): + if key== 'algorithm' : + txt += "{0} = {1}".format(key, value.__name__) + elif key == 'input': + txt += "{0} = {1}".format(key, np.shape(value)) + else: + txt += "{0} = {1}".format(key, value) + txt += '\n' + return txt + -- cgit v1.2.3 From a2ca45848e354f376c53ecd3fed946d64c1ff3aa Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:10:04 +0100 Subject: module rename --- src/Python/fista_module.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index eacda3d..c36329e 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -1032,13 +1032,13 @@ bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_al return result; } -BOOST_PYTHON_MODULE(regularizers) +BOOST_PYTHON_MODULE(cpu_regularizers) { np::initialize(); //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "regularizers"; + package.attr("__path__") = "cpu_regularizers"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); -- cgit v1.2.3 From cc3a464ec587e95ddfd421cd3836a7677dfb9744 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 16:51:18 +0100 Subject: export/import data from hdf5 Added file to export the data from DemoRD2.m to HDF5 to pass it to Python. Added file to import the data from DemoRD2.m from HDF5. --- src/Python/test/readhd5.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/Python/test/readhd5.py (limited to 'src') diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py new file mode 100644 index 0000000..1e19e14 --- /dev/null +++ b/src/Python/test/readhd5.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +""" + +import h5py +import numpy + +def getEntry(nx, location): + for item in nx[location].keys(): + print (item) + +filename = r'C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\Demos\DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D')) +Weights3D = numpy.asarray(nx.get('/Weights3D')) +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad')) +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] \ No newline at end of file -- cgit v1.2.3 From 56915cc00ded38d24c23b9ab1a0717d52d430ddd Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 16:54:59 +0100 Subject: initial revision for testing --- .../ccpi/reconstruction/FISTAReconstructor.py | 354 +++++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/FISTAReconstructor.py (limited to 'src') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py new file mode 100644 index 0000000..ea96b53 --- /dev/null +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -0,0 +1,354 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +#from ccpi.reconstruction.parallelbeam import alg + +from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else: + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + -- cgit v1.2.3 From ed5737df1e9a613ad881d3b61c62c2627027faa4 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:39:37 +0100 Subject: bugfix --- src/Python/Matlab2Python_utils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index e15d738..ee76bc7 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -123,7 +123,7 @@ T * mxGetData(const np::ndarray pm) { probably this would work. A = reinterpret_cast(prhs[0]); */ - return reinterpret_cast(prhs[0]); + //return reinterpret_cast(prhs[0]); } template @@ -273,4 +273,4 @@ BOOST_PYTHON_MODULE(prova) //numpy_boost_python_register_type(); def("mexFunction", mexFunction); def("doSomething", doSomething); -} \ No newline at end of file +} -- cgit v1.2.3 From 82d2db8a7514c850887c18143626539b7ca8b794 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:41:10 +0100 Subject: initial facility to test the FISTA --- src/Python/test_reconstructor.py | 179 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 src/Python/test_reconstructor.py (limited to 'src') diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py new file mode 100644 index 0000000..0fd08f5 --- /dev/null +++ b/src/Python/test_reconstructor.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.reconstruction_dev.FISTAReconstructor import FISTAReconstructor +import astra + +##def getEntry(nx, location): +## for item in nx[location].keys(): +## print (item) + +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D')) +Weights3D = numpy.asarray(nx.get('/Weights3D')) +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad')) +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 3 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, + image_size_y, + image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +#fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D ) + #N = params.vol_geom.GridColCount + +pars = dict() +pars['projector_geometry'] = proj_geom +pars['output_geometry'] = vol_geom +pars['input_sinogram'] = Sino3D +sliceZ , nangles , detectors = numpy.shape(Sino3D) +pars['detectors'] = detectors +pars['number_of_angles'] = nangles +pars['SlicesZ'] = sliceZ + + +pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) + +N = pars['output_geometry']['GridColCount'] +proj_geom = pars['projector_geometry'] +vol_geom = pars['output_geometry'] +weights = pars['weights'] +SlicesZ = pars['SlicesZ'] + +if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights[0]) + proj_geomT = proj_geom.copy(); + proj_geomT['DetectorRowCount'] = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + + import matplotlib.pyplot as plt + fig = plt.figure() + + #a.set_title('Lipschitz') + for i in range(niter): +# [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); +# s = norm(x1(:)); +# x1 = x1/s; +# [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); +# y = sqweight.*y; +# astra_mex_data3d('delete', sino_id); +# astra_mex_data3d('delete', id); + print ("iteration {0}".format(i)) + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + #a=fig.add_subplot(2,1,1) + #imgplot = plt.imshow(y[0]) + + y = sqweight * y # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geomT, + vol_geomT); + print ("shape {1} x1 {0}".format(x1.T[:4].T, numpy.shape(x1))) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + print ("x1 {0}".format(x1.T[:4].T)) + +# ### this line? +# sino_id, y = astra.creators.create_sino3d_gpu(x1, +# proj_geomT, +# vol_geomT); +# y = sqweight * y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT +else: + #% divergen beam geometry + print('Calculating Lipshitz constant for divergen beam geometry...') + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 -- cgit v1.2.3 From 7111d98258becca09e4c93e3c66edb7d524d6463 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:05 +0100 Subject: initial revision --- src/Python/ccpi/imaging/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/imaging/__init__.py (limited to 'src') diff --git a/src/Python/ccpi/imaging/__init__.py b/src/Python/ccpi/imaging/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 3f26b1d8ab3a632ceca97bdf04225008f9163684 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:27 +0100 Subject: initial revision --- src/Python/ccpi/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/__init__.py (limited to 'src') diff --git a/src/Python/ccpi/__init__.py b/src/Python/ccpi/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 447756a338dfa993e2969298af19f1f9707a409a Mon Sep 17 00:00:00 2001 From: algol Date: Fri, 25 Aug 2017 15:58:52 +0100 Subject: removed vtk dependency --- src/Python/test_regularizers.py | 124 +++++++++++++++++++++++++++------------- 1 file changed, 85 insertions(+), 39 deletions(-) (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 755804a..86849eb 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -5,15 +5,15 @@ Created on Fri Aug 4 11:10:05 2017 @author: ofn77899 """ -from ccpi.viewer.CILViewer2D import Converter -import vtk +#from ccpi.viewer.CILViewer2D import Converter +#import vtk import matplotlib.pyplot as plt import numpy as np import os from enum import Enum import timeit - +#from PIL import Image #from Regularizer import Regularizer from ccpi.imaging.Regularizer import Regularizer @@ -46,12 +46,20 @@ def nrmse(im1, im2): # u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; # u = SplitBregman_TV(single(u0), 10, 30, 1e-04); -filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" -reader = vtk.vtkTIFFReader() -reader.SetFileName(os.path.normpath(filename)) -reader.Update() +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +#reader = vtk.vtkTIFFReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() #vtk returns 3D images, let's take just the one slice there is as 2D -Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 +#Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 +filename = '/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif' +Im = plt.imread(filename) +#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255 +#img.show() +Im = np.asarray(Im, dtype='float32') + + + #imgplot = plt.imshow(Im) perc = 0.05 @@ -80,6 +88,7 @@ reg_output = [] use_object = True if use_object: reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + print (reg.pars) reg.setParameter(input=u0) reg.setParameter(regularization_parameter=10.) # or @@ -113,7 +122,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(plotme) +imgplot = plt.imshow(plotme,cmap="gray") ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); @@ -125,14 +134,32 @@ reg_output.append(out2) a=fig.add_subplot(2,3,3) +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) + + textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) + verticalalignment='top', bbox=props) imgplot = plt.imshow(reg_output[-1][0]) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") ###################### LLT_model ######################################### # * u0 = Im + .03*randn(size(Im)); % adding noise @@ -149,13 +176,32 @@ pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,4) +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) + + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") # ###################### PatchBased_Regul ######################################### # # Quick 2D denoising example in Matlab: @@ -163,24 +209,24 @@ imgplot = plt.imshow(reg_output[-1][0]) # # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise # # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); -# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - # searching_window_ratio=3, - # similarity_window_ratio=1, - # PB_filtering_parameter=0.08) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,5) +a=fig.add_subplot(2,3,5) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") # ###################### TGV_PD ######################################### @@ -190,25 +236,25 @@ imgplot = plt.imshow(reg_output[-1][0]) # # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - # first_order_term=1.3, - # second_order_term=1, - # number_of_iterations=550) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,6) +a=fig.add_subplot(2,3,6) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") plt.show() -- cgit v1.2.3 From fb5e0ad0ad94f5b919b17f3223834380dce683d4 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 16:38:52 +0100 Subject: The calculation of the Lipschitz constant works --- src/Python/test_reconstructor.py | 60 ++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 20 deletions(-) (limited to 'src') diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index 0fd08f5..76ce3ac 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -23,10 +23,10 @@ nx = h5py.File(filename, "r") entries = [entry for entry in nx['/'].keys()] print (entries) -Sino3D = numpy.asarray(nx.get('/Sino3D')) -Weights3D = numpy.asarray(nx.get('/Weights3D')) +Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") +Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] -angles_rad = numpy.asarray(nx.get('/angles_rad')) +angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] @@ -62,17 +62,18 @@ vol_geom = astra.creators.create_vol_geom( image_size_x, #N = params.vol_geom.GridColCount pars = dict() -pars['projector_geometry'] = proj_geom -pars['output_geometry'] = vol_geom -pars['input_sinogram'] = Sino3D +pars['projector_geometry'] = proj_geom.copy() +pars['output_geometry'] = vol_geom.copy() +pars['input_sinogram'] = Sino3D.copy() sliceZ , nangles , detectors = numpy.shape(Sino3D) pars['detectors'] = detectors pars['number_of_angles'] = nangles pars['SlicesZ'] = sliceZ -pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) - +#pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) +pars['weights'] = Weights3D.copy() + N = pars['output_geometry']['GridColCount'] proj_geom = pars['projector_geometry'] vol_geom = pars['output_geometry'] @@ -82,7 +83,7 @@ SlicesZ = pars['SlicesZ'] if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice print('Calculating Lipshitz constant for parallel beam geometry...') - niter = 15;# % number of iteration for the PM + niter = 16;# % number of iteration for the PM #N = params.vol_geom.GridColCount; #x1 = rand(N,N,1); x1 = numpy.random.rand(1,N,N) @@ -96,7 +97,8 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); import matplotlib.pyplot as plt - fig = plt.figure() + fig = [] + props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) #a.set_title('Lipschitz') for i in range(niter): @@ -107,14 +109,27 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): # y = sqweight.*y; # astra_mex_data3d('delete', sino_id); # astra_mex_data3d('delete', id); - print ("iteration {0}".format(i)) + #print ("iteration {0}".format(i)) + fig.append(plt.figure()) + + a=fig[-1].add_subplot(1,2,1) + a.text(0.05, 0.95, "iteration {0}, x1".format(i), transform=a.transAxes, + fontsize=14,verticalalignment='top', bbox=props) + + imgplot = plt.imshow(x1[0].copy()) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT) - #a=fig.add_subplot(2,1,1) - #imgplot = plt.imshow(y[0]) + a=fig[-1].add_subplot(1,2,2) + a.text(0.05, 0.95, "iteration {0}, y".format(i), + transform=a.transAxes, fontsize=14,verticalalignment='top', + bbox=props) + + imgplot = plt.imshow(y[0].copy()) - y = sqweight * y # element wise multiplication + y = (sqweight * y).copy() # element wise multiplication #b=fig.add_subplot(2,1,2) #imgplot = plt.imshow(x1[0]) @@ -122,15 +137,17 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #astra_mex_data3d('delete', sino_id); astra.matlab.data3d('delete', sino_id) + del x1 - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), proj_geomT, - vol_geomT); - print ("shape {1} x1 {0}".format(x1.T[:4].T, numpy.shape(x1))) + vol_geomT) + del y + + s = numpy.linalg.norm(x1) ### this line? - x1 = x1/s; - print ("x1 {0}".format(x1.T[:4].T)) + x1 = (x1/s).copy(); # ### this line? # sino_id, y = astra.creators.create_sino3d_gpu(x1, @@ -138,10 +155,13 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): # vol_geomT); # y = sqweight * y; astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + #end del proj_geomT del vol_geomT + #plt.show() else: #% divergen beam geometry print('Calculating Lipshitz constant for divergen beam geometry...') -- cgit v1.2.3 From 391473269674bc98697eabac0b4fb2bd89f5d85e Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 16:55:48 +0100 Subject: Reorganized code with new fista package name --- src/Python/ccpi/fista/FISTAReconstructor.py | 389 ++++++++++++++ src/Python/ccpi/fista/FISTAReconstructor.pyc | Bin 0 -> 3804 bytes src/Python/ccpi/fista/FISTAReconstructor.py~ | 349 ++++++++++++ src/Python/ccpi/fista/Reconstructor.py | 425 +++++++++++++++ src/Python/ccpi/fista/Reconstructor.py~ | 598 +++++++++++++++++++++ src/Python/ccpi/fista/__init__.py | 0 src/Python/ccpi/fista/__init__.pyc | Bin 0 -> 189 bytes .../__pycache__/FISTAReconstructor.cpython-35.pyc | Bin 0 -> 3641 bytes .../ccpi/fista/__pycache__/__init__.cpython-35.pyc | Bin 0 -> 185 bytes src/Python/test_reconstructor.py | 11 +- 10 files changed, 1767 insertions(+), 5 deletions(-) create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.pyc create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py~ create mode 100644 src/Python/ccpi/fista/Reconstructor.py create mode 100644 src/Python/ccpi/fista/Reconstructor.py~ create mode 100644 src/Python/ccpi/fista/__init__.py create mode 100644 src/Python/ccpi/fista/__init__.pyc create mode 100644 src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc create mode 100644 src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py new file mode 100644 index 0000000..1e76815 --- /dev/null +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -0,0 +1,389 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +#from ccpi.reconstruction.parallelbeam import alg + +#from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + # handle parmeters: + # obligatory parameters + self.pars = dict() + self.pars['projector_geometry'] = projector_geometry + self.pars['output_geometry'] = output_geometry + self.pars['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.pars['detectors'] = detectors + self.pars['number_og_angles'] = nangles + self.pars['SlicesZ'] = sliceZ + + print (self.pars) + # handle optional input parameters (at instantiation) + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.pars['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not 'ideal_image' in kwargs.keys(): + self.pars['ideal_image'] = None + + if not 'region_of_interest'in kwargs.keys() : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not 'regularizer' in kwargs.keys() : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + N = self.pars['output_geometry']['GridColCount'] + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 5;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights[0]) + proj_geomT = proj_geom.copy(); + proj_geomT['DetectorRowCount'] = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + + + for i in range(niter): + # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); + # s = norm(x1(:)); + # x1 = x1/s; + # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + # y = sqweight.*y; + # astra_mex_data3d('delete', sino_id); + # astra_mex_data3d('delete', id); + #print ("iteration {0}".format(i)) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + + y = (sqweight * y).copy() # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + del x1 + + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + proj_geomT, + vol_geomT) + del y + + + s = numpy.linalg.norm(x1) + ### this line? + x1 = (x1/s).copy(); + + # ### this line? + # sino_id, y = astra.creators.create_sino3d_gpu(x1, + # proj_geomT, + # vol_geomT); + # y = sqweight * y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + + #end + del proj_geomT + del vol_geomT + #plt.show() + else: + #% divergen beam geometry + print('Calculating Lipshitz constant for divergen beam geometry...') + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location, nx): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + diff --git a/src/Python/ccpi/fista/FISTAReconstructor.pyc b/src/Python/ccpi/fista/FISTAReconstructor.pyc new file mode 100644 index 0000000..ecc4d7d Binary files /dev/null and b/src/Python/ccpi/fista/FISTAReconstructor.pyc differ diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py~ b/src/Python/ccpi/fista/FISTAReconstructor.py~ new file mode 100644 index 0000000..6c7024d --- /dev/null +++ b/src/Python/ccpi/fista/FISTAReconstructor.py~ @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +#from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + + diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py new file mode 100644 index 0000000..d29ac0d --- /dev/null +++ b/src/Python/ccpi/fista/Reconstructor.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() diff --git a/src/Python/ccpi/fista/Reconstructor.py~ b/src/Python/ccpi/fista/Reconstructor.py~ new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/fista/Reconstructor.py~ @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: + + class Algorithm(Enum): + CGLS = alg.cgls + CGLS_CONV = alg.cgls_conv + SIRT = alg.sirt + MLEM = alg.mlem + CGLS_TICHONOV = alg.cgls_tikhonov + CGLS_TVREG = alg.cgls_TVreg + FISTA = 'fista' + + def __init__(self, algorithm = None, projection_data = None, + angles = None, center_of_rotation = None , + flat_field = None, dark_field = None, + iterations = None, resolution = None, isLogScale = False, threads = None, + normalized_projection = None): + + self.pars = dict() + self.pars['algorithm'] = algorithm + self.pars['projection_data'] = projection_data + self.pars['normalized_projection'] = normalized_projection + self.pars['angles'] = angles + self.pars['center_of_rotation'] = numpy.double(center_of_rotation) + self.pars['flat_field'] = flat_field + self.pars['iterations'] = iterations + self.pars['dark_field'] = dark_field + self.pars['resolution'] = resolution + self.pars['isLogScale'] = isLogScale + self.pars['threads'] = threads + if (iterations != None): + self.pars['iterationValues'] = numpy.zeros((iterations)) + + if projection_data != None and dark_field != None and flat_field != None: + norm = self.normalize(projection_data, dark_field, flat_field, 0.1) + self.pars['normalized_projection'] = norm + + + def setPars(self, parameters): + keys = ['algorithm','projection_data' ,'normalized_projection', \ + 'angles' , 'center_of_rotation' , 'flat_field', \ + 'iterations','dark_field' , 'resolution', 'isLogScale' , \ + 'threads' , 'iterationValues', 'regularize'] + + for k in keys: + if k not in parameters.keys(): + self.pars[k] = None + else: + self.pars[k] = parameters[k] + + + def sanityCheck(self): + projection_data = self.pars['projection_data'] + dark_field = self.pars['dark_field'] + flat_field = self.pars['flat_field'] + angles = self.pars['angles'] + + if projection_data != None and dark_field != None and \ + angles != None and flat_field != None: + data_shape = numpy.shape(projection_data) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + + if data_shape[1:] != numpy.shape(flat_field): + #raise Exception('Projection and flat field dimensions do not match') + return (False , 'Projection and flat field dimensions do not match') + if data_shape[1:] != numpy.shape(dark_field): + #raise Exception('Projection and dark field dimensions do not match') + return (False , 'Projection and dark field dimensions do not match') + + return (True , '' ) + elif self.pars['normalized_projection'] != None: + data_shape = numpy.shape(self.pars['normalized_projection']) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + else: + return (True , '' ) + else: + return (False , 'Not enough data') + + def reconstruct(self, parameters = None): + if parameters != None: + self.setPars(parameters) + + go , reason = self.sanityCheck() + if go: + return self._reconstruct() + else: + raise Exception(reason) + + + def _reconstruct(self, parameters=None): + if parameters!=None: + self.setPars(parameters) + parameters = self.pars + + if parameters['algorithm'] != None and \ + parameters['normalized_projection'] != None and \ + parameters['angles'] != None and \ + parameters['center_of_rotation'] != None and \ + parameters['iterations'] != None and \ + parameters['resolution'] != None and\ + parameters['threads'] != None and\ + parameters['isLogScale'] != None: + + + if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, + Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): + #store parameters + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['isLogScale'] + ) + return result + elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, + Reconstructor.Algorithm.CGLS_TICHONOV, + Reconstructor.Algorithm.CGLS_TVREG) : + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['regularize'], + numpy.zeros((parameters['iterations'])), + parameters['isLogScale'] + ) + + elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: + pass + + else: + if parameters['projection_data'] != None and \ + parameters['dark_field'] != None and \ + parameters['flat_field'] != None: + norm = self.normalize(parameters['projection_data'], + parameters['dark_field'], + parameters['flat_field'], 0.1) + self.pars['normalized_projection'] = norm + return self._reconstruct(parameters) + + + + def _normalize(self, projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + def normalize(self, projections, dark, flat, def_val=0): + norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] + return numpy.asarray (norm, dtype=numpy.float32) + + + +class FISTA(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/Python/ccpi/fista/__init__.pyc b/src/Python/ccpi/fista/__init__.pyc new file mode 100644 index 0000000..719e264 Binary files /dev/null and b/src/Python/ccpi/fista/__init__.pyc differ diff --git a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc new file mode 100644 index 0000000..84f16e2 Binary files /dev/null and b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc differ diff --git a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..90c23ff Binary files /dev/null and b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc differ diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index 76ce3ac..6f46e96 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -58,7 +58,8 @@ vol_geom = astra.creators.create_vol_geom( image_size_x, ## First pass the arguments to the FISTAReconstructor and test the ## Lipschitz constant -#fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D ) +fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D , weights=Weights3D) +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) #N = params.vol_geom.GridColCount pars = dict() @@ -83,7 +84,7 @@ SlicesZ = pars['SlicesZ'] if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice print('Calculating Lipshitz constant for parallel beam geometry...') - niter = 16;# % number of iteration for the PM + niter = 5;# % number of iteration for the PM #N = params.vol_geom.GridColCount; #x1 = rand(N,N,1); x1 = numpy.random.rand(1,N,N) @@ -129,7 +130,7 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): imgplot = plt.imshow(y[0].copy()) - y = (sqweight * y).copy() # element wise multiplication + y = (sqweight * y) # element wise multiplication #b=fig.add_subplot(2,1,2) #imgplot = plt.imshow(x1[0]) @@ -139,7 +140,7 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): astra.matlab.data3d('delete', sino_id) del x1 - idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y), proj_geomT, vol_geomT) del y @@ -147,7 +148,7 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): s = numpy.linalg.norm(x1) ### this line? - x1 = (x1/s).copy(); + x1 = (x1/s) # ### this line? # sino_id, y = astra.creators.create_sino3d_gpu(x1, -- cgit v1.2.3 From 64c0b9e7a1bcfc54e6ed8b57274d53c3ed9bb950 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 17:03:17 +0100 Subject: use refactore code --- src/Python/ccpi/fista/FISTAReconstructor.pyc | Bin 3804 -> 0 bytes src/Python/ccpi/fista/FISTAReconstructor.py~ | 349 ------------ src/Python/ccpi/fista/Reconstructor.py~ | 598 --------------------- src/Python/ccpi/fista/__init__.pyc | Bin 189 -> 0 bytes .../__pycache__/FISTAReconstructor.cpython-35.pyc | Bin 3641 -> 0 bytes .../ccpi/fista/__pycache__/__init__.cpython-35.pyc | Bin 185 -> 0 bytes src/Python/test_reconstructor.py | 2 +- 7 files changed, 1 insertion(+), 948 deletions(-) delete mode 100644 src/Python/ccpi/fista/FISTAReconstructor.pyc delete mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py~ delete mode 100644 src/Python/ccpi/fista/Reconstructor.py~ delete mode 100644 src/Python/ccpi/fista/__init__.pyc delete mode 100644 src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc delete mode 100644 src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.pyc b/src/Python/ccpi/fista/FISTAReconstructor.pyc deleted file mode 100644 index ecc4d7d..0000000 Binary files a/src/Python/ccpi/fista/FISTAReconstructor.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py~ b/src/Python/ccpi/fista/FISTAReconstructor.py~ deleted file mode 100644 index 6c7024d..0000000 --- a/src/Python/ccpi/fista/FISTAReconstructor.py~ +++ /dev/null @@ -1,349 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -#from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - - diff --git a/src/Python/ccpi/fista/Reconstructor.py~ b/src/Python/ccpi/fista/Reconstructor.py~ deleted file mode 100644 index ba67327..0000000 --- a/src/Python/ccpi/fista/Reconstructor.py~ +++ /dev/null @@ -1,598 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - -class Reconstructor: - - class Algorithm(Enum): - CGLS = alg.cgls - CGLS_CONV = alg.cgls_conv - SIRT = alg.sirt - MLEM = alg.mlem - CGLS_TICHONOV = alg.cgls_tikhonov - CGLS_TVREG = alg.cgls_TVreg - FISTA = 'fista' - - def __init__(self, algorithm = None, projection_data = None, - angles = None, center_of_rotation = None , - flat_field = None, dark_field = None, - iterations = None, resolution = None, isLogScale = False, threads = None, - normalized_projection = None): - - self.pars = dict() - self.pars['algorithm'] = algorithm - self.pars['projection_data'] = projection_data - self.pars['normalized_projection'] = normalized_projection - self.pars['angles'] = angles - self.pars['center_of_rotation'] = numpy.double(center_of_rotation) - self.pars['flat_field'] = flat_field - self.pars['iterations'] = iterations - self.pars['dark_field'] = dark_field - self.pars['resolution'] = resolution - self.pars['isLogScale'] = isLogScale - self.pars['threads'] = threads - if (iterations != None): - self.pars['iterationValues'] = numpy.zeros((iterations)) - - if projection_data != None and dark_field != None and flat_field != None: - norm = self.normalize(projection_data, dark_field, flat_field, 0.1) - self.pars['normalized_projection'] = norm - - - def setPars(self, parameters): - keys = ['algorithm','projection_data' ,'normalized_projection', \ - 'angles' , 'center_of_rotation' , 'flat_field', \ - 'iterations','dark_field' , 'resolution', 'isLogScale' , \ - 'threads' , 'iterationValues', 'regularize'] - - for k in keys: - if k not in parameters.keys(): - self.pars[k] = None - else: - self.pars[k] = parameters[k] - - - def sanityCheck(self): - projection_data = self.pars['projection_data'] - dark_field = self.pars['dark_field'] - flat_field = self.pars['flat_field'] - angles = self.pars['angles'] - - if projection_data != None and dark_field != None and \ - angles != None and flat_field != None: - data_shape = numpy.shape(projection_data) - angle_shape = numpy.shape(angles) - - if angle_shape[0] != data_shape[0]: - #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ - # (angle_shape[0] , data_shape[0]) ) - return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ - (angle_shape[0] , data_shape[0]) ) - - if data_shape[1:] != numpy.shape(flat_field): - #raise Exception('Projection and flat field dimensions do not match') - return (False , 'Projection and flat field dimensions do not match') - if data_shape[1:] != numpy.shape(dark_field): - #raise Exception('Projection and dark field dimensions do not match') - return (False , 'Projection and dark field dimensions do not match') - - return (True , '' ) - elif self.pars['normalized_projection'] != None: - data_shape = numpy.shape(self.pars['normalized_projection']) - angle_shape = numpy.shape(angles) - - if angle_shape[0] != data_shape[0]: - #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ - # (angle_shape[0] , data_shape[0]) ) - return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ - (angle_shape[0] , data_shape[0]) ) - else: - return (True , '' ) - else: - return (False , 'Not enough data') - - def reconstruct(self, parameters = None): - if parameters != None: - self.setPars(parameters) - - go , reason = self.sanityCheck() - if go: - return self._reconstruct() - else: - raise Exception(reason) - - - def _reconstruct(self, parameters=None): - if parameters!=None: - self.setPars(parameters) - parameters = self.pars - - if parameters['algorithm'] != None and \ - parameters['normalized_projection'] != None and \ - parameters['angles'] != None and \ - parameters['center_of_rotation'] != None and \ - parameters['iterations'] != None and \ - parameters['resolution'] != None and\ - parameters['threads'] != None and\ - parameters['isLogScale'] != None: - - - if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, - Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): - #store parameters - self.pars = parameters - result = parameters['algorithm']( - parameters['normalized_projection'] , - parameters['angles'], - parameters['center_of_rotation'], - parameters['resolution'], - parameters['iterations'], - parameters['threads'] , - parameters['isLogScale'] - ) - return result - elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, - Reconstructor.Algorithm.CGLS_TICHONOV, - Reconstructor.Algorithm.CGLS_TVREG) : - self.pars = parameters - result = parameters['algorithm']( - parameters['normalized_projection'] , - parameters['angles'], - parameters['center_of_rotation'], - parameters['resolution'], - parameters['iterations'], - parameters['threads'] , - parameters['regularize'], - numpy.zeros((parameters['iterations'])), - parameters['isLogScale'] - ) - - elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: - pass - - else: - if parameters['projection_data'] != None and \ - parameters['dark_field'] != None and \ - parameters['flat_field'] != None: - norm = self.normalize(parameters['projection_data'], - parameters['dark_field'], - parameters['flat_field'], 0.1) - self.pars['normalized_projection'] = norm - return self._reconstruct(parameters) - - - - def _normalize(self, projection, dark, flat, def_val=0): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - def normalize(self, projections, dark, flat, def_val=0): - norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] - return numpy.asarray (norm, dtype=numpy.float32) - - - -class FISTA(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - -#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) - -#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) -#img_cgls = recon.reconstruct() -# -#pars = dict() -#pars['algorithm'] = Reconstructor.Algorithm.SIRT -#pars['projection_data'] = proj -#pars['angles'] = angle_proj -#pars['center_of_rotation'] = numpy.double(86.2) -#pars['flat_field'] = flat -#pars['iterations'] = 15 -#pars['dark_field'] = dark -#pars['resolution'] = 1 -#pars['isLogScale'] = False -#pars['threads'] = 3 -# -#img_sirt = recon.reconstruct(pars) -# -#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM -#img_mlem = recon.reconstruct() - -############################################################ -############################################################ -#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV -#recon.pars['regularize'] = numpy.double(0.1) -#img_cgls_conv = recon.reconstruct() - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - -#viewer = edo.CILViewer() -#viewer.setInputAsNumpy(img_cgls2) -#viewer.displaySliceActor(0) -#viewer.startRenderLoop() - -import vtk - -def NumpyToVTKImageData(numpyarray): - if (len(numpy.shape(numpyarray)) == 3): - doubleImg = vtk.vtkImageData() - shape = numpy.shape(numpyarray) - doubleImg.SetDimensions(shape[0], shape[1], shape[2]) - doubleImg.SetOrigin(0,0,0) - doubleImg.SetSpacing(1,1,1) - doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) - #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) - doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) - - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - doubleImg.SetScalarComponentFromDouble( - i,j,k,0, numpyarray[i][j][k]) - #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(doubleImg) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(doubleImg) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(iMin) - shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) - shiftScaler.Update() - return shiftScaler.GetOutput() - -#writer = vtk.vtkMetaImageWriter() -#writer.SetFileName(alg + "_recon.mha") -#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) -#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.pyc b/src/Python/ccpi/fista/__init__.pyc deleted file mode 100644 index 719e264..0000000 Binary files a/src/Python/ccpi/fista/__init__.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc deleted file mode 100644 index 84f16e2..0000000 Binary files a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc deleted file mode 100644 index 90c23ff..0000000 Binary files a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc and /dev/null differ diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index 6f46e96..a4a622b 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -9,7 +9,7 @@ Based on DemoRD2.m import h5py import numpy -from ccpi.reconstruction_dev.FISTAReconstructor import FISTAReconstructor +from ccpi.fista.FISTAReconstructor import FISTAReconstructor import astra ##def getEntry(nx, location): -- cgit v1.2.3 From 83250cee1deff04c34d5ed9ad2d9dbde09cadcf6 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 17:04:44 +0100 Subject: minor changes --- src/Python/test_regularizers.py | 60 +++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 29 deletions(-) (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 755804a..d0bccaf 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -46,7 +46,9 @@ def nrmse(im1, im2): # u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; # u = SplitBregman_TV(single(u0), 10, 30, 1e-04); -filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif" + reader = vtk.vtkTIFFReader() reader.SetFileName(os.path.normpath(filename)) reader.Update() @@ -163,24 +165,24 @@ imgplot = plt.imshow(reg_output[-1][0]) # # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise # # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); -# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - # searching_window_ratio=3, - # similarity_window_ratio=1, - # PB_filtering_parameter=0.08) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,5) +a=fig.add_subplot(2,3,5) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) # ###################### TGV_PD ######################################### @@ -190,25 +192,25 @@ imgplot = plt.imshow(reg_output[-1][0]) # # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - # first_order_term=1.3, - # second_order_term=1, - # number_of_iterations=550) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,6) +a=fig.add_subplot(2,3,6) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) plt.show() -- cgit v1.2.3 From f7e1cf04f791898737bc15b0eb437abc2c5d9305 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 13 Sep 2017 10:41:14 +0100 Subject: cleaned up code --- src/Python/test_regularizers.py | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index d2fbca6..665a077 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -76,7 +76,7 @@ fig = plt.figure() a=fig.add_subplot(2,3,1) a.set_title('noise') -imgplot = plt.imshow(u0) +imgplot = plt.imshow(u0,cmap="gray") reg_output = [] ############################################################################## @@ -134,20 +134,6 @@ reg_output.append(out2) a=fig.add_subplot(2,3,3) -textstr = out2[-1] - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - searching_window_ratio=3, - similarity_window_ratio=1, - PB_filtering_parameter=0.08) -pars = out2[-2] -reg_output.append(out2) - -a=fig.add_subplot(2,3,5) - - textstr = out2[-1] # these are matplotlib.patch.Patch properties @@ -176,15 +162,6 @@ pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,4) -out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - searching_window_ratio=3, - similarity_window_ratio=1, - PB_filtering_parameter=0.08) -pars = out2[-2] -reg_output.append(out2) - -a=fig.add_subplot(2,3,5) - textstr = out2[-1] @@ -195,13 +172,6 @@ a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) imgplot = plt.imshow(reg_output[-1][0],cmap="gray") -textstr = out2[-1] -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0],cmap="gray") # ###################### PatchBased_Regul ######################################### # # Quick 2D denoising example in Matlab: -- cgit v1.2.3