From 0f40ee8ad7d6e0b3b7059e5e1242d8ab97cd3caf Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 3 Aug 2017 15:26:29 +0100 Subject: Added Python modules Matlab2Python_utils.cpp contains utilities for handling numpy arrays. Together with setup_test.py it creates a functional module for testing. fista_module.cpp and setup.py are meant for the real fista module. --- src/Python/Matlab2Python_utils.cpp | 206 ++++++++++++++++++++++++ src/Python/fista_module.cpp | 315 +++++++++++++++++++++++++++++++++++++ src/Python/setup.py | 58 +++++++ src/Python/setup_test.py | 58 +++++++ 4 files changed, 637 insertions(+) create mode 100644 src/Python/Matlab2Python_utils.cpp create mode 100644 src/Python/fista_module.cpp create mode 100644 src/Python/setup.py create mode 100644 src/Python/setup_test.py (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp new file mode 100644 index 0000000..138e8da --- /dev/null +++ b/src/Python/Matlab2Python_utils.cpp @@ -0,0 +1,206 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazanteev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include + +#include +#include +#include "boost/tuple/tuple.hpp" + +#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) +#include +// this trick only if compiler is MSVC +__if_not_exists(uint8_t) { typedef __int8 uint8_t; } +__if_not_exists(uint16_t) { typedef __int8 uint16_t; } +#endif + +namespace bp = boost::python; +namespace np = boost::python::numpy; + +/*! in the Matlab implementation this is called as +void mexFunction( +int nlhs, mxArray *plhs[], +int nrhs, const mxArray *prhs[]) +where: +prhs Array of pointers to the INPUT mxArrays +nrhs int number of INPUT mxArrays + +nlhs Array of pointers to the OUTPUT mxArrays +plhs int number of OUTPUT mxArrays + +*********************************************************** + +*********************************************************** +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*********************************************************** +char *mxArrayToString(const mxArray *array_ptr); +args: array_ptr Pointer to mxCHAR array. +Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array. +Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string. +*********************************************************** +mxClassID mxGetClassID(const mxArray *pm); +args: pm Pointer to an mxArray +Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types, +mxGetClassId returns a unique value identifying the class of the array contents. +Use mxIsClass to determine whether an array is of a specific user-defined type. + +mxClassID Value MATLAB Type MEX Type C Primitive Type +mxINT8_CLASS int8 int8_T char, byte +mxUINT8_CLASS uint8 uint8_T unsigned char, byte +mxINT16_CLASS int16 int16_T short +mxUINT16_CLASS uint16 uint16_T unsigned short +mxINT32_CLASS int32 int32_T int +mxUINT32_CLASS uint32 uint32_T unsigned int +mxINT64_CLASS int64 int64_T long long +mxUINT64_CLASS uint64 uint64_T unsigned long long +mxSINGLE_CLASS single float float +mxDOUBLE_CLASS double double double + +**************************************************************** +double *mxGetPr(const mxArray *pm); +args: pm Pointer to an mxArray of type double +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +**************************************************************** +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). +Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + +void mexErrMessageText(char* text) { + std::cerr << text << std::endl; +} + +/* +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*/ + +template +double mxGetScalar(const np::ndarray plh) { + return (double)bp::extract(plh[0]); +} + + + +template +T * mxGetData(const np::ndarray pm) { + //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. + //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. + /*Access the numpy array pointer: + char * get_data() const; + Returns: Array’s raw data pointer as a char + Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. + probably this would work. + A = reinterpret_cast(prhs[0]); + */ + return reinterpret_cast(prhs[0]); +} + +template +np::ndarray zeros(int dims , int * dim_array, T el) { + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + + +bp::list mexFunction( np::ndarray input ) { + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast( input.get_data() ); + int * B = reinterpret_cast( zz.get_data() ); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = (float)val; + std::memcpy(B + index , &val, sizeof(int)); + std::memcpy(C + index , &fval, sizeof(float)); + } + } + } + + + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} + + +BOOST_PYTHON_MODULE(fista) +{ + np::initialize(); + + //To specify that this module is a package + bp::object package = bp::scope(); + package.attr("__path__") = "fista"; + + np::dtype dt1 = np::dtype::get_builtin(); + np::dtype dt2 = np::dtype::get_builtin(); + + //import_array(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + def("mexFunction", mexFunction); +} \ No newline at end of file diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp new file mode 100644 index 0000000..5344083 --- /dev/null +++ b/src/Python/fista_module.cpp @@ -0,0 +1,315 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazanteev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include + +#include +#include +#include "boost/tuple/tuple.hpp" + +// include the regularizers +#include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "PatchBased_Regul_core.h" +#include "SplitBregman_TV_core.h" +#include "TGV_PD_core.h" + +#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) +#include +// this trick only if compiler is MSVC +__if_not_exists(uint8_t) { typedef __int8 uint8_t; } +__if_not_exists(uint16_t) { typedef __int8 uint16_t; } +#endif + +namespace bp = boost::python; +namespace np = boost::python::numpy; + + +/*! in the Matlab implementation this is called as +void mexFunction( +int nlhs, mxArray *plhs[], +int nrhs, const mxArray *prhs[]) +where: +prhs Array of pointers to the INPUT mxArrays +nrhs int number of INPUT mxArrays + +nlhs Array of pointers to the OUTPUT mxArrays +plhs int number of OUTPUT mxArrays + +*********************************************************** +mxGetData +args: pm Pointer to an mxArray +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +*********************************************************** +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*********************************************************** +char *mxArrayToString(const mxArray *array_ptr); +args: array_ptr Pointer to mxCHAR array. +Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array. +Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string. +*********************************************************** +mxClassID mxGetClassID(const mxArray *pm); +args: pm Pointer to an mxArray +Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types, +mxGetClassId returns a unique value identifying the class of the array contents. +Use mxIsClass to determine whether an array is of a specific user-defined type. + +mxClassID Value MATLAB Type MEX Type C Primitive Type +mxINT8_CLASS int8 int8_T char, byte +mxUINT8_CLASS uint8 uint8_T unsigned char, byte +mxINT16_CLASS int16 int16_T short +mxUINT16_CLASS uint16 uint16_T unsigned short +mxINT32_CLASS int32 int32_T int +mxUINT32_CLASS uint32 uint32_T unsigned int +mxINT64_CLASS int64 int64_T long long +mxUINT64_CLASS uint64 uint64_T unsigned long long +mxSINGLE_CLASS single float float +mxDOUBLE_CLASS double double double + +**************************************************************** +double *mxGetPr(const mxArray *pm); +args: pm Pointer to an mxArray of type double +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +**************************************************************** +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); +args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. + dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. + For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. + classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. + For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. + ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). + Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). + +Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). + If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not + enough free heap space to create the mxArray. +*/ + +template +np::ndarray zeros(int dims, int * dim_array, T el) { + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + + +bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) { + /* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) + * + * Input Parameters: + * 1. Noisy image/volume + * 2. lambda - regularization parameter + * 3. Number of iterations [OPTIONAL parameter] + * 4. eplsilon - tolerance constant [OPTIONAL parameter] + * 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] + * + * Output: + * Filtered/regularized image + * + * All sanity checks and default values are set in Python + */ + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + const int dim_array[3]; + float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + number_of_dims = input.get_nd(); + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -11; + } + else { + dim_array[2] = input.shape(2); + } + + /*Handling Matlab input data*/ + //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + + /*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ + A = reinterpret_cast(input.get_data()); + + + //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ + mu = (float)d_mu; + //iter = 35; /* default iterations number */ + iter = niterations; + //epsil = 0.0001; /* default tolerance constant */ + epsil = (float)d_epsil; + //methTV = 0; /* default isotropic TV penalty */ + methTV = TV_type; + //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ + //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ + //if (nrhs == 5) { + // char *penalty_type; + // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ + // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); + // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */ + // mxFree(penalty_type); + //} + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + lambda = 2.0f*mu; + count = 1; + re_old = 0.0f; + /*Handling Matlab output data*/ + dimY = dim_array[0]; dimX = dim_array[1]; dimZ = dim_array[2]; + + if (number_of_dims == 2) { + dimZ = 1; /*2D case*/ + /* + mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); +args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. + dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. + For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. + classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. + For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. + ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). + Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). + + mxCreateNumericArray initializes all its real data elements to 0. +*/ + +/* + U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +*/ + //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U = A = reinterpret_castinput.get_data(); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + /*printf("%f %i %i \n", re, ll, count); */ + + /*copyIm(U_old, U, dimX, dimY, dimZ); */ + } + printf("SB iterations stopped at iteration: %i\n", ll); + } + if (number_of_dims == 3) { + U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + /*printf("%f %i %i \n", re, ll, count); */ + re_old = re; + } + printf("SB iterations stopped at iteration: %i\n", ll); + } + bp::list result; + return result; +} + + +BOOST_PYTHON_MODULE(fista) +{ + np::initialize(); + + //To specify that this module is a package + bp::object package = bp::scope(); + package.attr("__path__") = "fista"; + + np::dtype dt1 = np::dtype::get_builtin(); + np::dtype dt2 = np::dtype::get_builtin(); + + + def("mexFunction", mexFunction); +} \ No newline at end of file diff --git a/src/Python/setup.py b/src/Python/setup.py new file mode 100644 index 0000000..ffb9c02 --- /dev/null +++ b/src/Python/setup.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform + +cil_version=os.environ['CIL_VERSION'] +if cil_version == '': + print("Please set the environmental variable CIL_VERSION") + sys.exit(1) + +library_include_path = "" +library_lib_path = "" +try: + library_include_path = os.environ['LIBRARY_INC'] + library_lib_path = os.environ['LIBRARY_LIB'] +except: + library_include_path = os.environ['PREFIX']+'/include' + pass + +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +if platform.system() == 'Windows': + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] + extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + if sys.version_info.major == 3 : + extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] + else: + extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: + extra_include_dirs += ["../ContourTree/", "../Core/","."] + if sys.version_info.major == 3: + extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] + else: + extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( + name='ccpi', + description='CCPi Core Imaging Library - FISTA Reconstruction Module', + version=cil_version, + cmdclass = {'build_ext': build_ext}, + ext_modules = [Extension("fista", + sources=[ "Matlab2Python_utils.cpp", + ], + include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), + + ], + zip_safe = False, + packages = {'ccpi','ccpi.reconstruction'}, +) diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py new file mode 100644 index 0000000..ffb9c02 --- /dev/null +++ b/src/Python/setup_test.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform + +cil_version=os.environ['CIL_VERSION'] +if cil_version == '': + print("Please set the environmental variable CIL_VERSION") + sys.exit(1) + +library_include_path = "" +library_lib_path = "" +try: + library_include_path = os.environ['LIBRARY_INC'] + library_lib_path = os.environ['LIBRARY_LIB'] +except: + library_include_path = os.environ['PREFIX']+'/include' + pass + +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +if platform.system() == 'Windows': + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] + extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + if sys.version_info.major == 3 : + extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] + else: + extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: + extra_include_dirs += ["../ContourTree/", "../Core/","."] + if sys.version_info.major == 3: + extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] + else: + extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( + name='ccpi', + description='CCPi Core Imaging Library - FISTA Reconstruction Module', + version=cil_version, + cmdclass = {'build_ext': build_ext}, + ext_modules = [Extension("fista", + sources=[ "Matlab2Python_utils.cpp", + ], + include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), + + ], + zip_safe = False, + packages = {'ccpi','ccpi.reconstruction'}, +) -- cgit v1.2.3 From cf94b779bf8f11128ce0e4535ba1e12ccb2b50a1 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 3 Aug 2017 16:52:20 +0100 Subject: added FGP_TV wrapper --- src/Python/fista_module.cpp | 576 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 473 insertions(+), 103 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 5344083..2492884 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -26,12 +26,10 @@ limitations under the License. #include #include "boost/tuple/tuple.hpp" -// include the regularizers -#include "FGP_TV_core.h" -#include "LLT_model_core.h" -#include "PatchBased_Regul_core.h" #include "SplitBregman_TV_core.h" -#include "TGV_PD_core.h" +#include "FGP_TV_core.h" + + #if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) #include @@ -43,7 +41,6 @@ __if_not_exists(uint16_t) { typedef __int8 uint16_t; } namespace bp = boost::python; namespace np = boost::python::numpy; - /*! in the Matlab implementation this is called as void mexFunction( int nlhs, mxArray *plhs[], @@ -56,9 +53,7 @@ nlhs Array of pointers to the OUTPUT mxArrays plhs int number of OUTPUT mxArrays *********************************************************** -mxGetData -args: pm Pointer to an mxArray -Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. + *********************************************************** double mxGetScalar(const mxArray *pm); args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. @@ -92,77 +87,143 @@ double *mxGetPr(const mxArray *pm); args: pm Pointer to an mxArray of type double Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. **************************************************************** -mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); -args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. - dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. - For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. - classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. - For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. - ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). - Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). - +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). - If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not - enough free heap space to create the mxArray. +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + +void mexErrMessageText(char* text) { + std::cerr << text << std::endl; +} + +/* +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. */ template -np::ndarray zeros(int dims, int * dim_array, T el) { - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); - np::dtype dtype = np::dtype::get_builtin(); - np::ndarray zz = np::zeros(shape, dtype); - return zz; +double mxGetScalar(const np::ndarray plh) { + return (double)bp::extract(plh[0]); } -bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) { - /* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) - * - * Input Parameters: - * 1. Noisy image/volume - * 2. lambda - regularization parameter - * 3. Number of iterations [OPTIONAL parameter] - * 4. eplsilon - tolerance constant [OPTIONAL parameter] - * 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] - * - * Output: - * Filtered/regularized image - * - * All sanity checks and default values are set in Python + +template +T * mxGetData(const np::ndarray pm) { + //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. + //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. + /*Access the numpy array pointer: + char * get_data() const; + Returns: Array’s raw data pointer as a char + Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. + probably this would work. + A = reinterpret_cast(prhs[0]); */ + return reinterpret_cast(prhs[0]); +} + + + + +bp::list mexFunction(np::ndarray input) { + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast(input.get_data()); + int * B = reinterpret_cast(zz.get_data()); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = (float)val; + std::memcpy(B + index, &val, sizeof(int)); + std::memcpy(C + index, &fval, sizeof(float)); + } + } + } + + + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} + +bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + + // the result is in the following list + bp::list result; + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int dim_array[3]; + const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; - + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - number_of_dims = input.get_nd(); + + int number_of_dims = input.get_nd(); + int dim_array[3]; dim_array[0] = input.shape(0); dim_array[1] = input.shape(1); if (number_of_dims == 2) { - dim_array[2] = -11; + dim_array[2] = -1; } else { dim_array[2] = input.shape(2); } - /*Handling Matlab input data*/ + // Parameter handling is be done in Python + ///*Handling Matlab input data*/ //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); - /*Handling Matlab input data*/ + ///*Handling Matlab input data*/ //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ A = reinterpret_cast(input.get_data()); - //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ mu = (float)d_mu; + //iter = 35; /* default iterations number */ - iter = niterations; + //epsil = 0.0001; /* default tolerance constant */ epsil = (float)d_epsil; //methTV = 0; /* default isotropic TV penalty */ - methTV = TV_type; //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ //if (nrhs == 5) { @@ -182,34 +243,31 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, doub if (number_of_dims == 2) { dimZ = 1; /*2D case*/ - /* - mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); -args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. - dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. - For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. - classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. - For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. - ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). - Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). - - mxCreateNumericArray initializes all its real data elements to 0. -*/ - -/* - U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); -*/ //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - U = A = reinterpret_castinput.get_data(); - U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ @@ -245,59 +303,370 @@ args: ndim: Number of dimensions. If you specify a value for ndim that is less /*printf("%f %i %i \n", re, ll, count); */ /*copyIm(U_old, U, dimX, dimY, dimZ); */ + result.append(npU); + result.append(ll); + } + //printf("SB iterations stopped at iteration: %i\n", ll); + if (number_of_dims == 3) { + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npDz = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npBz = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Dz = reinterpret_cast(npDz.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + Bz = reinterpret_cast(npBz.get_data()); + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + /*printf("%f %i %i \n", re, ll, count); */ + re_old = re; + } + //printf("SB iterations stopped at iteration: %i\n", ll); + result.append(npU); + result.append(ll); } - printf("SB iterations stopped at iteration: %i\n", ll); } - if (number_of_dims == 3) { - U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + return result; - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ +} - /* begin outer SB iterations */ +bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + + // the result is in the following list + bp::list result; + + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + // Parameter handling is be done in Python + ///*Handling Matlab input data*/ + //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + + ///*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ + A = reinterpret_cast(input.get_data()); + + //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ + mu = (float)d_mu; + + //iter = 35; /* default iterations number */ + + //epsil = 0.0001; /* default tolerance constant */ + epsil = (float)d_epsil; + //methTV = 0; /* default isotropic TV penalty */ + //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ + //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ + //if (nrhs == 5) { + // char *penalty_type; + // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ + // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); + // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */ + // mxFree(penalty_type); + //} + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + //plhs[1] = mxCreateNumericMatrix(1, 1, mxSINGLE_CLASS, mxREAL); + bp::tuple shape1 = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray out1 = np::zeros(shape1, dtype); + + //float *funcvalA = (float *)mxGetData(plhs[1]); + float * funcvalA = reinterpret_cast(out1.get_data()); + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = dim_array[2]; + + tk = 1.0f; + tkp1 = 1.0f; + count = 1; + re_old = 0.0f; + + if (number_of_dims == 2) { + dimZ = 1; /*2D case*/ + /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + R1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + R2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npD = np::zeros(shape, dtype); + np::ndarray npD_old = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npP1_old = np::zeros(shape, dtype); + np::ndarray npP2_old = np::zeros(shape, dtype); + np::ndarray npR1 = np::zeros(shape, dtype); + np::ndarray npR2 = zeros(2, dim_array, (float)0); + + D = reinterpret_cast(npD.get_data()); + D_old = reinterpret_cast(npD_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + P1_old = reinterpret_cast(npP1_old.get_data()); + P2_old = reinterpret_cast(npP2_old.get_data()); + R1 = reinterpret_cast(npR1.get_data()); + R2 = reinterpret_cast(npR2.get_data()); + + /* begin iterations */ for (ll = 0; ll 4) break; + if (count > 3) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j 2) { - if (re > re_old) break; + if (re > re_old) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j(npD); + result.append(out1); + result.append(ll); + } + if (number_of_dims == 3) { + /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P1_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P2_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P3_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npD = np::zeros(shape, dtype); + np::ndarray npD_old = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npP3 = np::zeros(shape, dtype); + np::ndarray npP1_old = np::zeros(shape, dtype); + np::ndarray npP2_old = np::zeros(shape, dtype); + np::ndarray npP3_old = np::zeros(shape, dtype); + np::ndarray npR1 = np::zeros(shape, dtype); + np::ndarray npR2 = np::zeros(shape, dtype); + np::ndarray npR3 = np::zeros(shape, dtype); + + D = reinterpret_cast(npD.get_data()); + D_old = reinterpret_cast(npD_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + P3 = reinterpret_cast(npP3.get_data()); + P1_old = reinterpret_cast(npP1_old.get_data()); + P2_old = reinterpret_cast(npP2_old.get_data()); + P3_old = reinterpret_cast(npP3_old.get_data()); + R1 = reinterpret_cast(npR1.get_data()); + R2 = reinterpret_cast(npR2.get_data()); + R2 = reinterpret_cast(npR3.get_data()); + /* begin iterations */ + for (ll = 0; ll 3) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j 2) { + if (re > re_old) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j(npD); + result.append(out1); + result.append(ll); } - bp::list result; + return result; } - BOOST_PYTHON_MODULE(fista) { @@ -310,6 +679,7 @@ BOOST_PYTHON_MODULE(fista) np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); - def("mexFunction", mexFunction); + def("SplitBregman_TV", SplitBregman_TV); + def("FGP_TV", FGP_TV); } \ No newline at end of file -- cgit v1.2.3 From 16a2b514d191563eb7691ee24f063bd27f5ff12d Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:15:03 +0100 Subject: compilation fixes --- src/Python/setup.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index ffb9c02..a8feb1c 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -29,14 +29,14 @@ extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\env extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] extra_libraries = [] if platform.system() == 'Windows': - extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] - extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB' , '/openmp' ] + extra_include_dirs += ["..\\..\\main_func\\regularizers_CPU\\","."] if sys.version_info.major == 3 : extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] else: extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] else: - extra_include_dirs += ["../ContourTree/", "../Core/","."] + extra_include_dirs += ["../../main_func/regularizers_CPU","."] if sys.version_info.major == 3: extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] else: @@ -47,8 +47,12 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("fista", - sources=[ "Matlab2Python_utils.cpp", + ext_modules = [Extension("regularizers", + sources=["fista_module.cpp", + "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", + "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", + "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", + "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 0859756c281f514d53b05a3cc9dc5035136d73a9 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:15:17 +0100 Subject: test facility for regularizers --- src/Python/test_regularizers.py | 265 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 src/Python/test_regularizers.py (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py new file mode 100644 index 0000000..6abfba4 --- /dev/null +++ b/src/Python/test_regularizers.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Aug 4 11:10:05 2017 + +@author: ofn77899 +""" + +from ccpi.viewer.CILViewer2D import Converter +import vtk + +import regularizers +import matplotlib.pyplot as plt +import numpy as np +import os +from enum import Enum + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) + 4) + 5) + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = regularizers.SplitBregman_TV + FGP_TV = regularizers.FGP_TV + LLT_model = regularizers.LLT_model + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm): + + self.algorithm = algorithm + self.pars = self.parsForAlgorithm(algorithm) + # __init__ + + def parsForAlgorithm(self, algorithm): + pars = dict() + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + return pars + # parsForAlgorithm + + def __call__(self, input, regularization_parameter, **kwargs): + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + self.pars['input'] = input + self.pars['regularization_parameter'] = regularization_parameter + #for key, value in self.pars.items(): + # print("{0} = {1}".format(key, value)) + + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if None in self.pars: + raise Exception("Not all parameters have been provided") + else: + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + return out + + +#Example: +# figure; +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +reader = vtk.vtkTIFFReader() +reader.SetFileName(os.path.normpath(filename)) +reader.Update() +#vtk returns 3D images, let's take just the one slice there is as 2D +Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 + +#imgplot = plt.imshow(Im) +perc = 0.05 +u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +# map the u0 u0->u0>0 +f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +u0 = f(u0).astype('float32') + +# plot +fig = plt.figure() +a=fig.add_subplot(2,3,1) +a.set_title('Original') +imgplot = plt.imshow(Im) + +a=fig.add_subplot(2,3,2) +a.set_title('noise') +imgplot = plt.imshow(u0) + + +############################################################################## +# Call regularizer + +####################### SplitBregman_TV ##################################### +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + +out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) +out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) +pars = out2[2] + +a=fig.add_subplot(2,3,3) +a.set_title('SplitBregman_TV') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['TV_penalty'].name) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + +###################### FGP_TV ######################################### +# u = FGP_TV(single(u0), 0.05, 100, 1e-04); +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, + number_of_iterations=10) +pars = out2[-1] + +a=fig.add_subplot(2,3,4) +a.set_title('FGP_TV') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['TV_penalty'].name) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + +###################### LLT_model ######################################### +# * u0 = Im + .03*randn(size(Im)); % adding noise +# [Den] = LLT_model(single(u0), 10, 0.1, 1); +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10., + time_step=0.1, + tolerance_constant=1e-4, + number_of_iterations=10) +pars = out2[-1] + +a=fig.add_subplot(2,3,5) +a.set_title('LLT_model') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['time_step'] + ) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + + + -- cgit v1.2.3 From 3e96c0d80387225894a8e5f1456ea310cd7e797b Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:16:37 +0100 Subject: Added 3 regularizers SplitBregman_TV FGP_TV LLT_model --- src/Python/fista_module.cpp | 266 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 232 insertions(+), 34 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 2492884..d890b10 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,8 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "utils.h" @@ -131,6 +133,18 @@ T * mxGetData(const np::ndarray pm) { return reinterpret_cast(prhs[0]); } +template +np::ndarray zeros(int dims, int * dim_array, T el) { + bp::tuple shape; + if (dims == 3) + shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + else if (dims == 2) + shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + @@ -169,7 +183,6 @@ bp::list mexFunction(np::ndarray input) { } } - bp::list result; result.append(number_of_dims); @@ -189,14 +202,14 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int *dim_array; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; + //const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -252,26 +265,26 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); np::dtype dtype = np::dtype::get_builtin(); - np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU = np::zeros(shape, dtype); np::ndarray npU_old = np::zeros(shape, dtype); - np::ndarray npDx = np::zeros(shape, dtype); - np::ndarray npDy = np::zeros(shape, dtype); - np::ndarray npBx = np::zeros(shape, dtype); - np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); - U = reinterpret_cast(npU.get_data()); + U = reinterpret_cast(npU.get_data()); U_old = reinterpret_cast(npU_old.get_data()); - Dx = reinterpret_cast(npDx.get_data()); - Dy = reinterpret_cast(npDy.get_data()); - Bx = reinterpret_cast(npBx.get_data()); - By = reinterpret_cast(npBy.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ - for (ll = 0; ll(npU); - result.append(ll); + } //printf("SB iterations stopped at iteration: %i\n", ll); - if (number_of_dims == 3) { + result.append(npU); + result.append(ll); + } + if (number_of_dims == 3) { /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); @@ -375,24 +390,25 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi result.append(npU); result.append(ll); } - } return result; -} + } + + bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; - float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval, mu; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -512,7 +528,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j(input.get_data()); + lambda = (float)d_lambda; + tau = (float)d_tau; + // iter is passed as parameter + epsil = (float)d_epsil; + // switcher is passed as parameter + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = 1; + + if (number_of_dims == 2) { + /*2D case*/ + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + D1 = reinterpret_cast(npD1.get_data()); + D2 = reinterpret_cast(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der2D(U, D1, D2, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd2D(U0, U, D1, D2, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + + result.append(npU); + } + else if (number_of_dims == 3) { + /*3D case*/ + dimZ = dim_array[2]; + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + if (switcher != 0) { + Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + }*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + np::ndarray npD3 = np::zeros(shape, dtype); + np::ndarray npMap = np::zeros(shape, np::dtype::get_builtin()); + Map = reinterpret_cast(npMap.get_data()); + if (switcher != 0) { + //Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + + Map = reinterpret_cast(npMap.get_data()); + } + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + D1 = reinterpret_cast(npD1.get_data()); + D2 = reinterpret_cast(npD2.get_data()); + D3 = reinterpret_cast(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + + if (switcher == 1) { + /* apply restrictive smoothing */ + calcMap(U, Map, dimX, dimY, dimZ); + /*clear outliers */ + cleanMap(Map, dimX, dimY, dimZ); + } + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der3D(U, D1, D2, D3, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd3D(U0, U, D1, D2, D3, Map, switcher, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + result.append(npU); + if (switcher != 0) result.append(npMap); + + } + return result; +} + + +BOOST_PYTHON_MODULE(regularizers) { np::initialize(); //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "fista"; + package.attr("__path__") = "regularizers"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); @@ -682,4 +879,5 @@ BOOST_PYTHON_MODULE(fista) def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); + def("LLT_model", LLT_model); } \ No newline at end of file -- cgit v1.2.3 From fbaf7281141e0ddad5046b433ba0a72d360d09aa Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:16:53 +0100 Subject: minor change --- src/Python/Matlab2Python_utils.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index 138e8da..6aaad90 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -128,7 +128,11 @@ T * mxGetData(const np::ndarray pm) { template np::ndarray zeros(int dims , int * dim_array, T el) { - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + bp::tuple shape; + if (dims == 3) + shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + else if (dims == 2) + shape = bp::make_tuple(dim_array[0], dim_array[1]); np::dtype dtype = np::dtype::get_builtin(); np::ndarray zz = np::zeros(shape, dtype); return zz; @@ -163,7 +167,7 @@ bp::list mexFunction( np::ndarray input ) { for (int k = 0; k < dim_array[2]; k++) { int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; int val = (*(A + index)); - float fval = (float)val; + float fval = sqrt((float)val); std::memcpy(B + index , &val, sizeof(int)); std::memcpy(C + index , &fval, sizeof(float)); } @@ -186,7 +190,7 @@ bp::list mexFunction( np::ndarray input ) { } -BOOST_PYTHON_MODULE(fista) +BOOST_PYTHON_MODULE(prova) { np::initialize(); -- cgit v1.2.3 From 9974b5a4bf88ac3e7929d7c33a911c90fcfcff29 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:17:18 +0100 Subject: test for general boost::python / numpy routines --- src/Python/test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 src/Python/test.py (limited to 'src') diff --git a/src/Python/test.py b/src/Python/test.py new file mode 100644 index 0000000..e283f89 --- /dev/null +++ b/src/Python/test.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Aug 3 14:08:09 2017 + +@author: ofn77899 +""" + +import fista +import numpy as np + +a = np.asarray([i for i in range(3*4*5)]) +a = a.reshape([3,4,5]) +print (a) +b = fista.mexFunction(a) +#print (b) +print (b[4].shape) +print (b[4]) +print (b[5]) \ No newline at end of file -- cgit v1.2.3 From 5a27224a373c12ba8e3af6e25a4c5eaec522b834 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:49:21 +0100 Subject: added PatchBased_Regul --- src/Python/fista_module.cpp | 123 +++++++++++++++++++++++++++++++++++++++++++- src/Python/setup.py | 1 + 2 files changed, 123 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index d890b10..c2d9352 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" #include "LLT_model_core.h" +#include "PatchBased_Regul_core.h" #include "utils.h" @@ -793,7 +794,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d if (switcher != 0) { Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); }*/ - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); np::dtype dtype = np::dtype::get_builtin(); @@ -865,6 +866,126 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d } +bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) { + // the result is in the following list + bp::list result; + + int N, M, Z, numdims, SearchW, /*SimilW, SearchW_real,*/ padXY, newsizeX, newsizeY, newsizeZ, switchpad_crop; + //const int *dims; + float *A, *B = NULL, *Ap = NULL, *Bp = NULL, h, lambda; + + numdims = input.get_nd(); + int dims[3]; + + dims[0] = input.shape(0); + dims[1] = input.shape(1); + if (numdims == 2) { + dims[2] = -1; + } + else { + dims[2] = input.shape(2); + } + /*numdims = mxGetNumberOfDimensions(prhs[0]); + dims = mxGetDimensions(prhs[0]);*/ + + N = dims[0]; + M = dims[1]; + Z = dims[2]; + + //if ((numdims < 2) || (numdims > 3)) { mexErrMsgTxt("The input should be 2D image or 3D volume"); } + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + + //if (nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter"); + + ///*Handling inputs*/ + //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */ + //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ + //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */ + //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ + //lambda = (float)mxGetScalar(prhs[4]); /* regularization parameter */ + + //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); + //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + + SearchW = SearchW_real + 2 * SimilW; + + /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ + /* SimilW_full = 2*SimilW + 1; */ /* the full similarity window size */ + + + padXY = SearchW + 2 * SimilW; /* padding sizes */ + newsizeX = N + 2 * (padXY); /* the X size of the padded array */ + newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ + newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ + int N_dims[] = { newsizeX, newsizeY, newsizeZ }; + + /******************************2D case ****************************/ + if (numdims == 2) { + ///*Handling output*/ + //B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL)); + ///*allocating memory for the padded arrays */ + //Ap = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + //Bp = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + ///**************************************************************************/ + + bp::tuple shape = bp::make_tuple(N, M); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npB = np::zeros(shape, dtype); + + shape = bp::make_tuple(newsizeX, newsizeY); + np::ndarray npAp = np::zeros(shape, dtype); + np::ndarray npBp = np::zeros(shape, dtype); + B = reinterpret_cast(npB.get_data()); + Ap = reinterpret_cast(npAp.get_data()); + Bp = reinterpret_cast(npBp.get_data()); + + /*Perform padding of image A to the size of [newsizeX * newsizeY] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + result.append(npB); + } + else + { + /******************************3D case ****************************/ + ///*Handling output*/ + //B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL)); + ///*allocating memory for the padded arrays */ + //Ap = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + //Bp = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + /**************************************************************************/ + bp::tuple shape = bp::make_tuple(dims[0], dims[1], dims[2]); + bp::tuple shape_AB = bp::make_tuple(N_dims[0], N_dims[1], N_dims[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npB = np::zeros(shape, dtype); + np::ndarray npAp = np::zeros(shape_AB, dtype); + np::ndarray npBp = np::zeros(shape_AB, dtype); + B = reinterpret_cast(npB.get_data()); + Ap = reinterpret_cast(npAp.get_data()); + Bp = reinterpret_cast(npBp.get_data()); + /*Perform padding of image A to the size of [newsizeX * newsizeY * newsizeZ] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + PB_FUNC3D(Ap, Bp, newsizeY, newsizeX, newsizeZ, padXY, SearchW, SimilW, (float)h, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + result.append(npB); + } /*end else ndims*/ + + return result; +} + BOOST_PYTHON_MODULE(regularizers) { np::initialize(); diff --git a/src/Python/setup.py b/src/Python/setup.py index a8feb1c..a4eed14 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -52,6 +52,7 @@ setup( "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", + "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 6589fa197d9f87f7a37f46943aa995d97f50bb46 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 7 Aug 2017 17:21:12 +0100 Subject: added TGV_PD, removed useless code --- src/Python/fista_module.cpp | 245 ++++++++++++++++++++++++++------------------ 1 file changed, 146 insertions(+), 99 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index c2d9352..eacda3d 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -30,6 +30,7 @@ limitations under the License. #include "FGP_TV_core.h" #include "LLT_model_core.h" #include "PatchBased_Regul_core.h" +#include "TGV_PD_core.h" #include "utils.h" @@ -103,101 +104,8 @@ If unsuccessful in a MEX file, the MEX file terminates and returns control to th enough free heap space to create the mxArray. */ -void mexErrMessageText(char* text) { - std::cerr << text << std::endl; -} - -/* -double mxGetScalar(const mxArray *pm); -args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. -Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. -*/ - -template -double mxGetScalar(const np::ndarray plh) { - return (double)bp::extract(plh[0]); -} - - - -template -T * mxGetData(const np::ndarray pm) { - //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. - //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. - /*Access the numpy array pointer: - char * get_data() const; - Returns: Array’s raw data pointer as a char - Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. - probably this would work. - A = reinterpret_cast(prhs[0]); - */ - return reinterpret_cast(prhs[0]); -} - -template -np::ndarray zeros(int dims, int * dim_array, T el) { - bp::tuple shape; - if (dims == 3) - shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); - else if (dims == 2) - shape = bp::make_tuple(dim_array[0], dim_array[1]); - np::dtype dtype = np::dtype::get_builtin(); - np::ndarray zz = np::zeros(shape, dtype); - return zz; -} - - -bp::list mexFunction(np::ndarray input) { - int number_of_dims = input.get_nd(); - int dim_array[3]; - - dim_array[0] = input.shape(0); - dim_array[1] = input.shape(1); - if (number_of_dims == 2) { - dim_array[2] = -1; - } - else { - dim_array[2] = input.shape(2); - } - - /**************************************************************************/ - np::ndarray zz = zeros(3, dim_array, (int)0); - np::ndarray fzz = zeros(3, dim_array, (float)0); - /**************************************************************************/ - - int * A = reinterpret_cast(input.get_data()); - int * B = reinterpret_cast(zz.get_data()); - float * C = reinterpret_cast(fzz.get_data()); - - //Copy data and cast - for (int i = 0; i < dim_array[0]; i++) { - for (int j = 0; j < dim_array[1]; j++) { - for (int k = 0; k < dim_array[2]; k++) { - int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; - int val = (*(A + index)); - float fval = (float)val; - std::memcpy(B + index, &val, sizeof(int)); - std::memcpy(C + index, &fval, sizeof(float)); - } - } - } - - bp::list result; - - result.append(number_of_dims); - result.append(dim_array[0]); - result.append(dim_array[1]); - result.append(dim_array[2]); - result.append(zz); - result.append(fzz); - - //result.append(tup); - return result; - -} - bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list @@ -487,7 +395,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me np::ndarray npP1_old = np::zeros(shape, dtype); np::ndarray npP2_old = np::zeros(shape, dtype); np::ndarray npR1 = np::zeros(shape, dtype); - np::ndarray npR2 = zeros(2, dim_array, (float)0); + np::ndarray npR2 = np::zeros(shape, dtype); D = reinterpret_cast(npD.get_data()); D_old = reinterpret_cast(npD_old.get_data()); @@ -866,7 +774,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d } -bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) { +bp::list PatchBased_Regul(np::ndarray input, double d_lambda, int SearchW_real, int SimilW, double d_h) { // the result is in the following list bp::list result; @@ -899,6 +807,7 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub ///*Handling inputs*/ //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */ + A = reinterpret_cast(input.get_data()); //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */ //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ @@ -907,6 +816,8 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + lambda = (float)d_lambda; + h = (float)d_h; SearchW = SearchW_real + 2 * SimilW; /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ @@ -918,7 +829,6 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ int N_dims[] = { newsizeX, newsizeY, newsizeZ }; - /******************************2D case ****************************/ if (numdims == 2) { ///*Handling output*/ @@ -943,12 +853,13 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub /*Perform padding of image A to the size of [newsizeX * newsizeY] */ switchpad_crop = 0; /*padding*/ pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); - + /* Do PB regularization with the padded array */ PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); - + switchpad_crop = 1; /*cropping*/ pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + result.append(npB); } else @@ -983,6 +894,141 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub result.append(npB); } /*end else ndims*/ + return result; +} + +bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_alpha0, int iter) { + // the result is in the following list + bp::list result; + int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll; + //const int *dim_array; + float *A, *U, *U_old, *P1, *P2, *P3, *Q1, *Q2, *Q3, *Q4, *Q5, *Q6, *Q7, *Q8, *Q9, *V1, *V1_old, *V2, *V2_old, *V3, *V3_old, lambda, L2, tau, sigma, alpha1, alpha0; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + /*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/ + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + + A = reinterpret_cast(input.get_data()); + + //lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/ + //alpha1 = (float)mxGetScalar(prhs[2]); /*first-order term*/ + //alpha0 = (float)mxGetScalar(prhs[3]); /*second-order term*/ + //iter = (int)mxGetScalar(prhs[4]); /*iterations number*/ + //if (nrhs != 5) mexErrMsgTxt("Five input parameters is reqired: Image(2D/3D), Regularization parameter, alpha1, alpha0, Iterations"); + lambda = (float)d_lambda; + alpha1 = (float)d_alpha1; + alpha0 = (float)d_alpha0; + + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; + + if (number_of_dims == 2) { + /*2D case*/ + dimZ = 1; + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npQ1 = np::zeros(shape, dtype); + np::ndarray npQ2 = np::zeros(shape, dtype); + np::ndarray npQ3 = np::zeros(shape, dtype); + np::ndarray npV1 = np::zeros(shape, dtype); + np::ndarray npV1_old = np::zeros(shape, dtype); + np::ndarray npV2 = np::zeros(shape, dtype); + np::ndarray npV2_old = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + Q1 = reinterpret_cast(npQ1.get_data()); + Q2 = reinterpret_cast(npQ2.get_data()); + Q3 = reinterpret_cast(npQ3.get_data()); + V1 = reinterpret_cast(npV1.get_data()); + V1_old = reinterpret_cast(npV1_old.get_data()); + V2 = reinterpret_cast(npV2.get_data()); + V2_old = reinterpret_cast(npV2_old.get_data()); + //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + /*dual variables*/ + /*P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + Q1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Q2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Q3 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + /*printf("%i \n", i);*/ + L2 = 12.0; /*Lipshitz constant*/ + tau = 1.0 / pow(L2, 0.5); + sigma = 1.0 / pow(L2, 0.5); + + /*Copy A to U*/ + copyIm(A, U, dimX, dimY, dimZ); + /* Here primal-dual iterations begin for 2D */ + for (ll = 0; ll < iter; ll++) { + + /* Calculate Dual Variable P */ + DualP_2D(U, V1, V2, P1, P2, dimX, dimY, dimZ, sigma); + + /*Projection onto convex set for P*/ + ProjP_2D(P1, P2, dimX, dimY, dimZ, alpha1); + + /* Calculate Dual Variable Q */ + DualQ_2D(V1, V2, Q1, Q2, Q3, dimX, dimY, dimZ, sigma); + + /*Projection onto convex set for Q*/ + ProjQ_2D(Q1, Q2, Q3, dimX, dimY, dimZ, alpha0); + + /*saving U into U_old*/ + copyIm(U, U_old, dimX, dimY, dimZ); + + /*adjoint operation -> divergence and projection of P*/ + DivProjP_2D(U, A, P1, P2, dimX, dimY, dimZ, lambda, tau); + + /*get updated solution U*/ + newU(U, U_old, dimX, dimY, dimZ); + + /*saving V into V_old*/ + copyIm(V1, V1_old, dimX, dimY, dimZ); + copyIm(V2, V2_old, dimX, dimY, dimZ); + + /* upd V*/ + UpdV_2D(V1, V2, P1, P2, Q1, Q2, Q3, dimX, dimY, dimZ, tau); + + /*get new V*/ + newU(V1, V1_old, dimX, dimY, dimZ); + newU(V2, V2_old, dimX, dimY, dimZ); + } /*end of iterations*/ + + result.append(npU); + } + + + + return result; } @@ -997,8 +1043,9 @@ BOOST_PYTHON_MODULE(regularizers) np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); - def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); def("LLT_model", LLT_model); + def("PatchBased_Regul", PatchBased_Regul); + def("TGV_PD", TGV_PD); } \ No newline at end of file -- cgit v1.2.3 From db50cddf2cfe92c652ff16ce51a3bcecca96de68 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 7 Aug 2017 17:21:54 +0100 Subject: added TGV_PD --- src/Python/setup.py | 1 + src/Python/test_regularizers.py | 195 ++++++++++++++++++++++++++++++++++------ 2 files changed, 168 insertions(+), 28 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index a4eed14..0468722 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -53,6 +53,7 @@ setup( "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", + "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c", "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 6abfba4..6a34749 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -47,6 +47,8 @@ class Regularizer(): SplitBregman_TV = regularizers.SplitBregman_TV FGP_TV = regularizers.FGP_TV LLT_model = regularizers.LLT_model + PatchBased_Regul = regularizers.PatchBased_Regul + TGV_PD = regularizers.TGV_PD # Algorithm class TotalVariationPenalty(Enum): @@ -55,13 +57,17 @@ class Regularizer(): # TotalVariationPenalty def __init__(self , algorithm): - + self.setAlgorithm ( algorithm ) + # __init__ + + def setAlgorithm(self, algorithm): self.algorithm = algorithm self.pars = self.parsForAlgorithm(algorithm) - # __init__ + # setAlgorithm def parsForAlgorithm(self, algorithm): pars = dict() + if algorithm == Regularizer.Algorithm.SplitBregman_TV : pars['algorithm'] = algorithm pars['input'] = None @@ -69,6 +75,7 @@ class Regularizer(): pars['number_of_iterations'] = 35 pars['tolerance_constant'] = 0.0001 pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.FGP_TV : pars['algorithm'] = algorithm pars['input'] = None @@ -76,6 +83,7 @@ class Regularizer(): pars['number_of_iterations'] = 50 pars['tolerance_constant'] = 0.001 pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.LLT_model: pars['algorithm'] = algorithm pars['input'] = None @@ -85,6 +93,24 @@ class Regularizer(): pars['tolerance_constant'] = None pars['restrictive_Z_smoothing'] = 0 + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + + return pars # parsForAlgorithm @@ -98,6 +124,8 @@ class Regularizer(): self.pars['regularization_parameter'] = regularization_parameter #for key, value in self.pars.items(): # print("{0} = {1}".format(key, value)) + if None in self.pars: + raise Exception("Not all parameters have been provided") if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : return self.algorithm(input, regularization_parameter, @@ -112,15 +140,27 @@ class Regularizer(): elif self.algorithm == Regularizer.Algorithm.LLT_model : #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) # no default - if None in self.pars: - raise Exception("Not all parameters have been provided") - else: - return self.algorithm(input, - regularization_parameter, - self.pars['time_step'] , - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['restrictive_Z_smoothing'] ) + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # __call__ @@ -142,13 +182,40 @@ class Regularizer(): @staticmethod def LLT_model(input, regularization_parameter , time_step, number_of_iterations, tolerance_constant, restrictive_Z_smoothing=0): - reg = Regularizer(Regularizer.Algorithm.FGP_TV) + reg = Regularizer(Regularizer.Algorithm.LLT_model) out = list( reg(input, regularization_parameter, time_step=time_step, number_of_iterations=number_of_iterations, tolerance_constant=tolerance_constant, restrictive_Z_smoothing=restrictive_Z_smoothing) ) out.append(reg.pars) return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + return out #Example: @@ -171,17 +238,17 @@ u0 = Im + (perc* np.random.normal(size=np.shape(Im))) f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) u0 = f(u0).astype('float32') -# plot +## plot fig = plt.figure() -a=fig.add_subplot(2,3,1) -a.set_title('Original') -imgplot = plt.imshow(Im) +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) -a=fig.add_subplot(2,3,2) +a=fig.add_subplot(2,3,1) a.set_title('noise') imgplot = plt.imshow(u0) - +reg_output = [] ############################################################################## # Call regularizer @@ -199,8 +266,9 @@ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., numbe TV_Penalty=Regularizer.TotalVariationPenalty.l1) out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) pars = out2[2] +reg_output.append(out2) -a=fig.add_subplot(2,3,3) +a=fig.add_subplot(2,3,2) a.set_title('SplitBregman_TV') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' textstr = textstr % (pars['regularization_parameter'], @@ -213,7 +281,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); @@ -221,7 +289,9 @@ out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, number_of_iterations=10) pars = out2[-1] -a=fig.add_subplot(2,3,4) +reg_output.append(out2) + +a=fig.add_subplot(2,3,3) a.set_title('FGP_TV') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' textstr = textstr % (pars['regularization_parameter'], @@ -234,18 +304,23 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) ###################### LLT_model ######################################### # * u0 = Im + .03*randn(size(Im)); % adding noise # [Den] = LLT_model(single(u0), 10, 0.1, 1); -out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10., - time_step=0.1, - tolerance_constant=1e-4, - number_of_iterations=10) +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +#input, regularization_parameter , time_step, number_of_iterations, +# tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) pars = out2[-1] -a=fig.add_subplot(2,3,5) +reg_output.append(out2) + +a=fig.add_subplot(2,3,4) a.set_title('LLT_model') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' textstr = textstr % (pars['regularization_parameter'], @@ -259,7 +334,71 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) + +###################### PatchBased_Regul ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-1] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) +a.set_title('PatchBased_Regul') +textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f' +textstr = textstr % (pars['regularization_parameter'], + pars['searching_window_ratio'], + pars['similarity_window_ratio'], + pars['PB_filtering_parameter']) + + + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) + + +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-1] +reg_output.append(out2) + +a=fig.add_subplot(2,3,6) +a.set_title('TGV_PD') +textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d' +textstr = textstr % (pars['regularization_parameter'], + pars['first_order_term'], + pars['second_order_term'], + pars['number_of_iterations']) + + + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) -- cgit v1.2.3 From 62b30291105e8a48633629350fc9820b404da2ff Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 17 Aug 2017 16:33:09 +0100 Subject: initial revision --- src/Python/test/astra_test.py | 85 ++++++++++++++++++++++++++++++++++++ src/Python/test/simple_astra_test.py | 25 +++++++++++ 2 files changed, 110 insertions(+) create mode 100644 src/Python/test/astra_test.py create mode 100644 src/Python/test/simple_astra_test.py (limited to 'src') diff --git a/src/Python/test/astra_test.py b/src/Python/test/astra_test.py new file mode 100644 index 0000000..42c375a --- /dev/null +++ b/src/Python/test/astra_test.py @@ -0,0 +1,85 @@ +import astra +import numpy +import filefun + + +# read in the same data as the DemoRD2 +angles = filefun.dlmread("DemoRD2/angles.csv") +darks_ar = filefun.dlmread("DemoRD2/darks_ar.csv", separator=",") +flats_ar = filefun.dlmread("DemoRD2/flats_ar.csv", separator=",") + +if True: + Sino3D = numpy.load("DemoRD2/Sino3D.npy") +else: + sino = filefun.dlmread("DemoRD2/sino_01.csv", separator=",") + a = map (lambda x:x, numpy.shape(sino)) + a.append(20) + + Sino3D = numpy.zeros(tuple(a), dtype="float") + + for i in range(1,numpy.shape(Sino3D)[2]+1): + print("Read file DemoRD2/sino_%02d.csv" % i) + sino = filefun.dlmread("DemoRD2/sino_%02d.csv" % i, separator=",") + Sino3D.T[i-1] = sino.T + +Weights3D = numpy.asarray(Sino3D, dtype="float") + +##angles_rad = angles*(pi/180); % conversion to radians +##size_det = size(data_raw3D,1); % detectors dim +##angSize = size(data_raw3D, 2); % angles dim +##slices_tot = size(data_raw3D, 3); % no of slices +##recon_size = 950; % reconstruction size + + +angles_rad = angles * numpy.pi /180. +size_det, angSize, slices_tot = numpy.shape(Sino3D) +size_det, angSize, slices_tot = [int(i) for i in numpy.shape(Sino3D)] +recon_size = 950 +Z_slices = 3; +det_row_count = Z_slices; + +#proj_geom = astra_create_proj_geom('parallel3d', 1, 1, +# det_row_count, size_det, angles_rad); + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX +proj_geom = astra.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + size_det, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +vol_geom = astra.create_vol_geom(recon_size,recon_size,Z_slices); + +sino = numpy.zeros((size_det, angSize, slices_tot), dtype="float") + +#weights = ones(size(sino)); +weights = numpy.ones(numpy.shape(sino)) + +##################################################################### +## PowerMethod for Lipschitz constant + +N = vol_geom['GridColCount'] +x1 = numpy.random.rand(1,N,N) +#sqweight = sqrt(weights(:,:,1)); +sqweight = numpy.sqrt(weights.T[0]).T +##proj_geomT = proj_geom; +proj_geomT = proj_geom.copy() +##proj_geomT.DetectorRowCount = 1; +proj_geomT['DetectorRowCount'] = 1 +##vol_geomT = vol_geom; +vol_geomT = vol_geom.copy() +##vol_geomT.GridSliceCount = 1; +vol_geomT['GridSliceCount'] = 1 + +##[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + +#sino_id, y = astra.create_sino3d_gpu(x1, proj_geomT, vol_geomT); +sino_id, y = astra.create_sino(x1, proj_geomT, vol_geomT); + +##y = sqweight.*y; +##astra_mex_data3d('delete', sino_id); + + diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py new file mode 100644 index 0000000..905eeea --- /dev/null +++ b/src/Python/test/simple_astra_test.py @@ -0,0 +1,25 @@ +import astra +import numpy + +detectorSpacingX = 1.0 +detectorSpacingY = 1.0 +det_row_count = 128 +det_col_count = 128 + +angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +image_size_x = 64 +image_size_y = 64 +image_size_z = 32 + +vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z) + +x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x) +sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom) -- cgit v1.2.3 From 97e0c63f883f62ed0cc84c969756517fe4bedfe8 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 12:55:19 +0100 Subject: Regularizer.pyfirst commit --- src/Python/Regularizer.py | 322 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 src/Python/Regularizer.py (limited to 'src') diff --git a/src/Python/Regularizer.py b/src/Python/Regularizer.py new file mode 100644 index 0000000..15dbbb4 --- /dev/null +++ b/src/Python/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 8 14:26:00 2017 + +@author: ofn77899 +""" + +import regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 CPU (OMP) regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) LLT_model + 4) PatchBased_Regul + 5) TGV_PD + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = regularizers.SplitBregman_TV + FGP_TV = regularizers.FGP_TV + LLT_model = regularizers.LLT_model + PatchBased_Regul = regularizers.PatchBased_Regul + TGV_PD = regularizers.TGV_PD + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm, debug = True): + self.setAlgorithm ( algorithm ) + self.debug = debug + # __init__ + + def setAlgorithm(self, algorithm): + self.algorithm = algorithm + self.pars = self.getDefaultParsForAlgorithm(algorithm) + # setAlgorithm + + def getDefaultParsForAlgorithm(self, algorithm): + pars = dict() + + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + else: + raise Exception('Unknown regularizer algorithm') + + return pars + # parsForAlgorithm + + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.pars.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + def getParameter(self, **kwargs): + ret = {} + for key , value in kwargs.items(): + if key in self.pars.keys(): + ret[key] = self.pars[key] + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + + def __call__(self, input = None, regularization_parameter = None, **kwargs): + '''Actual call for the regularizer. + + One can either set the regularization parameters first and then call the + algorithm or set the regularization parameter during the call (as + is done in the static methods). + ''' + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + if input is not None: + self.pars['input'] = input + if regularization_parameter is not None: + self.pars['regularization_parameter'] = regularization_parameter + + if self.debug: + print ("--------------------------------------------------") + for key, value in self.pars.items(): + if key== 'algorithm' : + print("{0} = {1}".format(key, value.__name__)) + elif key == 'input': + print("{0} = {1}".format(key, np.shape(value))) + else: + print("{0} = {1}".format(key, value)) + + + if None in self.pars: + raise Exception("Not all parameters have been provided") + + input = self.pars['input'] + regularization_parameter = self.pars['regularization_parameter'] + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if len(np.shape(input)) == 2: + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + elif len(np.shape(input)) == 3: + #assuming it's 3D + # run independent calls on each slice + out3d = input.copy() + for i in range(np.shape(input)[2]): + out = self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # copy the result in the 3D image + out3d.T[i] = out[0].copy() + # append the rest of the info that the algorithm returns + output = [out3d] + for i in range(1,len(out)): + output.append(out[i]) + return output + + + + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.LLT_model) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + start_time = timeit.default_timer() + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + + return out + + def printParametersToString(self): + txt = r'' + for key, value in self.pars.items(): + if key== 'algorithm' : + txt += "{0} = {1}".format(key, value.__name__) + elif key == 'input': + txt += "{0} = {1}".format(key, np.shape(value)) + else: + txt += "{0} = {1}".format(key, value) + txt += '\n' + return txt + -- cgit v1.2.3 From 5ed47a3fc9839b1803731fe5f422d43689f66763 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 12:56:09 +0100 Subject: Test module for Boost Python currently can pass a function to the C++ layer to be evaluated. --- src/Python/Matlab2Python_utils.cpp | 68 +++++++++++++++++++++++++++++++++++++- src/Python/setup_test.py | 6 ++-- src/Python/test.py | 34 ++++++++++++++++--- 3 files changed, 99 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index 6aaad90..e15d738 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -175,6 +175,71 @@ bp::list mexFunction( np::ndarray input ) { } + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} +bp::list doSomething(np::ndarray input, PyObject *pyobj , PyObject *pyobj2) { + + boost::python::object output(boost::python::handle<>(boost::python::borrowed(pyobj))); + int isOutput = !(output == boost::python::api::object()); + + boost::python::object calculate(boost::python::handle<>(boost::python::borrowed(pyobj2))); + int isCalculate = !(calculate == boost::python::api::object()); + + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast(input.get_data()); + int * B = reinterpret_cast(zz.get_data()); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = sqrt((float)val); + std::memcpy(B + index, &val, sizeof(int)); + std::memcpy(C + index, &fval, sizeof(float)); + // if the PyObj is not None evaluate the function + if (isOutput) + output(fval); + if (isCalculate) { + float nfval = (float)bp::extract(calculate(val)); + if (isOutput) + output(nfval); + std::memcpy(C + index, &nfval, sizeof(float)); + } + } + } + } + + bp::list result; result.append(number_of_dims); @@ -196,7 +261,7 @@ BOOST_PYTHON_MODULE(prova) //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "fista"; + package.attr("__path__") = "prova"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); @@ -207,4 +272,5 @@ BOOST_PYTHON_MODULE(prova) //numpy_boost_python_register_type(); //numpy_boost_python_register_type(); def("mexFunction", mexFunction); + def("doSomething", doSomething); } \ No newline at end of file diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py index ffb9c02..7c86175 100644 --- a/src/Python/setup_test.py +++ b/src/Python/setup_test.py @@ -30,13 +30,13 @@ extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x' extra_libraries = [] if platform.system() == 'Windows': extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] - extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + #extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] if sys.version_info.major == 3 : extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] else: extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] else: - extra_include_dirs += ["../ContourTree/", "../Core/","."] + #extra_include_dirs += ["../ContourTree/", "../Core/","."] if sys.version_info.major == 3: extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] else: @@ -47,7 +47,7 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("fista", + ext_modules = [Extension("prova", sources=[ "Matlab2Python_utils.cpp", ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), diff --git a/src/Python/test.py b/src/Python/test.py index e283f89..db47380 100644 --- a/src/Python/test.py +++ b/src/Python/test.py @@ -5,14 +5,38 @@ Created on Thu Aug 3 14:08:09 2017 @author: ofn77899 """ -import fista +import prova import numpy as np -a = np.asarray([i for i in range(3*4*5)]) -a = a.reshape([3,4,5]) +a = np.asarray([i for i in range(1*2*3)]) +a = a.reshape([1,2,3]) print (a) -b = fista.mexFunction(a) +b = prova.mexFunction(a) #print (b) print (b[4].shape) print (b[4]) -print (b[5]) \ No newline at end of file +print (b[5]) + +def print_element(input): + print ("f: {0}".format(input)) + +prova.doSomething(a, print_element, None) + +c = [] +def append_to_list(input, shouldPrint=False): + c.append(input) + if shouldPrint: + print ("{0} appended to list {1}".format(input, c)) + +def element_wise_algebra(input, shouldPrint=True): + ret = input - 7 + if shouldPrint: + print ("element_wise {0}".format(ret)) + return ret + +prova.doSomething(a, append_to_list, None) +#print ("this is c: {0}".format(c)) + +b = prova.doSomething(a, None, element_wise_algebra) +#print (a) +print (b[5]) -- cgit v1.2.3 From a9274a7533b6d33a99810b2c1f1ad455768820ae Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:27:28 +0100 Subject: updated test for regularizer API --- src/Python/test_regularizers.py | 590 ++++++++++++++++++++-------------------- 1 file changed, 290 insertions(+), 300 deletions(-) (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 6a34749..5d25f02 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -8,216 +8,37 @@ Created on Fri Aug 4 11:10:05 2017 from ccpi.viewer.CILViewer2D import Converter import vtk -import regularizers import matplotlib.pyplot as plt import numpy as np import os from enum import Enum - -class Regularizer(): - '''Class to handle regularizer algorithms to be used during reconstruction - - Currently 5 regularization algorithms are available: - - 1) SplitBregman_TV - 2) FGP_TV - 3) - 4) - 5) - - Usage: - the regularizer can be invoked as object or as static method - Depending on the actual regularizer the input parameter may vary, and - a different default setting is defined. - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - - out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., - number_of_iterations=30, tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - A number of optional parameters can be passed or skipped - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) - - ''' - class Algorithm(Enum): - SplitBregman_TV = regularizers.SplitBregman_TV - FGP_TV = regularizers.FGP_TV - LLT_model = regularizers.LLT_model - PatchBased_Regul = regularizers.PatchBased_Regul - TGV_PD = regularizers.TGV_PD - # Algorithm - - class TotalVariationPenalty(Enum): - isotropic = 0 - l1 = 1 - # TotalVariationPenalty - - def __init__(self , algorithm): - self.setAlgorithm ( algorithm ) - # __init__ - - def setAlgorithm(self, algorithm): - self.algorithm = algorithm - self.pars = self.parsForAlgorithm(algorithm) - # setAlgorithm - - def parsForAlgorithm(self, algorithm): - pars = dict() - - if algorithm == Regularizer.Algorithm.SplitBregman_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 35 - pars['tolerance_constant'] = 0.0001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.FGP_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 50 - pars['tolerance_constant'] = 0.001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.LLT_model: - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['time_step'] = None - pars['number_of_iterations'] = None - pars['tolerance_constant'] = None - pars['restrictive_Z_smoothing'] = 0 - - elif algorithm == Regularizer.Algorithm.PatchBased_Regul: - pars['algorithm'] = algorithm - pars['input'] = None - pars['searching_window_ratio'] = None - pars['similarity_window_ratio'] = None - pars['PB_filtering_parameter'] = None - pars['regularization_parameter'] = None - - elif algorithm == Regularizer.Algorithm.TGV_PD: - pars['algorithm'] = algorithm - pars['input'] = None - pars['first_order_term'] = None - pars['second_order_term'] = None - pars['number_of_iterations'] = None - pars['regularization_parameter'] = None - - - - return pars - # parsForAlgorithm - - def __call__(self, input, regularization_parameter, **kwargs): - - if kwargs is not None: - for key, value in kwargs.items(): - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - self.pars['input'] = input - self.pars['regularization_parameter'] = regularization_parameter - #for key, value in self.pars.items(): - # print("{0} = {1}".format(key, value)) - if None in self.pars: - raise Exception("Not all parameters have been provided") - - if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : - return self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.FGP_TV : - return self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.LLT_model : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, - regularization_parameter, - self.pars['time_step'] , - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['restrictive_Z_smoothing'] ) - elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, regularization_parameter, - self.pars['searching_window_ratio'] , - self.pars['similarity_window_ratio'] , - self.pars['PB_filtering_parameter']) - elif self.algorithm == Regularizer.Algorithm.TGV_PD : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, regularization_parameter, - self.pars['first_order_term'] , - self.pars['second_order_term'] , - self.pars['number_of_iterations']) - - - - # __call__ - - @staticmethod - def SplitBregman_TV(input, regularization_parameter , **kwargs): - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - out = list( reg(input, regularization_parameter, **kwargs) ) - out.append(reg.pars) - return out - - @staticmethod - def FGP_TV(input, regularization_parameter , **kwargs): - reg = Regularizer(Regularizer.Algorithm.FGP_TV) - out = list( reg(input, regularization_parameter, **kwargs) ) - out.append(reg.pars) - return out - - @staticmethod - def LLT_model(input, regularization_parameter , time_step, number_of_iterations, - tolerance_constant, restrictive_Z_smoothing=0): - reg = Regularizer(Regularizer.Algorithm.LLT_model) - out = list( reg(input, regularization_parameter, time_step=time_step, - number_of_iterations=number_of_iterations, - tolerance_constant=tolerance_constant, - restrictive_Z_smoothing=restrictive_Z_smoothing) ) - out.append(reg.pars) - return out - - @staticmethod - def PatchBased_Regul(input, regularization_parameter, - searching_window_ratio, - similarity_window_ratio, - PB_filtering_parameter): - reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) - out = list( reg(input, - regularization_parameter, - searching_window_ratio=searching_window_ratio, - similarity_window_ratio=similarity_window_ratio, - PB_filtering_parameter=PB_filtering_parameter ) - ) - out.append(reg.pars) - return out - - @staticmethod - def TGV_PD(input, regularization_parameter , first_order_term, - second_order_term, number_of_iterations): - - reg = Regularizer(Regularizer.Algorithm.TGV_PD) - out = list( reg(input, regularization_parameter, - first_order_term=first_order_term, - second_order_term=second_order_term, - number_of_iterations=number_of_iterations) ) - out.append(reg.pars) - return out - - +import timeit + +from Regularizer import Regularizer + +############################################################################### +#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 +#NRMSE a normalization of the root of the mean squared error +#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum +# intensity from the two images being compared, and respectively the same for +# minval. RMSE is given by the square root of MSE: +# sqrt[(sum(A - B) ** 2) / |A|], +# where |A| means the number of elements in A. By doing this, the maximum value +# given by RMSE is maxval. + +def nrmse(im1, im2): + a, b = im1.shape + rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b)) + max_val = max(np.max(im1), np.max(im2)) + min_val = min(np.min(im1), np.min(im2)) + return 1 - (rmse / (max_val - min_val)) +############################################################################### + +############################################################################### +# +# 2D Regularizers +# +############################################################################### #Example: # figure; # Im = double(imread('lena_gray_256.tif'))/255; % loading image @@ -255,49 +76,55 @@ reg_output = [] ####################### SplitBregman_TV ##################################### # u = SplitBregman_TV(single(u0), 10, 30, 1e-04); -reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +use_object = True +if use_object: + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + # or + # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + #TV_Penalty=Regularizer.TotalVariationPenalty.l1) + plotme = reg() [0] + pars = reg.pars + textstr = reg.printParametersToString() + + #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + # TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) -out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, - #tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - -out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) -out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) -pars = out2[2] -reg_output.append(out2) +else: + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + pars = out2[2] + reg_output.append(out2) + plotme = reg_output[-1][0] + textstr = out2[-1] a=fig.add_subplot(2,3,2) -a.set_title('SplitBregman_TV') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['TV_penalty'].name) + # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) +imgplot = plt.imshow(plotme) ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); -out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, - number_of_iterations=10) -pars = out2[-1] +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, + number_of_iterations=200) +pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,3) -a.set_title('FGP_TV') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['TV_penalty'].name) + +textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) @@ -316,50 +143,12 @@ out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, time_step=0.0003, tolerance_constant=0.0001, number_of_iterations=300) -pars = out2[-1] +pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,4) -a.set_title('LLT_model') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['time_step'] - ) - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) - -###################### PatchBased_Regul ######################################### -# Quick 2D denoising example in Matlab: -# Im = double(imread('lena_gray_256.tif'))/255; % loading image -# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); - -out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - searching_window_ratio=3, - similarity_window_ratio=1, - PB_filtering_parameter=0.08) -pars = out2[-1] -reg_output.append(out2) - -a=fig.add_subplot(2,3,5) -a.set_title('PatchBased_Regul') -textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f' -textstr = textstr % (pars['regularization_parameter'], - pars['searching_window_ratio'], - pars['similarity_window_ratio'], - pars['PB_filtering_parameter']) - - - - +textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords @@ -367,6 +156,215 @@ a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) imgplot = plt.imshow(reg_output[-1][0]) +# ###################### PatchBased_Regul ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + # searching_window_ratio=3, + # similarity_window_ratio=1, + # PB_filtering_parameter=0.08) +# pars = out2[-2] +# reg_output.append(out2) + +# a=fig.add_subplot(2,3,5) + + +# textstr = out2[-1] + +# # these are matplotlib.patch.Patch properties +# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# # place a text box in upper left in axes coords +# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + # verticalalignment='top', bbox=props) +# imgplot = plt.imshow(reg_output[-1][0]) + + +# ###################### TGV_PD ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + # first_order_term=1.3, + # second_order_term=1, + # number_of_iterations=550) +# pars = out2[-2] +# reg_output.append(out2) + +# a=fig.add_subplot(2,3,6) + + +# textstr = out2[-1] + + +# # these are matplotlib.patch.Patch properties +# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# # place a text box in upper left in axes coords +# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + # verticalalignment='top', bbox=props) +# imgplot = plt.imshow(reg_output[-1][0]) + + +plt.show() + +################################################################################ +## +## 3D Regularizers +## +################################################################################ +##Example: +## figure; +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" +# +#reader = vtk.vtkMetaImageReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +##vtk returns 3D images, let's take just the one slice there is as 2D +#Im = Converter.vtk2numpy(reader.GetOutput()) +#Im = Im.astype('float32') +##imgplot = plt.imshow(Im) +#perc = 0.05 +#u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +## map the u0 u0->u0>0 +#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +#u0 = f(u0).astype('float32') +#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), +# reader.GetOutput().GetOrigin()) +#converter.Update() +#writer = vtk.vtkMetaImageWriter() +#writer.SetInputData(converter.GetOutput()) +#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") +##writer.Write() +# +# +### plot +#fig3D = plt.figure() +##a=fig.add_subplot(3,3,1) +##a.set_title('Original') +##imgplot = plt.imshow(Im) +#sliceNo = 32 +# +#a=fig3D.add_subplot(2,3,1) +#a.set_title('noise') +#imgplot = plt.imshow(u0.T[sliceNo]) +# +#reg_output3d = [] +# +############################################################################### +## Call regularizer +# +######################## SplitBregman_TV ##################################### +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +# +##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +## #tolerance_constant=1e-4, +## TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +# +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### FGP_TV ######################################### +## u = FGP_TV(single(u0), 0.05, 100, 1e-04); +#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +# number_of_iterations=200) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### LLT_model ######################################### +## * u0 = Im + .03*randn(size(Im)); % adding noise +## [Den] = LLT_model(single(u0), 10, 0.1, 1); +##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +##input, regularization_parameter , time_step, number_of_iterations, +## tolerance_constant, restrictive_Z_smoothing=0 +#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +# time_step=0.0003, +# tolerance_constant=0.0001, +# number_of_iterations=300) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### PatchBased_Regul ######################################### +## Quick 2D denoising example in Matlab: +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +## ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); +# +#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +# searching_window_ratio=3, +# similarity_window_ratio=1, +# PB_filtering_parameter=0.08) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# ###################### TGV_PD ######################################### # Quick 2D denoising example in Matlab: @@ -375,30 +373,22 @@ imgplot = plt.imshow(reg_output[-1][0]) # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - first_order_term=1.3, - second_order_term=1, - number_of_iterations=550) -pars = out2[-1] -reg_output.append(out2) - -a=fig.add_subplot(2,3,6) -a.set_title('TGV_PD') -textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d' -textstr = textstr % (pars['regularization_parameter'], - pars['first_order_term'], - pars['second_order_term'], - pars['number_of_iterations']) - - - - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) - - - +#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +# first_order_term=1.3, +# second_order_term=1, +# number_of_iterations=550) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) -- cgit v1.2.3 From 73fed4964d81f1f47a0b6ecbe66517f569327b27 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:31:16 +0100 Subject: initial commit of Reconstructor.py --- src/Python/ccpi/reconstruction/Reconstructor.py | 598 ++++++++++++++++++++++++ 1 file changed, 598 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/Reconstructor.py (limited to 'src') diff --git a/src/Python/ccpi/reconstruction/Reconstructor.py b/src/Python/ccpi/reconstruction/Reconstructor.py new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/reconstruction/Reconstructor.py @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: + + class Algorithm(Enum): + CGLS = alg.cgls + CGLS_CONV = alg.cgls_conv + SIRT = alg.sirt + MLEM = alg.mlem + CGLS_TICHONOV = alg.cgls_tikhonov + CGLS_TVREG = alg.cgls_TVreg + FISTA = 'fista' + + def __init__(self, algorithm = None, projection_data = None, + angles = None, center_of_rotation = None , + flat_field = None, dark_field = None, + iterations = None, resolution = None, isLogScale = False, threads = None, + normalized_projection = None): + + self.pars = dict() + self.pars['algorithm'] = algorithm + self.pars['projection_data'] = projection_data + self.pars['normalized_projection'] = normalized_projection + self.pars['angles'] = angles + self.pars['center_of_rotation'] = numpy.double(center_of_rotation) + self.pars['flat_field'] = flat_field + self.pars['iterations'] = iterations + self.pars['dark_field'] = dark_field + self.pars['resolution'] = resolution + self.pars['isLogScale'] = isLogScale + self.pars['threads'] = threads + if (iterations != None): + self.pars['iterationValues'] = numpy.zeros((iterations)) + + if projection_data != None and dark_field != None and flat_field != None: + norm = self.normalize(projection_data, dark_field, flat_field, 0.1) + self.pars['normalized_projection'] = norm + + + def setPars(self, parameters): + keys = ['algorithm','projection_data' ,'normalized_projection', \ + 'angles' , 'center_of_rotation' , 'flat_field', \ + 'iterations','dark_field' , 'resolution', 'isLogScale' , \ + 'threads' , 'iterationValues', 'regularize'] + + for k in keys: + if k not in parameters.keys(): + self.pars[k] = None + else: + self.pars[k] = parameters[k] + + + def sanityCheck(self): + projection_data = self.pars['projection_data'] + dark_field = self.pars['dark_field'] + flat_field = self.pars['flat_field'] + angles = self.pars['angles'] + + if projection_data != None and dark_field != None and \ + angles != None and flat_field != None: + data_shape = numpy.shape(projection_data) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + + if data_shape[1:] != numpy.shape(flat_field): + #raise Exception('Projection and flat field dimensions do not match') + return (False , 'Projection and flat field dimensions do not match') + if data_shape[1:] != numpy.shape(dark_field): + #raise Exception('Projection and dark field dimensions do not match') + return (False , 'Projection and dark field dimensions do not match') + + return (True , '' ) + elif self.pars['normalized_projection'] != None: + data_shape = numpy.shape(self.pars['normalized_projection']) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + else: + return (True , '' ) + else: + return (False , 'Not enough data') + + def reconstruct(self, parameters = None): + if parameters != None: + self.setPars(parameters) + + go , reason = self.sanityCheck() + if go: + return self._reconstruct() + else: + raise Exception(reason) + + + def _reconstruct(self, parameters=None): + if parameters!=None: + self.setPars(parameters) + parameters = self.pars + + if parameters['algorithm'] != None and \ + parameters['normalized_projection'] != None and \ + parameters['angles'] != None and \ + parameters['center_of_rotation'] != None and \ + parameters['iterations'] != None and \ + parameters['resolution'] != None and\ + parameters['threads'] != None and\ + parameters['isLogScale'] != None: + + + if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, + Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): + #store parameters + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['isLogScale'] + ) + return result + elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, + Reconstructor.Algorithm.CGLS_TICHONOV, + Reconstructor.Algorithm.CGLS_TVREG) : + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['regularize'], + numpy.zeros((parameters['iterations'])), + parameters['isLogScale'] + ) + + elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: + pass + + else: + if parameters['projection_data'] != None and \ + parameters['dark_field'] != None and \ + parameters['flat_field'] != None: + norm = self.normalize(parameters['projection_data'], + parameters['dark_field'], + parameters['flat_field'], 0.1) + self.pars['normalized_projection'] = norm + return self._reconstruct(parameters) + + + + def _normalize(self, projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + def normalize(self, projections, dark, flat, def_val=0): + norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] + return numpy.asarray (norm, dtype=numpy.float32) + + + +class FISTA(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() -- cgit v1.2.3 From a15873ea24734c9a2a7c71eed4d106968b406a07 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:56:08 +0100 Subject: module rename to cpu_regularizers --- src/Python/setup.py | 4 ++-- src/Python/test_regularizers.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index 0468722..94467c4 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -47,7 +47,7 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("regularizers", + ext_modules = [Extension("ccpi.imaging.cpu_regularizers", sources=["fista_module.cpp", "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", @@ -60,5 +60,5 @@ setup( ], zip_safe = False, - packages = {'ccpi','ccpi.reconstruction'}, + packages = {'ccpi','ccpi.fistareconstruction'}, ) diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 5d25f02..755804a 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -14,7 +14,8 @@ import os from enum import Enum import timeit -from Regularizer import Regularizer +#from Regularizer import Regularizer +from ccpi.imaging.Regularizer import Regularizer ############################################################################### #https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 -- cgit v1.2.3 From 74e2b61d107f5871b8e78de3f3d0a503c494e64c Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:58:11 +0100 Subject: changed the backward slash to forward --- src/Python/setup.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index 0468722..e6c2dc6 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -49,12 +49,12 @@ setup( cmdclass = {'build_ext': build_ext}, ext_modules = [Extension("regularizers", sources=["fista_module.cpp", - "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", - "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", - "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", - "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", - "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c", - "..\\..\\main_func\\regularizers_CPU\\utils.c" + "../../main_func/regularizers_CPU/FGP_TV_core.c", + "../../main_func/regularizers_CPU/SplitBregman_TV_core.c", + "../../main_func/regularizers_CPU/LLT_model_core.c", + "../../main_func/regularizers_CPU/PatchBased_Regul_core.c", + "../../main_func/regularizers_CPU/TGV_PD_core.c", + "../../main_func/regularizers_CPU/utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 105b57a1e98c2bb7b3bf94c43b6c669925ebb1b9 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:01:28 +0100 Subject: added viewer for testing --- src/Python/ccpi/viewer/CILViewer.py | 361 +++++++ src/Python/ccpi/viewer/CILViewer2D.py | 1126 ++++++++++++++++++++ src/Python/ccpi/viewer/QVTKWidget.py | 340 ++++++ src/Python/ccpi/viewer/QVTKWidget2.py | 84 ++ src/Python/ccpi/viewer/__init__.py | 1 + .../viewer/__pycache__/CILViewer.cpython-35.pyc | Bin 0 -> 10542 bytes .../viewer/__pycache__/CILViewer2D.cpython-35.pyc | Bin 0 -> 35633 bytes .../viewer/__pycache__/QVTKWidget.cpython-35.pyc | Bin 0 -> 10099 bytes .../viewer/__pycache__/QVTKWidget2.cpython-35.pyc | Bin 0 -> 1316 bytes .../viewer/__pycache__/__init__.cpython-35.pyc | Bin 0 -> 210 bytes src/Python/ccpi/viewer/embedvtk.py | 75 ++ 11 files changed, 1987 insertions(+) create mode 100644 src/Python/ccpi/viewer/CILViewer.py create mode 100644 src/Python/ccpi/viewer/CILViewer2D.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget2.py create mode 100644 src/Python/ccpi/viewer/__init__.py create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/embedvtk.py (limited to 'src') diff --git a/src/Python/ccpi/viewer/CILViewer.py b/src/Python/ccpi/viewer/CILViewer.py new file mode 100644 index 0000000..efcf8be --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +import math +from vtk.util import numpy_support + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + + + +class CILViewer(): + '''Simple 3D Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600): + '''creates the rendering pipeline''' + + # create a rendering window and renderer + self.ren = vtk.vtkRenderer() + self.renWin = vtk.vtkRenderWindow() + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + # img 3D as slice + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceActor = None + self.voi = None + self.wl = None + self.ia = None + self.sliceActorNo = 0 + # create a renderwindowinteractor + self.iren = vtk.vtkRenderWindowInteractor() + self.iren.SetRenderWindow(self.renWin) + + self.style = vtk.vtkInteractorStyleTrackballCamera() + self.iren.SetInteractorStyle(self.style) + + self.ren.SetBackground(.1, .2, .4) + + self.actors = {} + self.iren.RemoveObservers('MouseWheelForwardEvent') + self.iren.RemoveObservers('MouseWheelBackwardEvent') + + self.iren.AddObserver('MouseWheelForwardEvent', self.mouseInteraction, 1.0) + self.iren.AddObserver('MouseWheelBackwardEvent', self.mouseInteraction, 1.0) + + self.iren.RemoveObservers('KeyPressEvent') + self.iren.AddObserver('KeyPressEvent', self.keyPress, 1.0) + + + self.iren.Initialize() + + + + def getRenderer(self): + '''returns the renderer''' + return self.ren + + def getRenderWindow(self): + '''returns the render window''' + return self.renWin + + def getInteractor(self): + '''returns the render window interactor''' + return self.iren + + def getCamera(self): + '''returns the active camera''' + return self.ren.GetActiveCamera() + + def createPolyDataActor(self, polydata): + '''returns an actor for a given polydata''' + mapper = vtk.vtkPolyDataMapper() + if vtk.VTK_MAJOR_VERSION <= 5: + mapper.SetInput(polydata) + else: + mapper.SetInputData(polydata) + + # actor + actor = vtk.vtkActor() + actor.SetMapper(mapper) + #actor.GetProperty().SetOpacity(0.8) + return actor + + def setPolyDataActor(self, actor): + '''displays the given polydata''' + + self.ren.AddActor(actor) + + self.actors[len(self.actors)+1] = [actor, True] + self.iren.Initialize() + self.renWin.Render() + + def displayPolyData(self, polydata): + self.setPolyDataActor(self.createPolyDataActor(polydata)) + + def hideActor(self, actorno): + '''Hides an actor identified by its number in the list of actors''' + try: + if self.actors[actorno][1]: + self.ren.RemoveActor(self.actors[actorno][0]) + self.actors[actorno][1] = False + except KeyError as ke: + print ("Warning Actor not present") + + def showActor(self, actorno, actor = None): + '''Shows hidden actor identified by its number in the list of actors''' + try: + if not self.actors[actorno][1]: + self.ren.AddActor(self.actors[actorno][0]) + self.actors[actorno][1] = True + return actorno + except KeyError as ke: + # adds it to the actors if not there already + if actor != None: + self.ren.AddActor(actor) + self.actors[len(self.actors)+1] = [actor, True] + return len(self.actors) + + def addActor(self, actor): + '''Adds an actor to the render''' + return self.showActor(0, actor) + + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + + def startRenderLoop(self): + self.iren.Start() + + + def setupObservers(self, interactor): + interactor.RemoveObservers('LeftButtonPressEvent') + interactor.AddObserver('LeftButtonPressEvent', self.mouseInteraction) + interactor.Initialize() + + + def mouseInteraction(self, interactor, event): + if event == 'MouseWheelForwardEvent': + maxSlice = self.img3D.GetDimensions()[self.sliceOrientation] + if (self.sliceno + 1 < maxSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno + 1 + self.displaySliceActor(self.sliceno) + else: + minSlice = 0 + if (self.sliceno - 1 > minSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno - 1 + self.displaySliceActor(self.sliceno) + + + def keyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "x": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_YZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "y": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "z": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceno = int(self.img3D.GetDimensions()[2] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + if interactor.GetKeyCode() == "X": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("x") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("y") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = math.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("z") + self.keyPress(interactor, event) + else : + print ("Unhandled event %s" % interactor.GetKeyCode()) + + + + def setInput3DData(self, imageData): + self.img3D = imageData + + def setInputAsNumpy(self, numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + + def displaySliceActor(self, sliceno = 0): + self.sliceno = sliceno + first = False + + self.sliceActor , self.voi, self.wl , self.ia = \ + self.getSliceActor(self.img3D, + sliceno, + self.sliceActor, + self.voi, + self.wl, + self.ia) + no = self.showActor(self.sliceActorNo, self.sliceActor) + self.sliceActorNo = no + + self.iren.Initialize() + self.renWin.Render() + + return self.sliceActorNo + + + def getSliceActor(self, + imageData , + sliceno=0, + imageActor=None , + voi=None, + windowLevel=None, + imageAccumulate=None): + '''Slices a 3D volume and then creates an actor to be rendered''' + if (voi==None): + voi = vtk.vtkExtractVOI() + #voi = vtk.vtkImageClip() + voi.SetInputData(imageData) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = sliceno + extent[self.sliceOrientation * 2 + 1] = sliceno + voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + voi.Update() + # set window/level for all slices + if imageAccumulate == None: + imageAccumulate = vtk.vtkImageAccumulate() + + if (windowLevel == None): + windowLevel = vtk.vtkImageMapToWindowLevelColors() + imageAccumulate.SetInputData(imageData) + imageAccumulate.Update() + cmax = imageAccumulate.GetMax()[0] + cmin = imageAccumulate.GetMin()[0] + windowLevel.SetLevel((cmax+cmin)/2) + windowLevel.SetWindow(cmax-cmin) + + windowLevel.SetInputData(voi.GetOutput()) + windowLevel.Update() + + if imageActor == None: + imageActor = vtk.vtkImageActor() + imageActor.SetInputData(windowLevel.GetOutput()) + imageActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + imageActor.Update() + return (imageActor , voi, windowLevel, imageAccumulate) + + + # Set interpolation on + def setInterpolateOn(self): + self.sliceActor.SetInterpolate(True) + self.renWin.Render() + + # Set interpolation off + def setInterpolateOff(self): + self.sliceActor.SetInterpolate(False) + self.renWin.Render() \ No newline at end of file diff --git a/src/Python/ccpi/viewer/CILViewer2D.py b/src/Python/ccpi/viewer/CILViewer2D.py new file mode 100644 index 0000000..c1629af --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer2D.py @@ -0,0 +1,1126 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +from vtk.util import numpy_support , vtkImageImportFromArray +from enum import Enum + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + +CONTROL_KEY = 8 +SHIFT_KEY = 4 +ALT_KEY = -128 + + +# Converter class +class Converter(): + + # Utility functions to transform numpy arrays to vtkImageData and viceversa + @staticmethod + def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Creates a vtkImageImportFromArray object and returns it. + + It handles the different axis order from numpy to VTK''' + importer = vtkImageImportFromArray.vtkImageImportFromArray() + importer.SetArray(numpy.transpose(nparray).copy()) + importer.SetDataSpacing(spacing) + importer.SetDataOrigin(origin) + return importer + + @staticmethod + def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Converts a 3D numpy array to a vtkImageData''' + importer = Converter.numpy2vtkImporter(nparray, spacing, origin) + importer.Update() + return importer.GetOutput() + + @staticmethod + def vtk2numpy(imgdata): + '''Converts the VTK data to 3D numpy array''' + img_data = numpy_support.vtk_to_numpy( + imgdata.GetPointData().GetScalars()) + + dims = imgdata.GetDimensions() + dims = (dims[2],dims[1],dims[0]) + data3d = numpy.reshape(img_data, dims) + + return numpy.transpose(data3d).copy() + + @staticmethod + def tiffStack2numpy(filename, indices, + extent = None , sampleRate = None ,\ + flatField = None, darkField = None): + '''Converts a stack of TIFF files to numpy array. + + filename must contain the whole path. The filename is supposed to be named and + have a suffix with the ordinal file number, i.e. /path/to/projection_%03d.tif + + indices are the suffix, generally an increasing number + + Optionally extracts only a selection of the 2D images and (optionally) + normalizes. + ''' + + stack = vtk.vtkImageData() + reader = vtk.vtkTIFFReader() + voi = vtk.vtkExtractVOI() + + #directory = "C:\\Users\\ofn77899\\Documents\\CCPi\\IMAT\\20170419_crabtomo\\crabtomo\\" + + stack_image = numpy.asarray([]) + nreduced = len(indices) + + for num in range(len(indices)): + fn = filename % indices[num] + print ("resampling %s" % ( fn ) ) + reader.SetFileName(fn) + reader.Update() + print (reader.GetOutput().GetScalarTypeAsString()) + if num == 0: + if (extent == None): + sliced = reader.GetOutput().GetExtent() + stack.SetExtent(sliced[0],sliced[1], sliced[2],sliced[3], 0, nreduced-1) + else: + sliced = extent + voi.SetVOI(extent) + + if sampleRate is not None: + voi.SetSampleRate(sampleRate) + ext = numpy.asarray([(sliced[2*i+1] - sliced[2*i])/sampleRate[i] for i in range(3)], dtype=int) + print ("ext {0}".format(ext)) + stack.SetExtent(0, ext[0] , 0, ext[1], 0, nreduced-1) + else: + stack.SetExtent(0, sliced[1] - sliced[0] , 0, sliced[3]-sliced[2], 0, nreduced-1) + if (flatField != None and darkField != None): + stack.AllocateScalars(vtk.VTK_FLOAT, 1) + else: + stack.AllocateScalars(reader.GetOutput().GetScalarType(), 1) + print ("Image Size: %d" % ((sliced[1]+1)*(sliced[3]+1) )) + stack_image = Converter.vtk2numpy(stack) + print ("Stack shape %s" % str(numpy.shape(stack_image))) + + if extent!=None: + voi.SetInputData(reader.GetOutput()) + voi.Update() + img = voi.GetOutput() + else: + img = reader.GetOutput() + + theSlice = Converter.vtk2numpy(img).T[0] + if darkField != None and flatField != None: + print("Try to normalize") + #if numpy.shape(darkField) == numpy.shape(flatField) and numpy.shape(flatField) == numpy.shape(theSlice): + theSlice = Converter.normalize(theSlice, darkField, flatField, 0.01) + print (theSlice.dtype) + + + print ("Slice shape %s" % str(numpy.shape(theSlice))) + stack_image.T[num] = theSlice.copy() + + return stack_image + + @staticmethod + def normalize(projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + + +## Utility functions to transform numpy arrays to vtkImageData and viceversa +#def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtkImporter(nparray, spacing, origin) +# +#def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtk(nparray, spacing, origin) +# +#def vtk2numpy(imgdata): +# return Converter.vtk2numpy(imgdata) +# +#def tiffStack2numpy(filename, indices): +# return Converter.tiffStack2numpy(filename, indices) + +class ViewerEvent(Enum): + # left button + PICK_EVENT = 0 + # alt + right button + move + WINDOW_LEVEL_EVENT = 1 + # shift + right button + ZOOM_EVENT = 2 + # control + right button + PAN_EVENT = 3 + # control + left button + CREATE_ROI_EVENT = 4 + # alt + left button + DELETE_ROI_EVENT = 5 + # release button + NO_EVENT = -1 + + +#class CILInteractorStyle(vtk.vtkInteractorStyleTrackballCamera): +class CILInteractorStyle(vtk.vtkInteractorStyleImage): + + def __init__(self, callback): + vtk.vtkInteractorStyleImage.__init__(self) + self.callback = callback + self._viewer = callback + priority = 1.0 + +# self.AddObserver("MouseWheelForwardEvent" , callback.OnMouseWheelForward , priority) +# self.AddObserver("MouseWheelBackwardEvent" , callback.OnMouseWheelBackward, priority) +# self.AddObserver('KeyPressEvent', callback.OnKeyPress, priority) +# self.AddObserver('LeftButtonPressEvent', callback.OnLeftButtonPressEvent, priority) +# self.AddObserver('RightButtonPressEvent', callback.OnRightButtonPressEvent, priority) +# self.AddObserver('LeftButtonReleaseEvent', callback.OnLeftButtonReleaseEvent, priority) +# self.AddObserver('RightButtonReleaseEvent', callback.OnRightButtonReleaseEvent, priority) +# self.AddObserver('MouseMoveEvent', callback.OnMouseMoveEvent, priority) + + self.AddObserver("MouseWheelForwardEvent" , self.OnMouseWheelForward , priority) + self.AddObserver("MouseWheelBackwardEvent" , self.OnMouseWheelBackward, priority) + self.AddObserver('KeyPressEvent', self.OnKeyPress, priority) + self.AddObserver('LeftButtonPressEvent', self.OnLeftButtonPressEvent, priority) + self.AddObserver('RightButtonPressEvent', self.OnRightButtonPressEvent, priority) + self.AddObserver('LeftButtonReleaseEvent', self.OnLeftButtonReleaseEvent, priority) + self.AddObserver('RightButtonReleaseEvent', self.OnRightButtonReleaseEvent, priority) + self.AddObserver('MouseMoveEvent', self.OnMouseMoveEvent, priority) + + self.InitialEventPosition = (0,0) + + + def SetInitialEventPosition(self, xy): + self.InitialEventPosition = xy + + def GetInitialEventPosition(self): + return self.InitialEventPosition + + def GetKeyCode(self): + return self.GetInteractor().GetKeyCode() + + def SetKeyCode(self, keycode): + self.GetInteractor().SetKeyCode(keycode) + + def GetControlKey(self): + return self.GetInteractor().GetControlKey() == CONTROL_KEY + + def GetShiftKey(self): + return self.GetInteractor().GetShiftKey() == SHIFT_KEY + + def GetAltKey(self): + return self.GetInteractor().GetAltKey() == ALT_KEY + + def GetEventPosition(self): + return self.GetInteractor().GetEventPosition() + + def GetEventPositionInWorldCoordinates(self): + pass + + def GetDeltaEventPosition(self): + x,y = self.GetInteractor().GetEventPosition() + return (x - self.InitialEventPosition[0] , y - self.InitialEventPosition[1]) + + def Dolly(self, factor): + self.callback.camera.Dolly(factor) + self.callback.ren.ResetCameraClippingRange() + + def GetDimensions(self): + return self._viewer.img3D.GetDimensions() + + def GetInputData(self): + return self._viewer.img3D + + def GetSliceOrientation(self): + return self._viewer.sliceOrientation + + def SetSliceOrientation(self, orientation): + self._viewer.sliceOrientation = orientation + + def GetActiveSlice(self): + return self._viewer.sliceno + + def SetActiveSlice(self, sliceno): + self._viewer.sliceno = sliceno + + def UpdatePipeline(self, reset = False): + self._viewer.updatePipeline(reset) + + def GetActiveCamera(self): + return self._viewer.ren.GetActiveCamera() + + def SetActiveCamera(self, camera): + self._viewer.ren.SetActiveCamera(camera) + + def ResetCamera(self): + self._viewer.ren.ResetCamera() + + def Render(self): + self._viewer.renWin.Render() + + def UpdateSliceActor(self): + self._viewer.sliceActor.Update() + + def AdjustCamera(self): + self._viewer.AdjustCamera() + + def SaveRender(self, filename): + self._viewer.SaveRender(filename) + + def GetRenderWindow(self): + return self._viewer.renWin + + def GetRenderer(self): + return self._viewer.ren + + def GetROIWidget(self): + return self._viewer.ROIWidget + + def SetViewerEvent(self, event): + self._viewer.event = event + + def GetViewerEvent(self): + return self._viewer.event + + def SetInitialCameraPosition(self, position): + self._viewer.InitialCameraPosition = position + + def GetInitialCameraPosition(self): + return self._viewer.InitialCameraPosition + + def SetInitialLevel(self, level): + self._viewer.InitialLevel = level + + def GetInitialLevel(self): + return self._viewer.InitialLevel + + def SetInitialWindow(self, window): + self._viewer.InitialWindow = window + + def GetInitialWindow(self): + return self._viewer.InitialWindow + + def GetWindowLevel(self): + return self._viewer.wl + + def SetROI(self, roi): + self._viewer.ROI = roi + + def GetROI(self): + return self._viewer.ROI + + def UpdateCornerAnnotation(self, text, corner): + self._viewer.updateCornerAnnotation(text, corner) + + def GetPicker(self): + return self._viewer.picker + + def GetCornerAnnotation(self): + return self._viewer.cornerAnnotation + + def UpdateROIHistogram(self): + self._viewer.updateROIHistogram() + + + ############### Handle events + def OnMouseWheelForward(self, interactor, event): + maxSlice = self.GetDimensions()[self.GetSliceOrientation()] + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + + if (self.GetActiveSlice() + advance < maxSlice): + self.SetActiveSlice(self.GetActiveSlice() + advance) + + self.UpdatePipeline() + else: + print ("maxSlice %d request %d" % (maxSlice, self.GetActiveSlice() + 1 )) + + def OnMouseWheelBackward(self, interactor, event): + minSlice = 0 + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + if (self.GetActiveSlice() - advance >= minSlice): + self.SetActiveSlice( self.GetActiveSlice() - advance) + self.UpdatePipeline() + else: + print ("minSlice %d request %d" % (minSlice, self.GetActiveSlice() + 1 )) + + def OnKeyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "X": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_YZ ) + self.SetActiveSlice( int(self.GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Y": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XZ ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Z": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XY ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[2] / 2) ) + self.UpdatePipeline(True) + if interactor.GetKeyCode() == "x": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("X") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("Y") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = numpy.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,1,0) + self.SetActiveCamera(camera) + self.ResetCamera() + self.Render() + interactor.SetKeyCode("Z") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "a": + # reset color/window + cmax = self._viewer.ia.GetMax()[0] + cmin = self._viewer.ia.GetMin()[0] + + self.SetInitialLevel( (cmax+cmin)/2 ) + self.SetInitialWindow( cmax-cmin ) + + self.GetWindowLevel().SetLevel(self.GetInitialLevel()) + self.GetWindowLevel().SetWindow(self.GetInitialWindow()) + + self.GetWindowLevel().Update() + + self.UpdateSliceActor() + self.AdjustCamera() + self.Render() + + elif interactor.GetKeyCode() == "s": + filename = "current_render" + self.SaveRender(filename) + elif interactor.GetKeyCode() == "q": + print ("Terminating by pressing q %s" % (interactor.GetKeyCode(), )) + interactor.SetKeyCode("e") + self.OnKeyPress(interactor, event) + else : + #print ("Unhandled event %s" % (interactor.GetKeyCode(), ))) + pass + + def OnLeftButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + if ctrl and not (alt and shift): + self.SetViewerEvent( ViewerEvent.CREATE_ROI_EVENT ) + wsize = self.GetRenderWindow().GetSize() + position = interactor.GetEventPosition() + self.GetROIWidget().GetBorderRepresentation().SetPosition((position[0]/wsize[0] - 0.05) , (position[1]/wsize[1] - 0.05)) + self.GetROIWidget().GetBorderRepresentation().SetPosition2( (0.1) , (0.1)) + + self.GetROIWidget().On() + self.SetDisplayHistogram(True) + self.Render() + print ("Event %s is CREATE_ROI_EVENT" % (event)) + elif alt and not (shift and ctrl): + self.SetViewerEvent( ViewerEvent.DELETE_ROI_EVENT ) + self.GetROIWidget().Off() + self._viewer.updateCornerAnnotation("", 1, False) + self.SetDisplayHistogram(False) + self.Render() + print ("Event %s is DELETE_ROI_EVENT" % (event)) + elif not (ctrl and alt and shift): + self.SetViewerEvent ( ViewerEvent.PICK_EVENT ) + self.HandlePickEvent(interactor, event) + print ("Event %s is PICK_EVENT" % (event)) + + + def SetDisplayHistogram(self, display): + if display: + if (self._viewer.displayHistogram == 0): + self.GetRenderer().AddActor(self._viewer.histogramPlotActor) + self.firstHistogram = 1 + self.Render() + + self._viewer.histogramPlotActor.VisibilityOn() + self._viewer.displayHistogram = True + else: + self._viewer.histogramPlotActor.VisibilityOff() + self._viewer.displayHistogram = False + + + def OnLeftButtonReleaseEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.CREATE_ROI_EVENT: + #bc = self.ROIWidget.GetBorderRepresentation().GetPositionCoordinate() + #print (bc.GetValue()) + self.OnROIModifiedEvent(interactor, event) + + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def OnRightButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + + if alt and not (ctrl and shift): + self.SetViewerEvent( ViewerEvent.WINDOW_LEVEL_EVENT ) + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif shift and not (ctrl and alt): + self.SetViewerEvent( ViewerEvent.ZOOM_EVENT ) + self.SetInitialCameraPosition( self.GetActiveCamera().GetPosition()) + print ("Event %s is ZOOM_EVENT" % (event)) + elif ctrl and not (shift and alt): + self.SetViewerEvent (ViewerEvent.PAN_EVENT ) + self.SetInitialCameraPosition ( self.GetActiveCamera().GetPosition() ) + print ("Event %s is PAN_EVENT" % (event)) + + def OnRightButtonReleaseEvent(self, interactor, event): + print (event) + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + self.SetInitialLevel( self.GetWindowLevel().GetLevel() ) + self.SetInitialWindow ( self.GetWindowLevel().GetWindow() ) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT or \ + self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.SetInitialCameraPosition( () ) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + + def OnROIModifiedEvent(self, interactor, event): + + #print ("ROI EVENT " + event) + p1 = self.GetROIWidget().GetBorderRepresentation().GetPositionCoordinate() + p2 = self.GetROIWidget().GetBorderRepresentation().GetPosition2Coordinate() + wsize = self.GetRenderWindow().GetSize() + + #print (p1.GetValue()) + #print (p2.GetValue()) + pp1 = [p1.GetValue()[0] * wsize[0] , p1.GetValue()[1] * wsize[1] , 0.0] + pp2 = [p2.GetValue()[0] * wsize[0] + pp1[0] , p2.GetValue()[1] * wsize[1] + pp1[1] , 0.0] + vox1 = self.viewport2imageCoordinate(pp1) + vox2 = self.viewport2imageCoordinate(pp2) + + self.SetROI( (vox1 , vox2) ) + roi = self.GetROI() + print ("Pixel1 %d,%d,%d Value %f" % vox1 ) + print ("Pixel2 %d,%d,%d Value %f" % vox2 ) + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][2] - roi[0][2]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + x = abs(roi[1][1] - roi[0][1]) + y = abs(roi[1][2] - roi[0][2]) + + text = "ROI: %d x %d, %.2f kp" % (x,y,float(x*y)/1024.) + print (text) + self.UpdateCornerAnnotation(text, 1) + self.UpdateROIHistogram() + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.GetPicker().Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.GetPicker().GetPickPosition()) + pickPosition[self.GetSliceOrientation()] = \ + self.GetInputData().GetSpacing()[self.GetSliceOrientation()] * self.GetActiveSlice() + \ + self.GetInputData().GetOrigin()[self.GetSliceOrientation()] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.GetInputData().GetDimensions() + print (dims) + spac = self.GetInputData().GetSpacing() + orig = self.GetInputData().GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.GetInputData().GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + + def OnMouseMoveEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT: + self.HandleZoomEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.HandlePanEvent(interactor, event) + + + def HandleZoomEvent(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + size = self.GetRenderWindow().GetSize() + dy = - 4 * dy / size[1] + + print ("distance: " + str(self.GetActiveCamera().GetDistance())) + + print ("\ndy: %f\ncamera dolly %f\n" % (dy, 1 + dy)) + + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + dist = newposition[SLICE_ORIENTATION_XY] * ( 1 + dy ) + newposition[SLICE_ORIENTATION_XY] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[SLICE_ORIENTATION_XZ] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[SLICE_ORIENTATION_YZ] *= ( 1 + dy ) + #print ("new position " + str(newposition)) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + print ("distance after: " + str(self.GetActiveCamera().GetDistance())) + + def HandlePanEvent(self, interactor, event): + x,y = interactor.GetEventPosition() + x0,y0 = interactor.GetInitialEventPosition() + + ic = self.viewport2imageCoordinate((x,y)) + ic0 = self.viewport2imageCoordinate((x0,y0)) + + dx = 4 *( ic[0] - ic0[0]) + dy = 4* (ic[1] - ic0[1]) + + camera = vtk.vtkCamera() + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + newfocalpoint = [i for i in self.GetActiveCamera().GetFocalPoint()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + newposition[0] -= dx + newposition[1] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[1] = newposition[1] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[0] -= dx + newposition[2] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[2] = newposition[2] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[1] -= dx + newposition[2] -= dy + newfocalpoint[2] = newposition[2] + newfocalpoint[1] = newposition[1] + #print ("new position " + str(newposition)) + camera.SetFocalPoint(newfocalpoint) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + def HandleWindowLevel(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + print ("Event delta %d %d" % (dx,dy)) + size = self.GetRenderWindow().GetSize() + + dx = 4 * dx / size[0] + dy = 4 * dy / size[1] + window = self.GetInitialWindow() + level = self.GetInitialLevel() + + if abs(window) > 0.01: + dx = dx * window + else: + dx = dx * (lambda x: -0.01 if x <0 else 0.01)(window); + + if abs(level) > 0.01: + dy = dy * level + else: + dy = dy * (lambda x: -0.01 if x <0 else 0.01)(level) + + + # Abs so that direction does not flip + + if window < 0.0: + dx = -1*dx + if level < 0.0: + dy = -1*dy + + # Compute new window level + + newWindow = dx + window + newLevel = level - dy + + # Stay away from zero and really + + if abs(newWindow) < 0.01: + newWindow = 0.01 * (lambda x: -1 if x <0 else 1)(newWindow) + + if abs(newLevel) < 0.01: + newLevel = 0.01 * (lambda x: -1 if x <0 else 1)(newLevel) + + self.GetWindowLevel().SetWindow(newWindow) + self.GetWindowLevel().SetLevel(newLevel) + + self.GetWindowLevel().Update() + self.UpdateSliceActor() + self.AdjustCamera() + + self.Render() + + def HandlePickEvent(self, interactor, event): + position = interactor.GetEventPosition() + #print ("PICK " + str(position)) + vox = self.viewport2imageCoordinate(position) + #print ("Pixel %d,%d,%d Value %f" % vox ) + self._viewer.cornerAnnotation.VisibilityOn() + self.UpdateCornerAnnotation("[%d,%d,%d] : %.2f" % vox , 0) + self.Render() + +############################################################################### + + + +class CILViewer2D(): + '''Simple Interactive Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600, ren=None, renWin=None,iren=None): + '''creates the rendering pipeline''' + # create a rendering window and renderer + if ren == None: + self.ren = vtk.vtkRenderer() + else: + self.ren = ren + if renWin == None: + self.renWin = vtk.vtkRenderWindow() + else: + self.renWin = renWin + if iren == None: + self.iren = vtk.vtkRenderWindowInteractor() + else: + self.iren = iren + + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + self.style = CILInteractorStyle(self) + + self.iren.SetInteractorStyle(self.style) + self.iren.SetRenderWindow(self.renWin) + self.iren.Initialize() + self.ren.SetBackground(.1, .2, .4) + + self.camera = vtk.vtkCamera() + self.camera.ParallelProjectionOn() + self.ren.SetActiveCamera(self.camera) + + # data + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + + #Actors + self.sliceActor = vtk.vtkImageActor() + self.voi = vtk.vtkExtractVOI() + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia = vtk.vtkImageAccumulate() + self.sliceActorNo = 0 + + #initial Window/Level + self.InitialLevel = 0 + self.InitialWindow = 0 + + #ViewerEvent + self.event = ViewerEvent.NO_EVENT + + # ROI Widget + self.ROIWidget = vtk.vtkBorderWidget() + self.ROIWidget.SetInteractor(self.iren) + self.ROIWidget.CreateDefaultRepresentation() + self.ROIWidget.GetBorderRepresentation().GetBorderProperty().SetColor(0,1,0) + self.ROIWidget.AddObserver(vtk.vtkWidgetEvent.Select, self.style.OnROIModifiedEvent, 1.0) + + # edge points of the ROI + self.ROI = () + + #picker + self.picker = vtk.vtkPropPicker() + self.picker.PickFromListOn() + self.picker.AddPickList(self.sliceActor) + + self.iren.SetPicker(self.picker) + + # corner annotation + self.cornerAnnotation = vtk.vtkCornerAnnotation() + self.cornerAnnotation.SetMaximumFontSize(12); + self.cornerAnnotation.PickableOff(); + self.cornerAnnotation.VisibilityOff(); + self.cornerAnnotation.GetTextProperty().ShadowOn(); + self.cornerAnnotation.SetLayerNumber(1); + + + + # cursor doesn't show up + self.cursor = vtk.vtkCursor2D() + self.cursorMapper = vtk.vtkPolyDataMapper2D() + self.cursorActor = vtk.vtkActor2D() + self.cursor.SetModelBounds(-10, 10, -10, 10, 0, 0) + self.cursor.SetFocalPoint(0, 0, 0) + self.cursor.AllOff() + self.cursor.AxesOn() + self.cursorActor.PickableOff() + self.cursorActor.VisibilityOn() + self.cursorActor.GetProperty().SetColor(1, 1, 1) + self.cursorActor.SetLayerNumber(1) + self.cursorMapper.SetInputData(self.cursor.GetOutput()) + self.cursorActor.SetMapper(self.cursorMapper) + + # Zoom + self.InitialCameraPosition = () + + # XY Plot actor for histogram + self.displayHistogram = False + self.firstHistogram = 0 + self.roiIA = vtk.vtkImageAccumulate() + self.roiVOI = vtk.vtkExtractVOI() + self.histogramPlotActor = vtk.vtkXYPlotActor() + self.histogramPlotActor.ExchangeAxesOff(); + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetAdjustXLabels(3) + self.histogramPlotActor.SetXTitle( "Level" ) + self.histogramPlotActor.SetYTitle( "N" ) + self.histogramPlotActor.SetXValuesToValue() + self.histogramPlotActor.SetPlotColor(0, (0,1,1) ) + self.histogramPlotActor.SetPosition(0.6,0.6) + self.histogramPlotActor.SetPosition2(0.4,0.4) + + + + def GetInteractor(self): + return self.iren + + def GetRenderer(self): + return self.ren + + def setInput3DData(self, imageData): + self.img3D = imageData + self.installPipeline() + + def setInputAsNumpy(self, numpyarray, origin=(0,0,0), spacing=(1.,1.,1.), + rescale=True, dtype=vtk.VTK_UNSIGNED_SHORT): + importer = Converter.numpy2vtkImporter(numpyarray, spacing, origin) + importer.Update() + + if rescale: + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(importer.GetOutput()) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + if (iMax - iMin == 0): + scale = 1 + else: + if dtype == vtk.VTK_UNSIGNED_SHORT: + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + elif dtype == vtk.VTK_UNSIGNED_INT: + scale = vtk.VTK_UNSIGNED_INT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(importer.GetOutput()) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(-iMin) + shiftScaler.SetOutputScalarType(dtype) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + else: + self.img3D = importer.GetOutput() + + self.installPipeline() + + def displaySlice(self, sliceno = 0): + self.sliceno = sliceno + + self.updatePipeline() + + self.renWin.Render() + + return self.sliceActorNo + + def updatePipeline(self, resetcamera = False): + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + self.ia.Update() + self.wl.Update() + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + + self.updateCornerAnnotation("Slice %d/%d" % (self.sliceno + 1 , self.img3D.GetDimensions()[self.sliceOrientation])) + + if self.displayHistogram: + self.updateROIHistogram() + + self.AdjustCamera(resetcamera) + + self.renWin.Render() + + + def installPipeline(self): + '''Slices a 3D volume and then creates an actor to be rendered''' + + self.ren.AddViewProp(self.cornerAnnotation) + + self.voi.SetInputData(self.img3D) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + # set window/level for current slices + + + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia.SetInputData(self.voi.GetOutput()) + self.ia.Update() + cmax = self.ia.GetMax()[0] + cmin = self.ia.GetMin()[0] + + self.InitialLevel = (cmax+cmin)/2 + self.InitialWindow = cmax-cmin + + + self.wl.SetLevel(self.InitialLevel) + self.wl.SetWindow(self.InitialWindow) + + self.wl.SetInputData(self.voi.GetOutput()) + self.wl.Update() + + self.sliceActor.SetInputData(self.wl.GetOutput()) + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + self.sliceActor.SetInterpolate(False) + self.ren.AddActor(self.sliceActor) + self.ren.ResetCamera() + self.ren.Render() + + self.AdjustCamera() + + self.ren.AddViewProp(self.cursorActor) + self.cursorActor.VisibilityOn() + + self.iren.Initialize() + self.renWin.Render() + #self.iren.Start() + + def AdjustCamera(self, resetcamera = False): + self.ren.ResetCameraClippingRange() + if resetcamera: + self.ren.ResetCamera() + + + def getROI(self): + return self.ROI + + def getROIExtent(self): + p0 = self.ROI[0] + p1 = self.ROI[1] + return (p0[0], p1[0],p0[1],p1[1],p0[2],p1[2]) + + ############### Handle events are moved to the interactor style + + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.picker.Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.picker.GetPickPosition()) + pickPosition[self.sliceOrientation] = \ + self.img3D.GetSpacing()[self.sliceOrientation] * self.sliceno + \ + self.img3D.GetOrigin()[self.sliceOrientation] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.img3D.GetDimensions() + print (dims) + spac = self.img3D.GetSpacing() + orig = self.img3D.GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.img3D.GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + def GetRenderWindow(self): + return self.renWin + + + def startRenderLoop(self): + self.iren.Start() + + def GetSliceOrientation(self): + return self.sliceOrientation + + def GetActiveSlice(self): + return self.sliceno + + def updateCornerAnnotation(self, text , idx=0, visibility=True): + if visibility: + self.cornerAnnotation.VisibilityOn() + else: + self.cornerAnnotation.VisibilityOff() + + self.cornerAnnotation.SetText(idx, text) + self.iren.Render() + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + def updateROIHistogram(self): + + extent = [0 for i in range(6)] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + extent[4] = self.GetActiveSlice() + extent[5] = self.GetActiveSlice()+1 + #y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + #x = abs(roi[1][0] - roi[0][0]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[2] = self.GetActiveSlice() + extent[3] = self.GetActiveSlice()+1 + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + #x = abs(roi[1][1] - roi[0][1]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[0] = self.GetActiveSlice() + extent[1] = self.GetActiveSlice()+1 + + self.roiVOI.SetVOI(extent) + self.roiVOI.SetInputData(self.img3D) + self.roiVOI.Update() + irange = self.roiVOI.GetOutput().GetScalarRange() + + self.roiIA.SetInputData(self.roiVOI.GetOutput()) + self.roiIA.IgnoreZeroOff() + self.roiIA.SetComponentExtent(0,int(irange[1]-irange[0]-1),0,0,0,0 ) + self.roiIA.SetComponentOrigin( int(irange[0]),0,0 ); + self.roiIA.SetComponentSpacing( 1,0,0 ); + self.roiIA.Update() + + self.histogramPlotActor.AddDataSetInputConnection(self.roiIA.GetOutputPort()) + self.histogramPlotActor.SetXRange(irange[0],irange[1]) + + self.histogramPlotActor.SetYRange( self.roiIA.GetOutput().GetScalarRange() ) + + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/QVTKWidget.py b/src/Python/ccpi/viewer/QVTKWidget.py new file mode 100644 index 0000000..906786b --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget.py @@ -0,0 +1,340 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter + +class QVTKWidget(QtWidgets.QWidget): + + """ A QVTKWidget for Python and Qt.""" + + # Map between VTK and Qt cursors. + _CURSOR_MAP = { + 0: QtCore.Qt.ArrowCursor, # VTK_CURSOR_DEFAULT + 1: QtCore.Qt.ArrowCursor, # VTK_CURSOR_ARROW + 2: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZENE + 3: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZENWSE + 4: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZESW + 5: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZESE + 6: QtCore.Qt.SizeVerCursor, # VTK_CURSOR_SIZENS + 7: QtCore.Qt.SizeHorCursor, # VTK_CURSOR_SIZEWE + 8: QtCore.Qt.SizeAllCursor, # VTK_CURSOR_SIZEALL + 9: QtCore.Qt.PointingHandCursor, # VTK_CURSOR_HAND + 10: QtCore.Qt.CrossCursor, # VTK_CURSOR_CROSSHAIR + } + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + # the current button + self._ActiveButton = QtCore.Qt.NoButton + + # private attributes + self.__oldFocus = None + self.__saveX = 0 + self.__saveY = 0 + self.__saveModifiers = QtCore.Qt.NoModifier + self.__saveButtons = QtCore.Qt.NoButton + self.__timeframe = 0 + + # create qt-level widget + QtWidgets.QWidget.__init__(self, parent, wflags|QtCore.Qt.MSWindowsOwnDC) + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D() + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + + self._Iren.Register(self._RenderWindow) + self._Iren.SetRenderWindow(self._RenderWindow) + self._RenderWindow.SetWindowInfo(str(int(self.winId()))) + + # do all the necessary qt setup + self.setAttribute(QtCore.Qt.WA_OpaquePaintEvent) + self.setAttribute(QtCore.Qt.WA_PaintOnScreen) + self.setMouseTracking(True) # get all mouse events + self.setFocusPolicy(QtCore.Qt.WheelFocus) + self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)) + + self._Timer = QtCore.QTimer(self) + #self.connect(self._Timer, QtCore.pyqtSignal('timeout()'), self.TimerEvent) + + self._Iren.AddObserver('CreateTimerEvent', self.CreateTimer) + self._Iren.AddObserver('DestroyTimerEvent', self.DestroyTimer) + self._Iren.GetRenderWindow().AddObserver('CursorChangedEvent', + self.CursorChangedEvent) + + # Destructor + def __del__(self): + self._Iren.UnRegister(self._RenderWindow) + #QtWidgets.QWidget.__del__(self) + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + + # GetInteractor + def GetInteractor(self): + return self._Iren + + # Display image data + def GetPyveViewer(self): + return self._PyveViewer + + def __getattr__(self, attr): + """Makes the object behave like a vtkGenericRenderWindowInteractor""" + print (attr) + if attr == '__vtk__': + return lambda t=self._Iren: t + elif hasattr(self._Iren, attr): + return getattr(self._Iren, attr) +# else: +# raise AttributeError( self.__class__.__name__ + \ +# " has no attribute named " + attr ) + + def CreateTimer(self, obj, evt): + self._Timer.start(10) + + def DestroyTimer(self, obj, evt): + self._Timer.stop() + return 1 + + def TimerEvent(self): + self._Iren.InvokeEvent("TimerEvent") + + def CursorChangedEvent(self, obj, evt): + """Called when the CursorChangedEvent fires on the render window.""" + # This indirection is needed since when the event fires, the current + # cursor is not yet set so we defer this by which time the current + # cursor should have been set. + QtCore.QTimer.singleShot(0, self.ShowCursor) + + def HideCursor(self): + """Hides the cursor.""" + self.setCursor(QtCore.Qt.BlankCursor) + + def ShowCursor(self): + """Shows the cursor.""" + vtk_cursor = self._Iren.GetRenderWindow().GetCurrentCursor() + qt_cursor = self._CURSOR_MAP.get(vtk_cursor, QtCore.Qt.ArrowCursor) + self.setCursor(qt_cursor) + + def sizeHint(self): + return QtCore.QSize(400, 400) + + def paintEngine(self): + return None + + def paintEvent(self, ev): + self._RenderWindow.Render() + + def resizeEvent(self, ev): + self._RenderWindow.Render() + w = self.width() + h = self.height() + + self._RenderWindow.SetSize(w, h) + self._Iren.SetSize(w, h) + + def _GetCtrlShiftAlt(self, ev): + ctrl = shift = alt = False + + if hasattr(ev, 'modifiers'): + if ev.modifiers() & QtCore.Qt.ShiftModifier: + shift = True + if ev.modifiers() & QtCore.Qt.ControlModifier: + ctrl = True + if ev.modifiers() & QtCore.Qt.AltModifier: + alt = True + else: + if self.__saveModifiers & QtCore.Qt.ShiftModifier: + shift = True + if self.__saveModifiers & QtCore.Qt.ControlModifier: + ctrl = True + if self.__saveModifiers & QtCore.Qt.AltModifier: + alt = True + + return ctrl, shift, alt + + def enterEvent(self, ev): + if not self.hasFocus(): + self.__oldFocus = self.focusWidget() + self.setFocus() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("EnterEvent") + + def leaveEvent(self, ev): + if self.__saveButtons == QtCore.Qt.NoButton and self.__oldFocus: + self.__oldFocus.setFocus() + self.__oldFocus = None + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("LeaveEvent") + + def mousePressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + repeat = 0 + if ev.type() == QtCore.QEvent.MouseButtonDblClick: + repeat = 1 + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), repeat, None) + + self._Iren.SetAltKey(alt) + self._ActiveButton = ev.button() + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonPressEvent") + + def mouseReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonReleaseEvent") + + def mouseMoveEvent(self, ev): + self.__saveModifiers = ev.modifiers() + self.__saveButtons = ev.buttons() + self.__saveX = ev.x() + self.__saveY = ev.y() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("MouseMoveEvent") + + def keyPressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = str(ev.text()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyPressEvent") + self._Iren.InvokeEvent("CharEvent") + + def keyReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = chr(ev.key()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyReleaseEvent") + + def wheelEvent(self, ev): + print ("angleDeltaX %d" % ev.angleDelta().x()) + print ("angleDeltaY %d" % ev.angleDelta().y()) + if ev.angleDelta().y() >= 0: + self._Iren.InvokeEvent("MouseWheelForwardEvent") + else: + self._Iren.InvokeEvent("MouseWheelBackwardEvent") + + def GetRenderWindow(self): + return self._RenderWindow + + def Render(self): + self.update() + + +def QVTKExample(): + """A simple example that uses the QVTKWidget class.""" + + # every QT app needs an app + app = QtWidgets.QApplication(['PyVE QVTKWidget Example']) + page_VTK = QtWidgets.QWidget() + page_VTK.resize(500,500) + layout = QtWidgets.QVBoxLayout(page_VTK) + # create the widget + widget = QVTKWidget(parent=None) + layout.addWidget(widget) + + #reader = vtk.vtkPNGReader() + #reader.SetFileName("F:\Diagnostics\Images\PyVE\VTKData\Data\camscene.png") + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + widget.SetInput(reader.GetOutput()) + + # show the widget + page_VTK.show() + # start event processing + app.exec_() + +if __name__ == "__main__": + QVTKExample() diff --git a/src/Python/ccpi/viewer/QVTKWidget2.py b/src/Python/ccpi/viewer/QVTKWidget2.py new file mode 100644 index 0000000..e32e1c2 --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget2.py @@ -0,0 +1,84 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor + +class QVTKWidget(QVTKRenderWindowInteractor): + + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + kw = dict() + super().__init__(parent, **kw) + + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D(400,400) + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + kw['iren'] = self._Iren + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + kw['rw'] = self._RenderWindow + + + + + def GetInteractor(self): + return self._Iren + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__init__.py b/src/Python/ccpi/viewer/__init__.py new file mode 100644 index 0000000..946188b --- /dev/null +++ b/src/Python/ccpi/viewer/__init__.py @@ -0,0 +1 @@ +from ccpi.viewer.CILViewer import CILViewer \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc new file mode 100644 index 0000000..711f77a Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc new file mode 100644 index 0000000..77c2ca8 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc new file mode 100644 index 0000000..3d11b87 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc new file mode 100644 index 0000000..2fa2eaf Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..fcea537 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/embedvtk.py b/src/Python/ccpi/viewer/embedvtk.py new file mode 100644 index 0000000..b5eb0a7 --- /dev/null +++ b/src/Python/ccpi/viewer/embedvtk.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Jul 27 12:18:58 2017 + +@author: ofn77899 +""" + +#!/usr/bin/env python + +import sys +import vtk +from PyQt5 import QtCore, QtWidgets +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor +import QVTKWidget2 + +class MainWindow(QtWidgets.QMainWindow): + + def __init__(self, parent = None): + QtWidgets.QMainWindow.__init__(self, parent) + + self.frame = QtWidgets.QFrame() + + self.vl = QtWidgets.QVBoxLayout() +# self.vtkWidget = QVTKRenderWindowInteractor(self.frame) + + self.vtkWidget = QVTKWidget2.QVTKWidget(self.frame) + self.iren = self.vtkWidget.GetInteractor() + self.vl.addWidget(self.vtkWidget) + + + + + self.ren = vtk.vtkRenderer() + self.vtkWidget.GetRenderWindow().AddRenderer(self.ren) +# self.iren = self.vtkWidget.GetRenderWindow().GetInteractor() +# +# # Create source +# source = vtk.vtkSphereSource() +# source.SetCenter(0, 0, 0) +# source.SetRadius(5.0) +# +# # Create a mapper +# mapper = vtk.vtkPolyDataMapper() +# mapper.SetInputConnection(source.GetOutputPort()) +# +# # Create an actor +# actor = vtk.vtkActor() +# actor.SetMapper(mapper) +# +# self.ren.AddActor(actor) +# +# self.ren.ResetCamera() +# + self.frame.setLayout(self.vl) + self.setCentralWidget(self.frame) + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + self.vtkWidget.SetInput(reader.GetOutput()) + + #self.vktWidget.Initialize() + #self.vktWidget.Start() + + self.show() + #self.iren.Initialize() + + +if __name__ == "__main__": + + app = QtWidgets.QApplication(sys.argv) + + window = MainWindow() + + sys.exit(app.exec_()) \ No newline at end of file -- cgit v1.2.3 From ad62962697509d977087c25d24a3ff083d9c4308 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:08:32 +0100 Subject: initial revision --- src/Python/ccpi/imaging/Regularizer.py | 322 +++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 src/Python/ccpi/imaging/Regularizer.py (limited to 'src') diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py new file mode 100644 index 0000000..fb9ae08 --- /dev/null +++ b/src/Python/ccpi/imaging/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 8 14:26:00 2017 + +@author: ofn77899 +""" + +from ccpi.imaging import cpu_regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 CPU (OMP) regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) LLT_model + 4) PatchBased_Regul + 5) TGV_PD + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = cpu_regularizers.SplitBregman_TV + FGP_TV = cpu_regularizers.FGP_TV + LLT_model = cpu_regularizers.LLT_model + PatchBased_Regul = cpu_regularizers.PatchBased_Regul + TGV_PD = cpu_regularizers.TGV_PD + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm, debug = True): + self.setAlgorithm ( algorithm ) + self.debug = debug + # __init__ + + def setAlgorithm(self, algorithm): + self.algorithm = algorithm + self.pars = self.getDefaultParsForAlgorithm(algorithm) + # setAlgorithm + + def getDefaultParsForAlgorithm(self, algorithm): + pars = dict() + + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + else: + raise Exception('Unknown regularizer algorithm') + + return pars + # parsForAlgorithm + + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.pars.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + def getParameter(self, **kwargs): + ret = {} + for key , value in kwargs.items(): + if key in self.pars.keys(): + ret[key] = self.pars[key] + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + + def __call__(self, input = None, regularization_parameter = None, **kwargs): + '''Actual call for the regularizer. + + One can either set the regularization parameters first and then call the + algorithm or set the regularization parameter during the call (as + is done in the static methods). + ''' + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + if input is not None: + self.pars['input'] = input + if regularization_parameter is not None: + self.pars['regularization_parameter'] = regularization_parameter + + if self.debug: + print ("--------------------------------------------------") + for key, value in self.pars.items(): + if key== 'algorithm' : + print("{0} = {1}".format(key, value.__name__)) + elif key == 'input': + print("{0} = {1}".format(key, np.shape(value))) + else: + print("{0} = {1}".format(key, value)) + + + if None in self.pars: + raise Exception("Not all parameters have been provided") + + input = self.pars['input'] + regularization_parameter = self.pars['regularization_parameter'] + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if len(np.shape(input)) == 2: + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + elif len(np.shape(input)) == 3: + #assuming it's 3D + # run independent calls on each slice + out3d = input.copy() + for i in range(np.shape(input)[2]): + out = self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # copy the result in the 3D image + out3d.T[i] = out[0].copy() + # append the rest of the info that the algorithm returns + output = [out3d] + for i in range(1,len(out)): + output.append(out[i]) + return output + + + + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.LLT_model) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + start_time = timeit.default_timer() + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + + return out + + def printParametersToString(self): + txt = r'' + for key, value in self.pars.items(): + if key== 'algorithm' : + txt += "{0} = {1}".format(key, value.__name__) + elif key == 'input': + txt += "{0} = {1}".format(key, np.shape(value)) + else: + txt += "{0} = {1}".format(key, value) + txt += '\n' + return txt + -- cgit v1.2.3 From a2ca45848e354f376c53ecd3fed946d64c1ff3aa Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:10:04 +0100 Subject: module rename --- src/Python/fista_module.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index eacda3d..c36329e 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -1032,13 +1032,13 @@ bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_al return result; } -BOOST_PYTHON_MODULE(regularizers) +BOOST_PYTHON_MODULE(cpu_regularizers) { np::initialize(); //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "regularizers"; + package.attr("__path__") = "cpu_regularizers"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); -- cgit v1.2.3 From cc3a464ec587e95ddfd421cd3836a7677dfb9744 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 16:51:18 +0100 Subject: export/import data from hdf5 Added file to export the data from DemoRD2.m to HDF5 to pass it to Python. Added file to import the data from DemoRD2.m from HDF5. --- src/Python/test/readhd5.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/Python/test/readhd5.py (limited to 'src') diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py new file mode 100644 index 0000000..1e19e14 --- /dev/null +++ b/src/Python/test/readhd5.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +""" + +import h5py +import numpy + +def getEntry(nx, location): + for item in nx[location].keys(): + print (item) + +filename = r'C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\Demos\DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D')) +Weights3D = numpy.asarray(nx.get('/Weights3D')) +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad')) +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] \ No newline at end of file -- cgit v1.2.3 From 56915cc00ded38d24c23b9ab1a0717d52d430ddd Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 16:54:59 +0100 Subject: initial revision for testing --- .../ccpi/reconstruction/FISTAReconstructor.py | 354 +++++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/FISTAReconstructor.py (limited to 'src') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py new file mode 100644 index 0000000..ea96b53 --- /dev/null +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -0,0 +1,354 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +#from ccpi.reconstruction.parallelbeam import alg + +from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else: + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + -- cgit v1.2.3 From ed5737df1e9a613ad881d3b61c62c2627027faa4 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:39:37 +0100 Subject: bugfix --- src/Python/Matlab2Python_utils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index e15d738..ee76bc7 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -123,7 +123,7 @@ T * mxGetData(const np::ndarray pm) { probably this would work. A = reinterpret_cast(prhs[0]); */ - return reinterpret_cast(prhs[0]); + //return reinterpret_cast(prhs[0]); } template @@ -273,4 +273,4 @@ BOOST_PYTHON_MODULE(prova) //numpy_boost_python_register_type(); def("mexFunction", mexFunction); def("doSomething", doSomething); -} \ No newline at end of file +} -- cgit v1.2.3 From 82d2db8a7514c850887c18143626539b7ca8b794 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:41:10 +0100 Subject: initial facility to test the FISTA --- src/Python/test_reconstructor.py | 179 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 src/Python/test_reconstructor.py (limited to 'src') diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py new file mode 100644 index 0000000..0fd08f5 --- /dev/null +++ b/src/Python/test_reconstructor.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.reconstruction_dev.FISTAReconstructor import FISTAReconstructor +import astra + +##def getEntry(nx, location): +## for item in nx[location].keys(): +## print (item) + +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D')) +Weights3D = numpy.asarray(nx.get('/Weights3D')) +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad')) +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 3 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, + image_size_y, + image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +#fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D ) + #N = params.vol_geom.GridColCount + +pars = dict() +pars['projector_geometry'] = proj_geom +pars['output_geometry'] = vol_geom +pars['input_sinogram'] = Sino3D +sliceZ , nangles , detectors = numpy.shape(Sino3D) +pars['detectors'] = detectors +pars['number_of_angles'] = nangles +pars['SlicesZ'] = sliceZ + + +pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) + +N = pars['output_geometry']['GridColCount'] +proj_geom = pars['projector_geometry'] +vol_geom = pars['output_geometry'] +weights = pars['weights'] +SlicesZ = pars['SlicesZ'] + +if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights[0]) + proj_geomT = proj_geom.copy(); + proj_geomT['DetectorRowCount'] = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + + import matplotlib.pyplot as plt + fig = plt.figure() + + #a.set_title('Lipschitz') + for i in range(niter): +# [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); +# s = norm(x1(:)); +# x1 = x1/s; +# [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); +# y = sqweight.*y; +# astra_mex_data3d('delete', sino_id); +# astra_mex_data3d('delete', id); + print ("iteration {0}".format(i)) + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + #a=fig.add_subplot(2,1,1) + #imgplot = plt.imshow(y[0]) + + y = sqweight * y # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geomT, + vol_geomT); + print ("shape {1} x1 {0}".format(x1.T[:4].T, numpy.shape(x1))) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + print ("x1 {0}".format(x1.T[:4].T)) + +# ### this line? +# sino_id, y = astra.creators.create_sino3d_gpu(x1, +# proj_geomT, +# vol_geomT); +# y = sqweight * y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT +else: + #% divergen beam geometry + print('Calculating Lipshitz constant for divergen beam geometry...') + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 -- cgit v1.2.3 From 7111d98258becca09e4c93e3c66edb7d524d6463 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:05 +0100 Subject: initial revision --- src/Python/ccpi/imaging/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/imaging/__init__.py (limited to 'src') diff --git a/src/Python/ccpi/imaging/__init__.py b/src/Python/ccpi/imaging/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 3f26b1d8ab3a632ceca97bdf04225008f9163684 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:27 +0100 Subject: initial revision --- src/Python/ccpi/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/__init__.py (limited to 'src') diff --git a/src/Python/ccpi/__init__.py b/src/Python/ccpi/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 447756a338dfa993e2969298af19f1f9707a409a Mon Sep 17 00:00:00 2001 From: algol Date: Fri, 25 Aug 2017 15:58:52 +0100 Subject: removed vtk dependency --- src/Python/test_regularizers.py | 124 +++++++++++++++++++++++++++------------- 1 file changed, 85 insertions(+), 39 deletions(-) (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 755804a..86849eb 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -5,15 +5,15 @@ Created on Fri Aug 4 11:10:05 2017 @author: ofn77899 """ -from ccpi.viewer.CILViewer2D import Converter -import vtk +#from ccpi.viewer.CILViewer2D import Converter +#import vtk import matplotlib.pyplot as plt import numpy as np import os from enum import Enum import timeit - +#from PIL import Image #from Regularizer import Regularizer from ccpi.imaging.Regularizer import Regularizer @@ -46,12 +46,20 @@ def nrmse(im1, im2): # u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; # u = SplitBregman_TV(single(u0), 10, 30, 1e-04); -filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" -reader = vtk.vtkTIFFReader() -reader.SetFileName(os.path.normpath(filename)) -reader.Update() +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +#reader = vtk.vtkTIFFReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() #vtk returns 3D images, let's take just the one slice there is as 2D -Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 +#Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 +filename = '/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif' +Im = plt.imread(filename) +#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255 +#img.show() +Im = np.asarray(Im, dtype='float32') + + + #imgplot = plt.imshow(Im) perc = 0.05 @@ -80,6 +88,7 @@ reg_output = [] use_object = True if use_object: reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + print (reg.pars) reg.setParameter(input=u0) reg.setParameter(regularization_parameter=10.) # or @@ -113,7 +122,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(plotme) +imgplot = plt.imshow(plotme,cmap="gray") ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); @@ -125,14 +134,32 @@ reg_output.append(out2) a=fig.add_subplot(2,3,3) +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) + + textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) + verticalalignment='top', bbox=props) imgplot = plt.imshow(reg_output[-1][0]) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") ###################### LLT_model ######################################### # * u0 = Im + .03*randn(size(Im)); % adding noise @@ -149,13 +176,32 @@ pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,4) +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) + + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") # ###################### PatchBased_Regul ######################################### # # Quick 2D denoising example in Matlab: @@ -163,24 +209,24 @@ imgplot = plt.imshow(reg_output[-1][0]) # # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise # # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); -# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - # searching_window_ratio=3, - # similarity_window_ratio=1, - # PB_filtering_parameter=0.08) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,5) +a=fig.add_subplot(2,3,5) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") # ###################### TGV_PD ######################################### @@ -190,25 +236,25 @@ imgplot = plt.imshow(reg_output[-1][0]) # # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - # first_order_term=1.3, - # second_order_term=1, - # number_of_iterations=550) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,6) +a=fig.add_subplot(2,3,6) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") plt.show() -- cgit v1.2.3 From fb5e0ad0ad94f5b919b17f3223834380dce683d4 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 16:38:52 +0100 Subject: The calculation of the Lipschitz constant works --- src/Python/test_reconstructor.py | 60 ++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 20 deletions(-) (limited to 'src') diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index 0fd08f5..76ce3ac 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -23,10 +23,10 @@ nx = h5py.File(filename, "r") entries = [entry for entry in nx['/'].keys()] print (entries) -Sino3D = numpy.asarray(nx.get('/Sino3D')) -Weights3D = numpy.asarray(nx.get('/Weights3D')) +Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") +Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] -angles_rad = numpy.asarray(nx.get('/angles_rad')) +angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] @@ -62,17 +62,18 @@ vol_geom = astra.creators.create_vol_geom( image_size_x, #N = params.vol_geom.GridColCount pars = dict() -pars['projector_geometry'] = proj_geom -pars['output_geometry'] = vol_geom -pars['input_sinogram'] = Sino3D +pars['projector_geometry'] = proj_geom.copy() +pars['output_geometry'] = vol_geom.copy() +pars['input_sinogram'] = Sino3D.copy() sliceZ , nangles , detectors = numpy.shape(Sino3D) pars['detectors'] = detectors pars['number_of_angles'] = nangles pars['SlicesZ'] = sliceZ -pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) - +#pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) +pars['weights'] = Weights3D.copy() + N = pars['output_geometry']['GridColCount'] proj_geom = pars['projector_geometry'] vol_geom = pars['output_geometry'] @@ -82,7 +83,7 @@ SlicesZ = pars['SlicesZ'] if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice print('Calculating Lipshitz constant for parallel beam geometry...') - niter = 15;# % number of iteration for the PM + niter = 16;# % number of iteration for the PM #N = params.vol_geom.GridColCount; #x1 = rand(N,N,1); x1 = numpy.random.rand(1,N,N) @@ -96,7 +97,8 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); import matplotlib.pyplot as plt - fig = plt.figure() + fig = [] + props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) #a.set_title('Lipschitz') for i in range(niter): @@ -107,14 +109,27 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): # y = sqweight.*y; # astra_mex_data3d('delete', sino_id); # astra_mex_data3d('delete', id); - print ("iteration {0}".format(i)) + #print ("iteration {0}".format(i)) + fig.append(plt.figure()) + + a=fig[-1].add_subplot(1,2,1) + a.text(0.05, 0.95, "iteration {0}, x1".format(i), transform=a.transAxes, + fontsize=14,verticalalignment='top', bbox=props) + + imgplot = plt.imshow(x1[0].copy()) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT) - #a=fig.add_subplot(2,1,1) - #imgplot = plt.imshow(y[0]) + a=fig[-1].add_subplot(1,2,2) + a.text(0.05, 0.95, "iteration {0}, y".format(i), + transform=a.transAxes, fontsize=14,verticalalignment='top', + bbox=props) + + imgplot = plt.imshow(y[0].copy()) - y = sqweight * y # element wise multiplication + y = (sqweight * y).copy() # element wise multiplication #b=fig.add_subplot(2,1,2) #imgplot = plt.imshow(x1[0]) @@ -122,15 +137,17 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #astra_mex_data3d('delete', sino_id); astra.matlab.data3d('delete', sino_id) + del x1 - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), proj_geomT, - vol_geomT); - print ("shape {1} x1 {0}".format(x1.T[:4].T, numpy.shape(x1))) + vol_geomT) + del y + + s = numpy.linalg.norm(x1) ### this line? - x1 = x1/s; - print ("x1 {0}".format(x1.T[:4].T)) + x1 = (x1/s).copy(); # ### this line? # sino_id, y = astra.creators.create_sino3d_gpu(x1, @@ -138,10 +155,13 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): # vol_geomT); # y = sqweight * y; astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + #end del proj_geomT del vol_geomT + #plt.show() else: #% divergen beam geometry print('Calculating Lipshitz constant for divergen beam geometry...') -- cgit v1.2.3 From 391473269674bc98697eabac0b4fb2bd89f5d85e Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 16:55:48 +0100 Subject: Reorganized code with new fista package name --- src/Python/ccpi/fista/FISTAReconstructor.py | 389 ++++++++++++++ src/Python/ccpi/fista/FISTAReconstructor.pyc | Bin 0 -> 3804 bytes src/Python/ccpi/fista/FISTAReconstructor.py~ | 349 ++++++++++++ src/Python/ccpi/fista/Reconstructor.py | 425 +++++++++++++++ src/Python/ccpi/fista/Reconstructor.py~ | 598 +++++++++++++++++++++ src/Python/ccpi/fista/__init__.py | 0 src/Python/ccpi/fista/__init__.pyc | Bin 0 -> 189 bytes .../__pycache__/FISTAReconstructor.cpython-35.pyc | Bin 0 -> 3641 bytes .../ccpi/fista/__pycache__/__init__.cpython-35.pyc | Bin 0 -> 185 bytes src/Python/test_reconstructor.py | 11 +- 10 files changed, 1767 insertions(+), 5 deletions(-) create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.pyc create mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py~ create mode 100644 src/Python/ccpi/fista/Reconstructor.py create mode 100644 src/Python/ccpi/fista/Reconstructor.py~ create mode 100644 src/Python/ccpi/fista/__init__.py create mode 100644 src/Python/ccpi/fista/__init__.pyc create mode 100644 src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc create mode 100644 src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py new file mode 100644 index 0000000..1e76815 --- /dev/null +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -0,0 +1,389 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +#from ccpi.reconstruction.parallelbeam import alg + +#from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + # handle parmeters: + # obligatory parameters + self.pars = dict() + self.pars['projector_geometry'] = projector_geometry + self.pars['output_geometry'] = output_geometry + self.pars['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.pars['detectors'] = detectors + self.pars['number_og_angles'] = nangles + self.pars['SlicesZ'] = sliceZ + + print (self.pars) + # handle optional input parameters (at instantiation) + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.pars['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not 'ideal_image' in kwargs.keys(): + self.pars['ideal_image'] = None + + if not 'region_of_interest'in kwargs.keys() : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not 'regularizer' in kwargs.keys() : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + N = self.pars['output_geometry']['GridColCount'] + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 5;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights[0]) + proj_geomT = proj_geom.copy(); + proj_geomT['DetectorRowCount'] = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + + + for i in range(niter): + # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); + # s = norm(x1(:)); + # x1 = x1/s; + # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + # y = sqweight.*y; + # astra_mex_data3d('delete', sino_id); + # astra_mex_data3d('delete', id); + #print ("iteration {0}".format(i)) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + + y = (sqweight * y).copy() # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + del x1 + + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + proj_geomT, + vol_geomT) + del y + + + s = numpy.linalg.norm(x1) + ### this line? + x1 = (x1/s).copy(); + + # ### this line? + # sino_id, y = astra.creators.create_sino3d_gpu(x1, + # proj_geomT, + # vol_geomT); + # y = sqweight * y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + + #end + del proj_geomT + del vol_geomT + #plt.show() + else: + #% divergen beam geometry + print('Calculating Lipshitz constant for divergen beam geometry...') + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location, nx): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + diff --git a/src/Python/ccpi/fista/FISTAReconstructor.pyc b/src/Python/ccpi/fista/FISTAReconstructor.pyc new file mode 100644 index 0000000..ecc4d7d Binary files /dev/null and b/src/Python/ccpi/fista/FISTAReconstructor.pyc differ diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py~ b/src/Python/ccpi/fista/FISTAReconstructor.py~ new file mode 100644 index 0000000..6c7024d --- /dev/null +++ b/src/Python/ccpi/fista/FISTAReconstructor.py~ @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +#from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + + diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py new file mode 100644 index 0000000..d29ac0d --- /dev/null +++ b/src/Python/ccpi/fista/Reconstructor.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() diff --git a/src/Python/ccpi/fista/Reconstructor.py~ b/src/Python/ccpi/fista/Reconstructor.py~ new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/fista/Reconstructor.py~ @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: + + class Algorithm(Enum): + CGLS = alg.cgls + CGLS_CONV = alg.cgls_conv + SIRT = alg.sirt + MLEM = alg.mlem + CGLS_TICHONOV = alg.cgls_tikhonov + CGLS_TVREG = alg.cgls_TVreg + FISTA = 'fista' + + def __init__(self, algorithm = None, projection_data = None, + angles = None, center_of_rotation = None , + flat_field = None, dark_field = None, + iterations = None, resolution = None, isLogScale = False, threads = None, + normalized_projection = None): + + self.pars = dict() + self.pars['algorithm'] = algorithm + self.pars['projection_data'] = projection_data + self.pars['normalized_projection'] = normalized_projection + self.pars['angles'] = angles + self.pars['center_of_rotation'] = numpy.double(center_of_rotation) + self.pars['flat_field'] = flat_field + self.pars['iterations'] = iterations + self.pars['dark_field'] = dark_field + self.pars['resolution'] = resolution + self.pars['isLogScale'] = isLogScale + self.pars['threads'] = threads + if (iterations != None): + self.pars['iterationValues'] = numpy.zeros((iterations)) + + if projection_data != None and dark_field != None and flat_field != None: + norm = self.normalize(projection_data, dark_field, flat_field, 0.1) + self.pars['normalized_projection'] = norm + + + def setPars(self, parameters): + keys = ['algorithm','projection_data' ,'normalized_projection', \ + 'angles' , 'center_of_rotation' , 'flat_field', \ + 'iterations','dark_field' , 'resolution', 'isLogScale' , \ + 'threads' , 'iterationValues', 'regularize'] + + for k in keys: + if k not in parameters.keys(): + self.pars[k] = None + else: + self.pars[k] = parameters[k] + + + def sanityCheck(self): + projection_data = self.pars['projection_data'] + dark_field = self.pars['dark_field'] + flat_field = self.pars['flat_field'] + angles = self.pars['angles'] + + if projection_data != None and dark_field != None and \ + angles != None and flat_field != None: + data_shape = numpy.shape(projection_data) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + + if data_shape[1:] != numpy.shape(flat_field): + #raise Exception('Projection and flat field dimensions do not match') + return (False , 'Projection and flat field dimensions do not match') + if data_shape[1:] != numpy.shape(dark_field): + #raise Exception('Projection and dark field dimensions do not match') + return (False , 'Projection and dark field dimensions do not match') + + return (True , '' ) + elif self.pars['normalized_projection'] != None: + data_shape = numpy.shape(self.pars['normalized_projection']) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + else: + return (True , '' ) + else: + return (False , 'Not enough data') + + def reconstruct(self, parameters = None): + if parameters != None: + self.setPars(parameters) + + go , reason = self.sanityCheck() + if go: + return self._reconstruct() + else: + raise Exception(reason) + + + def _reconstruct(self, parameters=None): + if parameters!=None: + self.setPars(parameters) + parameters = self.pars + + if parameters['algorithm'] != None and \ + parameters['normalized_projection'] != None and \ + parameters['angles'] != None and \ + parameters['center_of_rotation'] != None and \ + parameters['iterations'] != None and \ + parameters['resolution'] != None and\ + parameters['threads'] != None and\ + parameters['isLogScale'] != None: + + + if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, + Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): + #store parameters + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['isLogScale'] + ) + return result + elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, + Reconstructor.Algorithm.CGLS_TICHONOV, + Reconstructor.Algorithm.CGLS_TVREG) : + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['regularize'], + numpy.zeros((parameters['iterations'])), + parameters['isLogScale'] + ) + + elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: + pass + + else: + if parameters['projection_data'] != None and \ + parameters['dark_field'] != None and \ + parameters['flat_field'] != None: + norm = self.normalize(parameters['projection_data'], + parameters['dark_field'], + parameters['flat_field'], 0.1) + self.pars['normalized_projection'] = norm + return self._reconstruct(parameters) + + + + def _normalize(self, projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + def normalize(self, projections, dark, flat, def_val=0): + norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] + return numpy.asarray (norm, dtype=numpy.float32) + + + +class FISTA(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/Python/ccpi/fista/__init__.pyc b/src/Python/ccpi/fista/__init__.pyc new file mode 100644 index 0000000..719e264 Binary files /dev/null and b/src/Python/ccpi/fista/__init__.pyc differ diff --git a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc new file mode 100644 index 0000000..84f16e2 Binary files /dev/null and b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc differ diff --git a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..90c23ff Binary files /dev/null and b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc differ diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index 76ce3ac..6f46e96 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -58,7 +58,8 @@ vol_geom = astra.creators.create_vol_geom( image_size_x, ## First pass the arguments to the FISTAReconstructor and test the ## Lipschitz constant -#fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D ) +fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D , weights=Weights3D) +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) #N = params.vol_geom.GridColCount pars = dict() @@ -83,7 +84,7 @@ SlicesZ = pars['SlicesZ'] if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice print('Calculating Lipshitz constant for parallel beam geometry...') - niter = 16;# % number of iteration for the PM + niter = 5;# % number of iteration for the PM #N = params.vol_geom.GridColCount; #x1 = rand(N,N,1); x1 = numpy.random.rand(1,N,N) @@ -129,7 +130,7 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): imgplot = plt.imshow(y[0].copy()) - y = (sqweight * y).copy() # element wise multiplication + y = (sqweight * y) # element wise multiplication #b=fig.add_subplot(2,1,2) #imgplot = plt.imshow(x1[0]) @@ -139,7 +140,7 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): astra.matlab.data3d('delete', sino_id) del x1 - idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y), proj_geomT, vol_geomT) del y @@ -147,7 +148,7 @@ if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): s = numpy.linalg.norm(x1) ### this line? - x1 = (x1/s).copy(); + x1 = (x1/s) # ### this line? # sino_id, y = astra.creators.create_sino3d_gpu(x1, -- cgit v1.2.3 From 64c0b9e7a1bcfc54e6ed8b57274d53c3ed9bb950 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 17:03:17 +0100 Subject: use refactore code --- src/Python/ccpi/fista/FISTAReconstructor.pyc | Bin 3804 -> 0 bytes src/Python/ccpi/fista/FISTAReconstructor.py~ | 349 ------------ src/Python/ccpi/fista/Reconstructor.py~ | 598 --------------------- src/Python/ccpi/fista/__init__.pyc | Bin 189 -> 0 bytes .../__pycache__/FISTAReconstructor.cpython-35.pyc | Bin 3641 -> 0 bytes .../ccpi/fista/__pycache__/__init__.cpython-35.pyc | Bin 185 -> 0 bytes src/Python/test_reconstructor.py | 2 +- 7 files changed, 1 insertion(+), 948 deletions(-) delete mode 100644 src/Python/ccpi/fista/FISTAReconstructor.pyc delete mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py~ delete mode 100644 src/Python/ccpi/fista/Reconstructor.py~ delete mode 100644 src/Python/ccpi/fista/__init__.pyc delete mode 100644 src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc delete mode 100644 src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.pyc b/src/Python/ccpi/fista/FISTAReconstructor.pyc deleted file mode 100644 index ecc4d7d..0000000 Binary files a/src/Python/ccpi/fista/FISTAReconstructor.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py~ b/src/Python/ccpi/fista/FISTAReconstructor.py~ deleted file mode 100644 index 6c7024d..0000000 --- a/src/Python/ccpi/fista/FISTAReconstructor.py~ +++ /dev/null @@ -1,349 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -#from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - - diff --git a/src/Python/ccpi/fista/Reconstructor.py~ b/src/Python/ccpi/fista/Reconstructor.py~ deleted file mode 100644 index ba67327..0000000 --- a/src/Python/ccpi/fista/Reconstructor.py~ +++ /dev/null @@ -1,598 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - -class Reconstructor: - - class Algorithm(Enum): - CGLS = alg.cgls - CGLS_CONV = alg.cgls_conv - SIRT = alg.sirt - MLEM = alg.mlem - CGLS_TICHONOV = alg.cgls_tikhonov - CGLS_TVREG = alg.cgls_TVreg - FISTA = 'fista' - - def __init__(self, algorithm = None, projection_data = None, - angles = None, center_of_rotation = None , - flat_field = None, dark_field = None, - iterations = None, resolution = None, isLogScale = False, threads = None, - normalized_projection = None): - - self.pars = dict() - self.pars['algorithm'] = algorithm - self.pars['projection_data'] = projection_data - self.pars['normalized_projection'] = normalized_projection - self.pars['angles'] = angles - self.pars['center_of_rotation'] = numpy.double(center_of_rotation) - self.pars['flat_field'] = flat_field - self.pars['iterations'] = iterations - self.pars['dark_field'] = dark_field - self.pars['resolution'] = resolution - self.pars['isLogScale'] = isLogScale - self.pars['threads'] = threads - if (iterations != None): - self.pars['iterationValues'] = numpy.zeros((iterations)) - - if projection_data != None and dark_field != None and flat_field != None: - norm = self.normalize(projection_data, dark_field, flat_field, 0.1) - self.pars['normalized_projection'] = norm - - - def setPars(self, parameters): - keys = ['algorithm','projection_data' ,'normalized_projection', \ - 'angles' , 'center_of_rotation' , 'flat_field', \ - 'iterations','dark_field' , 'resolution', 'isLogScale' , \ - 'threads' , 'iterationValues', 'regularize'] - - for k in keys: - if k not in parameters.keys(): - self.pars[k] = None - else: - self.pars[k] = parameters[k] - - - def sanityCheck(self): - projection_data = self.pars['projection_data'] - dark_field = self.pars['dark_field'] - flat_field = self.pars['flat_field'] - angles = self.pars['angles'] - - if projection_data != None and dark_field != None and \ - angles != None and flat_field != None: - data_shape = numpy.shape(projection_data) - angle_shape = numpy.shape(angles) - - if angle_shape[0] != data_shape[0]: - #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ - # (angle_shape[0] , data_shape[0]) ) - return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ - (angle_shape[0] , data_shape[0]) ) - - if data_shape[1:] != numpy.shape(flat_field): - #raise Exception('Projection and flat field dimensions do not match') - return (False , 'Projection and flat field dimensions do not match') - if data_shape[1:] != numpy.shape(dark_field): - #raise Exception('Projection and dark field dimensions do not match') - return (False , 'Projection and dark field dimensions do not match') - - return (True , '' ) - elif self.pars['normalized_projection'] != None: - data_shape = numpy.shape(self.pars['normalized_projection']) - angle_shape = numpy.shape(angles) - - if angle_shape[0] != data_shape[0]: - #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ - # (angle_shape[0] , data_shape[0]) ) - return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ - (angle_shape[0] , data_shape[0]) ) - else: - return (True , '' ) - else: - return (False , 'Not enough data') - - def reconstruct(self, parameters = None): - if parameters != None: - self.setPars(parameters) - - go , reason = self.sanityCheck() - if go: - return self._reconstruct() - else: - raise Exception(reason) - - - def _reconstruct(self, parameters=None): - if parameters!=None: - self.setPars(parameters) - parameters = self.pars - - if parameters['algorithm'] != None and \ - parameters['normalized_projection'] != None and \ - parameters['angles'] != None and \ - parameters['center_of_rotation'] != None and \ - parameters['iterations'] != None and \ - parameters['resolution'] != None and\ - parameters['threads'] != None and\ - parameters['isLogScale'] != None: - - - if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, - Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): - #store parameters - self.pars = parameters - result = parameters['algorithm']( - parameters['normalized_projection'] , - parameters['angles'], - parameters['center_of_rotation'], - parameters['resolution'], - parameters['iterations'], - parameters['threads'] , - parameters['isLogScale'] - ) - return result - elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, - Reconstructor.Algorithm.CGLS_TICHONOV, - Reconstructor.Algorithm.CGLS_TVREG) : - self.pars = parameters - result = parameters['algorithm']( - parameters['normalized_projection'] , - parameters['angles'], - parameters['center_of_rotation'], - parameters['resolution'], - parameters['iterations'], - parameters['threads'] , - parameters['regularize'], - numpy.zeros((parameters['iterations'])), - parameters['isLogScale'] - ) - - elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: - pass - - else: - if parameters['projection_data'] != None and \ - parameters['dark_field'] != None and \ - parameters['flat_field'] != None: - norm = self.normalize(parameters['projection_data'], - parameters['dark_field'], - parameters['flat_field'], 0.1) - self.pars['normalized_projection'] = norm - return self._reconstruct(parameters) - - - - def _normalize(self, projection, dark, flat, def_val=0): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - def normalize(self, projections, dark, flat, def_val=0): - norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] - return numpy.asarray (norm, dtype=numpy.float32) - - - -class FISTA(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - -#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) - -#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) -#img_cgls = recon.reconstruct() -# -#pars = dict() -#pars['algorithm'] = Reconstructor.Algorithm.SIRT -#pars['projection_data'] = proj -#pars['angles'] = angle_proj -#pars['center_of_rotation'] = numpy.double(86.2) -#pars['flat_field'] = flat -#pars['iterations'] = 15 -#pars['dark_field'] = dark -#pars['resolution'] = 1 -#pars['isLogScale'] = False -#pars['threads'] = 3 -# -#img_sirt = recon.reconstruct(pars) -# -#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM -#img_mlem = recon.reconstruct() - -############################################################ -############################################################ -#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV -#recon.pars['regularize'] = numpy.double(0.1) -#img_cgls_conv = recon.reconstruct() - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - -#viewer = edo.CILViewer() -#viewer.setInputAsNumpy(img_cgls2) -#viewer.displaySliceActor(0) -#viewer.startRenderLoop() - -import vtk - -def NumpyToVTKImageData(numpyarray): - if (len(numpy.shape(numpyarray)) == 3): - doubleImg = vtk.vtkImageData() - shape = numpy.shape(numpyarray) - doubleImg.SetDimensions(shape[0], shape[1], shape[2]) - doubleImg.SetOrigin(0,0,0) - doubleImg.SetSpacing(1,1,1) - doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) - #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) - doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) - - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - doubleImg.SetScalarComponentFromDouble( - i,j,k,0, numpyarray[i][j][k]) - #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(doubleImg) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(doubleImg) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(iMin) - shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) - shiftScaler.Update() - return shiftScaler.GetOutput() - -#writer = vtk.vtkMetaImageWriter() -#writer.SetFileName(alg + "_recon.mha") -#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) -#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.pyc b/src/Python/ccpi/fista/__init__.pyc deleted file mode 100644 index 719e264..0000000 Binary files a/src/Python/ccpi/fista/__init__.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc deleted file mode 100644 index 84f16e2..0000000 Binary files a/src/Python/ccpi/fista/__pycache__/FISTAReconstructor.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc deleted file mode 100644 index 90c23ff..0000000 Binary files a/src/Python/ccpi/fista/__pycache__/__init__.cpython-35.pyc and /dev/null differ diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index 6f46e96..a4a622b 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -9,7 +9,7 @@ Based on DemoRD2.m import h5py import numpy -from ccpi.reconstruction_dev.FISTAReconstructor import FISTAReconstructor +from ccpi.fista.FISTAReconstructor import FISTAReconstructor import astra ##def getEntry(nx, location): -- cgit v1.2.3 From 83250cee1deff04c34d5ed9ad2d9dbde09cadcf6 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 25 Aug 2017 17:04:44 +0100 Subject: minor changes --- src/Python/test_regularizers.py | 60 +++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 29 deletions(-) (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 755804a..d0bccaf 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -46,7 +46,9 @@ def nrmse(im1, im2): # u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; # u = SplitBregman_TV(single(u0), 10, 30, 1e-04); -filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif" + reader = vtk.vtkTIFFReader() reader.SetFileName(os.path.normpath(filename)) reader.Update() @@ -163,24 +165,24 @@ imgplot = plt.imshow(reg_output[-1][0]) # # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise # # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); -# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - # searching_window_ratio=3, - # similarity_window_ratio=1, - # PB_filtering_parameter=0.08) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,5) +a=fig.add_subplot(2,3,5) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) # ###################### TGV_PD ######################################### @@ -190,25 +192,25 @@ imgplot = plt.imshow(reg_output[-1][0]) # # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - # first_order_term=1.3, - # second_order_term=1, - # number_of_iterations=550) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,6) +a=fig.add_subplot(2,3,6) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) plt.show() -- cgit v1.2.3 From f7e1cf04f791898737bc15b0eb437abc2c5d9305 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 13 Sep 2017 10:41:14 +0100 Subject: cleaned up code --- src/Python/test_regularizers.py | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index d2fbca6..665a077 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -76,7 +76,7 @@ fig = plt.figure() a=fig.add_subplot(2,3,1) a.set_title('noise') -imgplot = plt.imshow(u0) +imgplot = plt.imshow(u0,cmap="gray") reg_output = [] ############################################################################## @@ -134,20 +134,6 @@ reg_output.append(out2) a=fig.add_subplot(2,3,3) -textstr = out2[-1] - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - searching_window_ratio=3, - similarity_window_ratio=1, - PB_filtering_parameter=0.08) -pars = out2[-2] -reg_output.append(out2) - -a=fig.add_subplot(2,3,5) - - textstr = out2[-1] # these are matplotlib.patch.Patch properties @@ -176,15 +162,6 @@ pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,4) -out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - searching_window_ratio=3, - similarity_window_ratio=1, - PB_filtering_parameter=0.08) -pars = out2[-2] -reg_output.append(out2) - -a=fig.add_subplot(2,3,5) - textstr = out2[-1] @@ -195,13 +172,6 @@ a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) imgplot = plt.imshow(reg_output[-1][0],cmap="gray") -textstr = out2[-1] -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0],cmap="gray") # ###################### PatchBased_Regul ######################################### # # Quick 2D denoising example in Matlab: -- cgit v1.2.3 From 5aaa46237fbf0a6bb008fe81576cabc61e3b1fce Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 3 Aug 2017 15:26:29 +0100 Subject: Added Python modules Matlab2Python_utils.cpp contains utilities for handling numpy arrays. Together with setup_test.py it creates a functional module for testing. fista_module.cpp and setup.py are meant for the real fista module. --- src/Python/Matlab2Python_utils.cpp | 206 ++++++++++++++++++++++++ src/Python/fista_module.cpp | 315 +++++++++++++++++++++++++++++++++++++ src/Python/setup.py | 58 +++++++ src/Python/setup_test.py | 58 +++++++ 4 files changed, 637 insertions(+) create mode 100644 src/Python/Matlab2Python_utils.cpp create mode 100644 src/Python/fista_module.cpp create mode 100644 src/Python/setup.py create mode 100644 src/Python/setup_test.py (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp new file mode 100644 index 0000000..138e8da --- /dev/null +++ b/src/Python/Matlab2Python_utils.cpp @@ -0,0 +1,206 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazanteev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include + +#include +#include +#include "boost/tuple/tuple.hpp" + +#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) +#include +// this trick only if compiler is MSVC +__if_not_exists(uint8_t) { typedef __int8 uint8_t; } +__if_not_exists(uint16_t) { typedef __int8 uint16_t; } +#endif + +namespace bp = boost::python; +namespace np = boost::python::numpy; + +/*! in the Matlab implementation this is called as +void mexFunction( +int nlhs, mxArray *plhs[], +int nrhs, const mxArray *prhs[]) +where: +prhs Array of pointers to the INPUT mxArrays +nrhs int number of INPUT mxArrays + +nlhs Array of pointers to the OUTPUT mxArrays +plhs int number of OUTPUT mxArrays + +*********************************************************** + +*********************************************************** +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*********************************************************** +char *mxArrayToString(const mxArray *array_ptr); +args: array_ptr Pointer to mxCHAR array. +Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array. +Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string. +*********************************************************** +mxClassID mxGetClassID(const mxArray *pm); +args: pm Pointer to an mxArray +Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types, +mxGetClassId returns a unique value identifying the class of the array contents. +Use mxIsClass to determine whether an array is of a specific user-defined type. + +mxClassID Value MATLAB Type MEX Type C Primitive Type +mxINT8_CLASS int8 int8_T char, byte +mxUINT8_CLASS uint8 uint8_T unsigned char, byte +mxINT16_CLASS int16 int16_T short +mxUINT16_CLASS uint16 uint16_T unsigned short +mxINT32_CLASS int32 int32_T int +mxUINT32_CLASS uint32 uint32_T unsigned int +mxINT64_CLASS int64 int64_T long long +mxUINT64_CLASS uint64 uint64_T unsigned long long +mxSINGLE_CLASS single float float +mxDOUBLE_CLASS double double double + +**************************************************************** +double *mxGetPr(const mxArray *pm); +args: pm Pointer to an mxArray of type double +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +**************************************************************** +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). +Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + +void mexErrMessageText(char* text) { + std::cerr << text << std::endl; +} + +/* +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*/ + +template +double mxGetScalar(const np::ndarray plh) { + return (double)bp::extract(plh[0]); +} + + + +template +T * mxGetData(const np::ndarray pm) { + //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. + //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. + /*Access the numpy array pointer: + char * get_data() const; + Returns: Array’s raw data pointer as a char + Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. + probably this would work. + A = reinterpret_cast(prhs[0]); + */ + return reinterpret_cast(prhs[0]); +} + +template +np::ndarray zeros(int dims , int * dim_array, T el) { + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + + +bp::list mexFunction( np::ndarray input ) { + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast( input.get_data() ); + int * B = reinterpret_cast( zz.get_data() ); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = (float)val; + std::memcpy(B + index , &val, sizeof(int)); + std::memcpy(C + index , &fval, sizeof(float)); + } + } + } + + + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} + + +BOOST_PYTHON_MODULE(fista) +{ + np::initialize(); + + //To specify that this module is a package + bp::object package = bp::scope(); + package.attr("__path__") = "fista"; + + np::dtype dt1 = np::dtype::get_builtin(); + np::dtype dt2 = np::dtype::get_builtin(); + + //import_array(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + //numpy_boost_python_register_type(); + def("mexFunction", mexFunction); +} \ No newline at end of file diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp new file mode 100644 index 0000000..5344083 --- /dev/null +++ b/src/Python/fista_module.cpp @@ -0,0 +1,315 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazanteev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include + +#include +#include +#include "boost/tuple/tuple.hpp" + +// include the regularizers +#include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "PatchBased_Regul_core.h" +#include "SplitBregman_TV_core.h" +#include "TGV_PD_core.h" + +#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) +#include +// this trick only if compiler is MSVC +__if_not_exists(uint8_t) { typedef __int8 uint8_t; } +__if_not_exists(uint16_t) { typedef __int8 uint16_t; } +#endif + +namespace bp = boost::python; +namespace np = boost::python::numpy; + + +/*! in the Matlab implementation this is called as +void mexFunction( +int nlhs, mxArray *plhs[], +int nrhs, const mxArray *prhs[]) +where: +prhs Array of pointers to the INPUT mxArrays +nrhs int number of INPUT mxArrays + +nlhs Array of pointers to the OUTPUT mxArrays +plhs int number of OUTPUT mxArrays + +*********************************************************** +mxGetData +args: pm Pointer to an mxArray +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +*********************************************************** +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. +*********************************************************** +char *mxArrayToString(const mxArray *array_ptr); +args: array_ptr Pointer to mxCHAR array. +Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array. +Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string. +*********************************************************** +mxClassID mxGetClassID(const mxArray *pm); +args: pm Pointer to an mxArray +Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types, +mxGetClassId returns a unique value identifying the class of the array contents. +Use mxIsClass to determine whether an array is of a specific user-defined type. + +mxClassID Value MATLAB Type MEX Type C Primitive Type +mxINT8_CLASS int8 int8_T char, byte +mxUINT8_CLASS uint8 uint8_T unsigned char, byte +mxINT16_CLASS int16 int16_T short +mxUINT16_CLASS uint16 uint16_T unsigned short +mxINT32_CLASS int32 int32_T int +mxUINT32_CLASS uint32 uint32_T unsigned int +mxINT64_CLASS int64 int64_T long long +mxUINT64_CLASS uint64 uint64_T unsigned long long +mxSINGLE_CLASS single float float +mxDOUBLE_CLASS double double double + +**************************************************************** +double *mxGetPr(const mxArray *pm); +args: pm Pointer to an mxArray of type double +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +**************************************************************** +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); +args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. + dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. + For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. + classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. + For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. + ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). + Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). + +Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). + If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not + enough free heap space to create the mxArray. +*/ + +template +np::ndarray zeros(int dims, int * dim_array, T el) { + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + + +bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) { + /* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) + * + * Input Parameters: + * 1. Noisy image/volume + * 2. lambda - regularization parameter + * 3. Number of iterations [OPTIONAL parameter] + * 4. eplsilon - tolerance constant [OPTIONAL parameter] + * 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] + * + * Output: + * Filtered/regularized image + * + * All sanity checks and default values are set in Python + */ + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + const int dim_array[3]; + float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + number_of_dims = input.get_nd(); + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -11; + } + else { + dim_array[2] = input.shape(2); + } + + /*Handling Matlab input data*/ + //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + + /*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ + A = reinterpret_cast(input.get_data()); + + + //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ + mu = (float)d_mu; + //iter = 35; /* default iterations number */ + iter = niterations; + //epsil = 0.0001; /* default tolerance constant */ + epsil = (float)d_epsil; + //methTV = 0; /* default isotropic TV penalty */ + methTV = TV_type; + //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ + //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ + //if (nrhs == 5) { + // char *penalty_type; + // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ + // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); + // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */ + // mxFree(penalty_type); + //} + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + lambda = 2.0f*mu; + count = 1; + re_old = 0.0f; + /*Handling Matlab output data*/ + dimY = dim_array[0]; dimX = dim_array[1]; dimZ = dim_array[2]; + + if (number_of_dims == 2) { + dimZ = 1; /*2D case*/ + /* + mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); +args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. + dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. + For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. + classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. + For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. + ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). + Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). + + mxCreateNumericArray initializes all its real data elements to 0. +*/ + +/* + U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +*/ + //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U = A = reinterpret_castinput.get_data(); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + /*printf("%f %i %i \n", re, ll, count); */ + + /*copyIm(U_old, U, dimX, dimY, dimZ); */ + } + printf("SB iterations stopped at iteration: %i\n", ll); + } + if (number_of_dims == 3) { + U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + /*printf("%f %i %i \n", re, ll, count); */ + re_old = re; + } + printf("SB iterations stopped at iteration: %i\n", ll); + } + bp::list result; + return result; +} + + +BOOST_PYTHON_MODULE(fista) +{ + np::initialize(); + + //To specify that this module is a package + bp::object package = bp::scope(); + package.attr("__path__") = "fista"; + + np::dtype dt1 = np::dtype::get_builtin(); + np::dtype dt2 = np::dtype::get_builtin(); + + + def("mexFunction", mexFunction); +} \ No newline at end of file diff --git a/src/Python/setup.py b/src/Python/setup.py new file mode 100644 index 0000000..ffb9c02 --- /dev/null +++ b/src/Python/setup.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform + +cil_version=os.environ['CIL_VERSION'] +if cil_version == '': + print("Please set the environmental variable CIL_VERSION") + sys.exit(1) + +library_include_path = "" +library_lib_path = "" +try: + library_include_path = os.environ['LIBRARY_INC'] + library_lib_path = os.environ['LIBRARY_LIB'] +except: + library_include_path = os.environ['PREFIX']+'/include' + pass + +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +if platform.system() == 'Windows': + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] + extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + if sys.version_info.major == 3 : + extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] + else: + extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: + extra_include_dirs += ["../ContourTree/", "../Core/","."] + if sys.version_info.major == 3: + extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] + else: + extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( + name='ccpi', + description='CCPi Core Imaging Library - FISTA Reconstruction Module', + version=cil_version, + cmdclass = {'build_ext': build_ext}, + ext_modules = [Extension("fista", + sources=[ "Matlab2Python_utils.cpp", + ], + include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), + + ], + zip_safe = False, + packages = {'ccpi','ccpi.reconstruction'}, +) diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py new file mode 100644 index 0000000..ffb9c02 --- /dev/null +++ b/src/Python/setup_test.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform + +cil_version=os.environ['CIL_VERSION'] +if cil_version == '': + print("Please set the environmental variable CIL_VERSION") + sys.exit(1) + +library_include_path = "" +library_lib_path = "" +try: + library_include_path = os.environ['LIBRARY_INC'] + library_lib_path = os.environ['LIBRARY_LIB'] +except: + library_include_path = os.environ['PREFIX']+'/include' + pass + +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +if platform.system() == 'Windows': + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] + extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + if sys.version_info.major == 3 : + extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] + else: + extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: + extra_include_dirs += ["../ContourTree/", "../Core/","."] + if sys.version_info.major == 3: + extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] + else: + extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( + name='ccpi', + description='CCPi Core Imaging Library - FISTA Reconstruction Module', + version=cil_version, + cmdclass = {'build_ext': build_ext}, + ext_modules = [Extension("fista", + sources=[ "Matlab2Python_utils.cpp", + ], + include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), + + ], + zip_safe = False, + packages = {'ccpi','ccpi.reconstruction'}, +) -- cgit v1.2.3 From 22d41e596544668e71f3abef321d48f0a54f0f53 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 3 Aug 2017 16:52:20 +0100 Subject: added FGP_TV wrapper --- src/Python/fista_module.cpp | 576 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 473 insertions(+), 103 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 5344083..2492884 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -26,12 +26,10 @@ limitations under the License. #include #include "boost/tuple/tuple.hpp" -// include the regularizers -#include "FGP_TV_core.h" -#include "LLT_model_core.h" -#include "PatchBased_Regul_core.h" #include "SplitBregman_TV_core.h" -#include "TGV_PD_core.h" +#include "FGP_TV_core.h" + + #if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) #include @@ -43,7 +41,6 @@ __if_not_exists(uint16_t) { typedef __int8 uint16_t; } namespace bp = boost::python; namespace np = boost::python::numpy; - /*! in the Matlab implementation this is called as void mexFunction( int nlhs, mxArray *plhs[], @@ -56,9 +53,7 @@ nlhs Array of pointers to the OUTPUT mxArrays plhs int number of OUTPUT mxArrays *********************************************************** -mxGetData -args: pm Pointer to an mxArray -Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. + *********************************************************** double mxGetScalar(const mxArray *pm); args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. @@ -92,77 +87,143 @@ double *mxGetPr(const mxArray *pm); args: pm Pointer to an mxArray of type double Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. **************************************************************** -mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); -args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. - dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. - For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. - classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. - For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. - ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). - Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). - +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). - If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not - enough free heap space to create the mxArray. +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + +void mexErrMessageText(char* text) { + std::cerr << text << std::endl; +} + +/* +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. */ template -np::ndarray zeros(int dims, int * dim_array, T el) { - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); - np::dtype dtype = np::dtype::get_builtin(); - np::ndarray zz = np::zeros(shape, dtype); - return zz; +double mxGetScalar(const np::ndarray plh) { + return (double)bp::extract(plh[0]); } -bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, double d_epsil, int TV_type) { - /* C-OMP implementation of Split Bregman - TV denoising-regularization model (2D/3D) - * - * Input Parameters: - * 1. Noisy image/volume - * 2. lambda - regularization parameter - * 3. Number of iterations [OPTIONAL parameter] - * 4. eplsilon - tolerance constant [OPTIONAL parameter] - * 5. TV-type: 'iso' or 'l1' [OPTIONAL parameter] - * - * Output: - * Filtered/regularized image - * - * All sanity checks and default values are set in Python + +template +T * mxGetData(const np::ndarray pm) { + //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. + //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. + /*Access the numpy array pointer: + char * get_data() const; + Returns: Array’s raw data pointer as a char + Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. + probably this would work. + A = reinterpret_cast(prhs[0]); */ + return reinterpret_cast(prhs[0]); +} + + + + +bp::list mexFunction(np::ndarray input) { + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast(input.get_data()); + int * B = reinterpret_cast(zz.get_data()); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = (float)val; + std::memcpy(B + index, &val, sizeof(int)); + std::memcpy(C + index, &fval, sizeof(float)); + } + } + } + + + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} + +bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + + // the result is in the following list + bp::list result; + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int dim_array[3]; + const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; - + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - number_of_dims = input.get_nd(); + + int number_of_dims = input.get_nd(); + int dim_array[3]; dim_array[0] = input.shape(0); dim_array[1] = input.shape(1); if (number_of_dims == 2) { - dim_array[2] = -11; + dim_array[2] = -1; } else { dim_array[2] = input.shape(2); } - /*Handling Matlab input data*/ + // Parameter handling is be done in Python + ///*Handling Matlab input data*/ //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); - /*Handling Matlab input data*/ + ///*Handling Matlab input data*/ //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ A = reinterpret_cast(input.get_data()); - //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ mu = (float)d_mu; + //iter = 35; /* default iterations number */ - iter = niterations; + //epsil = 0.0001; /* default tolerance constant */ epsil = (float)d_epsil; //methTV = 0; /* default isotropic TV penalty */ - methTV = TV_type; //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ //if (nrhs == 5) { @@ -182,34 +243,31 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, , int niterations, doub if (number_of_dims == 2) { dimZ = 1; /*2D case*/ - /* - mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, mxClassID classid, mxComplexity ComplexFlag); -args: ndim: Number of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. - dims: Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. - For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. - classid: Identifier for the class of the array, which determines the way the numerical data is represented in memory. - For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. - ComplexFlag: If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). - Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). - - mxCreateNumericArray initializes all its real data elements to 0. -*/ - -/* - U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); -*/ //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - U = A = reinterpret_castinput.get_data(); - U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + //By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ @@ -245,59 +303,370 @@ args: ndim: Number of dimensions. If you specify a value for ndim that is less /*printf("%f %i %i \n", re, ll, count); */ /*copyIm(U_old, U, dimX, dimY, dimZ); */ + result.append(npU); + result.append(ll); + } + //printf("SB iterations stopped at iteration: %i\n", ll); + if (number_of_dims == 3) { + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npDz = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npBz = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Dz = reinterpret_cast(npDz.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + Bz = reinterpret_cast(npBz.get_data()); + + copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + + /* begin outer SB iterations */ + for (ll = 0; ll 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + /*printf("%f %i %i \n", re, ll, count); */ + re_old = re; + } + //printf("SB iterations stopped at iteration: %i\n", ll); + result.append(npU); + result.append(ll); } - printf("SB iterations stopped at iteration: %i\n", ll); } - if (number_of_dims == 3) { - U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); - Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + return result; - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ +} - /* begin outer SB iterations */ +bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + + // the result is in the following list + bp::list result; + + int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + // Parameter handling is be done in Python + ///*Handling Matlab input data*/ + //if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + + ///*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ + A = reinterpret_cast(input.get_data()); + + //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ + mu = (float)d_mu; + + //iter = 35; /* default iterations number */ + + //epsil = 0.0001; /* default tolerance constant */ + epsil = (float)d_epsil; + //methTV = 0; /* default isotropic TV penalty */ + //if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5)) iter = (int)mxGetScalar(prhs[2]); /* iterations number */ + //if ((nrhs == 4) || (nrhs == 5)) epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ + //if (nrhs == 5) { + // char *penalty_type; + // penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ + // if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); + // if (strcmp(penalty_type, "l1") == 0) methTV = 1; /* enable 'l1' penalty */ + // mxFree(penalty_type); + //} + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + //plhs[1] = mxCreateNumericMatrix(1, 1, mxSINGLE_CLASS, mxREAL); + bp::tuple shape1 = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray out1 = np::zeros(shape1, dtype); + + //float *funcvalA = (float *)mxGetData(plhs[1]); + float * funcvalA = reinterpret_cast(out1.get_data()); + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = dim_array[2]; + + tk = 1.0f; + tkp1 = 1.0f; + count = 1; + re_old = 0.0f; + + if (number_of_dims == 2) { + dimZ = 1; /*2D case*/ + /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + R1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + R2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npD = np::zeros(shape, dtype); + np::ndarray npD_old = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npP1_old = np::zeros(shape, dtype); + np::ndarray npP2_old = np::zeros(shape, dtype); + np::ndarray npR1 = np::zeros(shape, dtype); + np::ndarray npR2 = zeros(2, dim_array, (float)0); + + D = reinterpret_cast(npD.get_data()); + D_old = reinterpret_cast(npD_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + P1_old = reinterpret_cast(npP1_old.get_data()); + P2_old = reinterpret_cast(npP2_old.get_data()); + R1 = reinterpret_cast(npR1.get_data()); + R2 = reinterpret_cast(npR2.get_data()); + + /* begin iterations */ for (ll = 0; ll 4) break; + if (count > 3) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j 2) { - if (re > re_old) break; + if (re > re_old) { + Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + funcval = 0.0f; + for (j = 0; j(npD); + result.append(out1); + result.append(ll); + } + if (number_of_dims == 3) { + /*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P1_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P2_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + P3_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + R3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npD = np::zeros(shape, dtype); + np::ndarray npD_old = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npP3 = np::zeros(shape, dtype); + np::ndarray npP1_old = np::zeros(shape, dtype); + np::ndarray npP2_old = np::zeros(shape, dtype); + np::ndarray npP3_old = np::zeros(shape, dtype); + np::ndarray npR1 = np::zeros(shape, dtype); + np::ndarray npR2 = np::zeros(shape, dtype); + np::ndarray npR3 = np::zeros(shape, dtype); + + D = reinterpret_cast(npD.get_data()); + D_old = reinterpret_cast(npD_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + P3 = reinterpret_cast(npP3.get_data()); + P1_old = reinterpret_cast(npP1_old.get_data()); + P2_old = reinterpret_cast(npP2_old.get_data()); + P3_old = reinterpret_cast(npP3_old.get_data()); + R1 = reinterpret_cast(npR1.get_data()); + R2 = reinterpret_cast(npR2.get_data()); + R2 = reinterpret_cast(npR3.get_data()); + /* begin iterations */ + for (ll = 0; ll 3) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j 2) { + if (re > re_old) { + Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); + funcval = 0.0f; + for (j = 0; j(npD); + result.append(out1); + result.append(ll); } - bp::list result; + return result; } - BOOST_PYTHON_MODULE(fista) { @@ -310,6 +679,7 @@ BOOST_PYTHON_MODULE(fista) np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); - def("mexFunction", mexFunction); + def("SplitBregman_TV", SplitBregman_TV); + def("FGP_TV", FGP_TV); } \ No newline at end of file -- cgit v1.2.3 From 12dbe738d5a2af5573e33a31f1745a50dba165ba Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:15:03 +0100 Subject: compilation fixes --- src/Python/setup.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index ffb9c02..a8feb1c 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -29,14 +29,14 @@ extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\env extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] extra_libraries = [] if platform.system() == 'Windows': - extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] - extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB' , '/openmp' ] + extra_include_dirs += ["..\\..\\main_func\\regularizers_CPU\\","."] if sys.version_info.major == 3 : extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] else: extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] else: - extra_include_dirs += ["../ContourTree/", "../Core/","."] + extra_include_dirs += ["../../main_func/regularizers_CPU","."] if sys.version_info.major == 3: extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] else: @@ -47,8 +47,12 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("fista", - sources=[ "Matlab2Python_utils.cpp", + ext_modules = [Extension("regularizers", + sources=["fista_module.cpp", + "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", + "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", + "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", + "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 36e4c296223f67bb917511089ec59533460f1695 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:15:17 +0100 Subject: test facility for regularizers --- src/Python/test_regularizers.py | 265 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 src/Python/test_regularizers.py (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py new file mode 100644 index 0000000..6abfba4 --- /dev/null +++ b/src/Python/test_regularizers.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Aug 4 11:10:05 2017 + +@author: ofn77899 +""" + +from ccpi.viewer.CILViewer2D import Converter +import vtk + +import regularizers +import matplotlib.pyplot as plt +import numpy as np +import os +from enum import Enum + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) + 4) + 5) + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = regularizers.SplitBregman_TV + FGP_TV = regularizers.FGP_TV + LLT_model = regularizers.LLT_model + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm): + + self.algorithm = algorithm + self.pars = self.parsForAlgorithm(algorithm) + # __init__ + + def parsForAlgorithm(self, algorithm): + pars = dict() + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + return pars + # parsForAlgorithm + + def __call__(self, input, regularization_parameter, **kwargs): + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + self.pars['input'] = input + self.pars['regularization_parameter'] = regularization_parameter + #for key, value in self.pars.items(): + # print("{0} = {1}".format(key, value)) + + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if None in self.pars: + raise Exception("Not all parameters have been provided") + else: + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + return out + + +#Example: +# figure; +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +reader = vtk.vtkTIFFReader() +reader.SetFileName(os.path.normpath(filename)) +reader.Update() +#vtk returns 3D images, let's take just the one slice there is as 2D +Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 + +#imgplot = plt.imshow(Im) +perc = 0.05 +u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +# map the u0 u0->u0>0 +f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +u0 = f(u0).astype('float32') + +# plot +fig = plt.figure() +a=fig.add_subplot(2,3,1) +a.set_title('Original') +imgplot = plt.imshow(Im) + +a=fig.add_subplot(2,3,2) +a.set_title('noise') +imgplot = plt.imshow(u0) + + +############################################################################## +# Call regularizer + +####################### SplitBregman_TV ##################################### +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + +out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) +out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) +pars = out2[2] + +a=fig.add_subplot(2,3,3) +a.set_title('SplitBregman_TV') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['TV_penalty'].name) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + +###################### FGP_TV ######################################### +# u = FGP_TV(single(u0), 0.05, 100, 1e-04); +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, + number_of_iterations=10) +pars = out2[-1] + +a=fig.add_subplot(2,3,4) +a.set_title('FGP_TV') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['TV_penalty'].name) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + +###################### LLT_model ######################################### +# * u0 = Im + .03*randn(size(Im)); % adding noise +# [Den] = LLT_model(single(u0), 10, 0.1, 1); +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10., + time_step=0.1, + tolerance_constant=1e-4, + number_of_iterations=10) +pars = out2[-1] + +a=fig.add_subplot(2,3,5) +a.set_title('LLT_model') +textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' +textstr = textstr % (pars['regularization_parameter'], + pars['number_of_iterations'], + pars['tolerance_constant'], + pars['time_step'] + ) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(out2[0]) + + + -- cgit v1.2.3 From fd496731c8e9d4975864d76dbb6574cbeee7cf98 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:16:37 +0100 Subject: Added 3 regularizers SplitBregman_TV FGP_TV LLT_model --- src/Python/fista_module.cpp | 266 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 232 insertions(+), 34 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 2492884..d890b10 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -3,7 +3,7 @@ This work is part of the Core Imaging Library developed by Visual Analytics and Imaging System Group of the Science Technology Facilities Council, STFC -Copyright 2017 Daniil Kazanteev +Copyright 2017 Daniil Kazantsev Copyright 2017 Srikanth Nagella, Edoardo Pasca Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,6 +28,8 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "utils.h" @@ -131,6 +133,18 @@ T * mxGetData(const np::ndarray pm) { return reinterpret_cast(prhs[0]); } +template +np::ndarray zeros(int dims, int * dim_array, T el) { + bp::tuple shape; + if (dims == 3) + shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + else if (dims == 2) + shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + np::ndarray zz = np::zeros(shape, dtype); + return zz; +} + @@ -169,7 +183,6 @@ bp::list mexFunction(np::ndarray input) { } } - bp::list result; result.append(number_of_dims); @@ -189,14 +202,14 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; - const int *dim_array; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; + //const int *dim_array; float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -252,26 +265,26 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); np::dtype dtype = np::dtype::get_builtin(); - np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU = np::zeros(shape, dtype); np::ndarray npU_old = np::zeros(shape, dtype); - np::ndarray npDx = np::zeros(shape, dtype); - np::ndarray npDy = np::zeros(shape, dtype); - np::ndarray npBx = np::zeros(shape, dtype); - np::ndarray npBy = np::zeros(shape, dtype); + np::ndarray npDx = np::zeros(shape, dtype); + np::ndarray npDy = np::zeros(shape, dtype); + np::ndarray npBx = np::zeros(shape, dtype); + np::ndarray npBy = np::zeros(shape, dtype); - U = reinterpret_cast(npU.get_data()); + U = reinterpret_cast(npU.get_data()); U_old = reinterpret_cast(npU_old.get_data()); - Dx = reinterpret_cast(npDx.get_data()); - Dy = reinterpret_cast(npDy.get_data()); - Bx = reinterpret_cast(npBx.get_data()); - By = reinterpret_cast(npBy.get_data()); + Dx = reinterpret_cast(npDx.get_data()); + Dy = reinterpret_cast(npDy.get_data()); + Bx = reinterpret_cast(npBx.get_data()); + By = reinterpret_cast(npBy.get_data()); + - copyIm(A, U, dimX, dimY, dimZ); /*initialize */ /* begin outer SB iterations */ - for (ll = 0; ll(npU); - result.append(ll); + } //printf("SB iterations stopped at iteration: %i\n", ll); - if (number_of_dims == 3) { + result.append(npU); + result.append(ll); + } + if (number_of_dims == 3) { /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); @@ -375,24 +390,25 @@ bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsi result.append(npU); result.append(ll); } - } return result; -} + } + + bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list bp::list result; - int number_of_dims, iter, dimX, dimY, dimZ, ll, j, count, methTV; + int number_of_dims, dimX, dimY, dimZ, ll, j, count; float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; - float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval, mu; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); - int number_of_dims = input.get_nd(); + number_of_dims = input.get_nd(); int dim_array[3]; dim_array[0] = input.shape(0); @@ -512,7 +528,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me for (j = 0; j(input.get_data()); + lambda = (float)d_lambda; + tau = (float)d_tau; + // iter is passed as parameter + epsil = (float)d_epsil; + // switcher is passed as parameter + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; dimZ = 1; + + if (number_of_dims == 2) { + /*2D case*/ + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + D1 = reinterpret_cast(npD1.get_data()); + D2 = reinterpret_cast(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der2D(U, D1, D2, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd2D(U0, U, D1, D2, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + + result.append(npU); + } + else if (number_of_dims == 3) { + /*3D case*/ + dimZ = dim_array[2]; + /*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + D3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); + if (switcher != 0) { + Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + }*/ + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + np::ndarray npD1 = np::zeros(shape, dtype); + np::ndarray npD2 = np::zeros(shape, dtype); + np::ndarray npD3 = np::zeros(shape, dtype); + np::ndarray npMap = np::zeros(shape, np::dtype::get_builtin()); + Map = reinterpret_cast(npMap.get_data()); + if (switcher != 0) { + //Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); + + Map = reinterpret_cast(npMap.get_data()); + } + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + D1 = reinterpret_cast(npD1.get_data()); + D2 = reinterpret_cast(npD2.get_data()); + D3 = reinterpret_cast(npD2.get_data()); + + /*Copy U0 to U*/ + copyIm(U0, U, dimX, dimY, dimZ); + + count = 1; + re_old = 0.0f; + + + if (switcher == 1) { + /* apply restrictive smoothing */ + calcMap(U, Map, dimX, dimY, dimZ); + /*clear outliers */ + cleanMap(Map, dimX, dimY, dimZ); + } + for (ll = 0; ll < iter; ll++) { + + copyIm(U, U_old, dimX, dimY, dimZ); + + /*estimate inner derrivatives */ + der3D(U, D1, D2, D3, dimX, dimY, dimZ); + /* calculate div^2 and update */ + div_upd3D(U0, U, D1, D2, D3, Map, switcher, dimX, dimY, dimZ, lambda, tau); + + /* calculate norm to terminate earlier */ + re = 0.0f; re1 = 0.0f; + for (j = 0; j 4) break; + + /* check that the residual norm is decreasing */ + if (ll > 2) { + if (re > re_old) break; + } + re_old = re; + + } /*end of iterations*/ + //printf("HO iterations stopped at iteration: %i\n", ll); + result.append(npU); + if (switcher != 0) result.append(npMap); + + } + return result; +} + + +BOOST_PYTHON_MODULE(regularizers) { np::initialize(); //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "fista"; + package.attr("__path__") = "regularizers"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); @@ -682,4 +879,5 @@ BOOST_PYTHON_MODULE(fista) def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); + def("LLT_model", LLT_model); } \ No newline at end of file -- cgit v1.2.3 From 4bef3726577ddf1bf2b594620e106573c6f18693 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:16:53 +0100 Subject: minor change --- src/Python/Matlab2Python_utils.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index 138e8da..6aaad90 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -128,7 +128,11 @@ T * mxGetData(const np::ndarray pm) { template np::ndarray zeros(int dims , int * dim_array, T el) { - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + bp::tuple shape; + if (dims == 3) + shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); + else if (dims == 2) + shape = bp::make_tuple(dim_array[0], dim_array[1]); np::dtype dtype = np::dtype::get_builtin(); np::ndarray zz = np::zeros(shape, dtype); return zz; @@ -163,7 +167,7 @@ bp::list mexFunction( np::ndarray input ) { for (int k = 0; k < dim_array[2]; k++) { int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; int val = (*(A + index)); - float fval = (float)val; + float fval = sqrt((float)val); std::memcpy(B + index , &val, sizeof(int)); std::memcpy(C + index , &fval, sizeof(float)); } @@ -186,7 +190,7 @@ bp::list mexFunction( np::ndarray input ) { } -BOOST_PYTHON_MODULE(fista) +BOOST_PYTHON_MODULE(prova) { np::initialize(); -- cgit v1.2.3 From 662ab4ac9c3d89cdc1527c2a2bdcf442f3b6a173 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:17:18 +0100 Subject: test for general boost::python / numpy routines --- src/Python/test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 src/Python/test.py (limited to 'src') diff --git a/src/Python/test.py b/src/Python/test.py new file mode 100644 index 0000000..e283f89 --- /dev/null +++ b/src/Python/test.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Aug 3 14:08:09 2017 + +@author: ofn77899 +""" + +import fista +import numpy as np + +a = np.asarray([i for i in range(3*4*5)]) +a = a.reshape([3,4,5]) +print (a) +b = fista.mexFunction(a) +#print (b) +print (b[4].shape) +print (b[4]) +print (b[5]) \ No newline at end of file -- cgit v1.2.3 From ecfb1146dc1de9ea6d8c6587d15417a9690f5ab4 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 4 Aug 2017 16:49:21 +0100 Subject: added PatchBased_Regul --- src/Python/fista_module.cpp | 123 +++++++++++++++++++++++++++++++++++++++++++- src/Python/setup.py | 1 + 2 files changed, 123 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index d890b10..c2d9352 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "SplitBregman_TV_core.h" #include "FGP_TV_core.h" #include "LLT_model_core.h" +#include "PatchBased_Regul_core.h" #include "utils.h" @@ -793,7 +794,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d if (switcher != 0) { Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); }*/ - bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); np::dtype dtype = np::dtype::get_builtin(); @@ -865,6 +866,126 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d } +bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) { + // the result is in the following list + bp::list result; + + int N, M, Z, numdims, SearchW, /*SimilW, SearchW_real,*/ padXY, newsizeX, newsizeY, newsizeZ, switchpad_crop; + //const int *dims; + float *A, *B = NULL, *Ap = NULL, *Bp = NULL, h, lambda; + + numdims = input.get_nd(); + int dims[3]; + + dims[0] = input.shape(0); + dims[1] = input.shape(1); + if (numdims == 2) { + dims[2] = -1; + } + else { + dims[2] = input.shape(2); + } + /*numdims = mxGetNumberOfDimensions(prhs[0]); + dims = mxGetDimensions(prhs[0]);*/ + + N = dims[0]; + M = dims[1]; + Z = dims[2]; + + //if ((numdims < 2) || (numdims > 3)) { mexErrMsgTxt("The input should be 2D image or 3D volume"); } + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + + //if (nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter"); + + ///*Handling inputs*/ + //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */ + //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ + //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */ + //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ + //lambda = (float)mxGetScalar(prhs[4]); /* regularization parameter */ + + //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); + //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + + SearchW = SearchW_real + 2 * SimilW; + + /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ + /* SimilW_full = 2*SimilW + 1; */ /* the full similarity window size */ + + + padXY = SearchW + 2 * SimilW; /* padding sizes */ + newsizeX = N + 2 * (padXY); /* the X size of the padded array */ + newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ + newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ + int N_dims[] = { newsizeX, newsizeY, newsizeZ }; + + /******************************2D case ****************************/ + if (numdims == 2) { + ///*Handling output*/ + //B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL)); + ///*allocating memory for the padded arrays */ + //Ap = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + //Bp = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); + ///**************************************************************************/ + + bp::tuple shape = bp::make_tuple(N, M); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npB = np::zeros(shape, dtype); + + shape = bp::make_tuple(newsizeX, newsizeY); + np::ndarray npAp = np::zeros(shape, dtype); + np::ndarray npBp = np::zeros(shape, dtype); + B = reinterpret_cast(npB.get_data()); + Ap = reinterpret_cast(npAp.get_data()); + Bp = reinterpret_cast(npBp.get_data()); + + /*Perform padding of image A to the size of [newsizeX * newsizeY] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + result.append(npB); + } + else + { + /******************************3D case ****************************/ + ///*Handling output*/ + //B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL)); + ///*allocating memory for the padded arrays */ + //Ap = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + //Bp = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); + /**************************************************************************/ + bp::tuple shape = bp::make_tuple(dims[0], dims[1], dims[2]); + bp::tuple shape_AB = bp::make_tuple(N_dims[0], N_dims[1], N_dims[2]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npB = np::zeros(shape, dtype); + np::ndarray npAp = np::zeros(shape_AB, dtype); + np::ndarray npBp = np::zeros(shape_AB, dtype); + B = reinterpret_cast(npB.get_data()); + Ap = reinterpret_cast(npAp.get_data()); + Bp = reinterpret_cast(npBp.get_data()); + /*Perform padding of image A to the size of [newsizeX * newsizeY * newsizeZ] */ + switchpad_crop = 0; /*padding*/ + pad_crop(A, Ap, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + /* Do PB regularization with the padded array */ + PB_FUNC3D(Ap, Bp, newsizeY, newsizeX, newsizeZ, padXY, SearchW, SimilW, (float)h, (float)lambda); + + switchpad_crop = 1; /*cropping*/ + pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + + result.append(npB); + } /*end else ndims*/ + + return result; +} + BOOST_PYTHON_MODULE(regularizers) { np::initialize(); diff --git a/src/Python/setup.py b/src/Python/setup.py index a8feb1c..a4eed14 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -52,6 +52,7 @@ setup( "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", + "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 753f3477bde8fc250adc542bbeffc03d369107e1 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 7 Aug 2017 17:21:12 +0100 Subject: added TGV_PD, removed useless code --- src/Python/fista_module.cpp | 245 ++++++++++++++++++++++++++------------------ 1 file changed, 146 insertions(+), 99 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index c2d9352..eacda3d 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -30,6 +30,7 @@ limitations under the License. #include "FGP_TV_core.h" #include "LLT_model_core.h" #include "PatchBased_Regul_core.h" +#include "TGV_PD_core.h" #include "utils.h" @@ -103,101 +104,8 @@ If unsuccessful in a MEX file, the MEX file terminates and returns control to th enough free heap space to create the mxArray. */ -void mexErrMessageText(char* text) { - std::cerr << text << std::endl; -} - -/* -double mxGetScalar(const mxArray *pm); -args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. -Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray. In C, mxGetScalar returns a double. -*/ - -template -double mxGetScalar(const np::ndarray plh) { - return (double)bp::extract(plh[0]); -} - - - -template -T * mxGetData(const np::ndarray pm) { - //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. - //Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. - /*Access the numpy array pointer: - char * get_data() const; - Returns: Array’s raw data pointer as a char - Note: This returns char so stride math works properly on it.User will have to reinterpret_cast it. - probably this would work. - A = reinterpret_cast(prhs[0]); - */ - return reinterpret_cast(prhs[0]); -} - -template -np::ndarray zeros(int dims, int * dim_array, T el) { - bp::tuple shape; - if (dims == 3) - shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); - else if (dims == 2) - shape = bp::make_tuple(dim_array[0], dim_array[1]); - np::dtype dtype = np::dtype::get_builtin(); - np::ndarray zz = np::zeros(shape, dtype); - return zz; -} - - -bp::list mexFunction(np::ndarray input) { - int number_of_dims = input.get_nd(); - int dim_array[3]; - - dim_array[0] = input.shape(0); - dim_array[1] = input.shape(1); - if (number_of_dims == 2) { - dim_array[2] = -1; - } - else { - dim_array[2] = input.shape(2); - } - - /**************************************************************************/ - np::ndarray zz = zeros(3, dim_array, (int)0); - np::ndarray fzz = zeros(3, dim_array, (float)0); - /**************************************************************************/ - - int * A = reinterpret_cast(input.get_data()); - int * B = reinterpret_cast(zz.get_data()); - float * C = reinterpret_cast(fzz.get_data()); - - //Copy data and cast - for (int i = 0; i < dim_array[0]; i++) { - for (int j = 0; j < dim_array[1]; j++) { - for (int k = 0; k < dim_array[2]; k++) { - int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; - int val = (*(A + index)); - float fval = (float)val; - std::memcpy(B + index, &val, sizeof(int)); - std::memcpy(C + index, &fval, sizeof(float)); - } - } - } - - bp::list result; - - result.append(number_of_dims); - result.append(dim_array[0]); - result.append(dim_array[1]); - result.append(dim_array[2]); - result.append(zz); - result.append(fzz); - - //result.append(tup); - return result; - -} - bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { // the result is in the following list @@ -487,7 +395,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me np::ndarray npP1_old = np::zeros(shape, dtype); np::ndarray npP2_old = np::zeros(shape, dtype); np::ndarray npR1 = np::zeros(shape, dtype); - np::ndarray npR2 = zeros(2, dim_array, (float)0); + np::ndarray npR2 = np::zeros(shape, dtype); D = reinterpret_cast(npD.get_data()); D_old = reinterpret_cast(npD_old.get_data()); @@ -866,7 +774,7 @@ bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, d } -bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, double d_h, double d_lambda) { +bp::list PatchBased_Regul(np::ndarray input, double d_lambda, int SearchW_real, int SimilW, double d_h) { // the result is in the following list bp::list result; @@ -899,6 +807,7 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub ///*Handling inputs*/ //A = (float *)mxGetData(prhs[0]); /* the image to regularize/filter */ + A = reinterpret_cast(input.get_data()); //SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ //SimilW = (int)mxGetScalar(prhs[2]); /* the similarity window ratio */ //h = (float)mxGetScalar(prhs[3]); /* parameter for the PB filtering function */ @@ -907,6 +816,8 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub //if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); //if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + lambda = (float)d_lambda; + h = (float)d_h; SearchW = SearchW_real + 2 * SimilW; /* SearchW_full = 2*SearchW + 1; */ /* the full searching window size */ @@ -918,7 +829,6 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ int N_dims[] = { newsizeX, newsizeY, newsizeZ }; - /******************************2D case ****************************/ if (numdims == 2) { ///*Handling output*/ @@ -943,12 +853,13 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub /*Perform padding of image A to the size of [newsizeX * newsizeY] */ switchpad_crop = 0; /*padding*/ pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); - + /* Do PB regularization with the padded array */ PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); - + switchpad_crop = 1; /*cropping*/ pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); + result.append(npB); } else @@ -983,6 +894,141 @@ bp::list PatchBased_Regul(np::ndarray input, int SearchW_real, int SimilW, doub result.append(npB); } /*end else ndims*/ + return result; +} + +bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_alpha0, int iter) { + // the result is in the following list + bp::list result; + int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll; + //const int *dim_array; + float *A, *U, *U_old, *P1, *P2, *P3, *Q1, *Q2, *Q3, *Q4, *Q5, *Q6, *Q7, *Q8, *Q9, *V1, *V1_old, *V2, *V2_old, *V3, *V3_old, lambda, L2, tau, sigma, alpha1, alpha0; + + //number_of_dims = mxGetNumberOfDimensions(prhs[0]); + //dim_array = mxGetDimensions(prhs[0]); + number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + /*Handling Matlab input data*/ + //A = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/ + //if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + + A = reinterpret_cast(input.get_data()); + + //lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/ + //alpha1 = (float)mxGetScalar(prhs[2]); /*first-order term*/ + //alpha0 = (float)mxGetScalar(prhs[3]); /*second-order term*/ + //iter = (int)mxGetScalar(prhs[4]); /*iterations number*/ + //if (nrhs != 5) mexErrMsgTxt("Five input parameters is reqired: Image(2D/3D), Regularization parameter, alpha1, alpha0, Iterations"); + lambda = (float)d_lambda; + alpha1 = (float)d_alpha1; + alpha0 = (float)d_alpha0; + + /*Handling Matlab output data*/ + dimX = dim_array[0]; dimY = dim_array[1]; + + if (number_of_dims == 2) { + /*2D case*/ + dimZ = 1; + bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); + np::dtype dtype = np::dtype::get_builtin(); + + np::ndarray npU = np::zeros(shape, dtype); + np::ndarray npP1 = np::zeros(shape, dtype); + np::ndarray npP2 = np::zeros(shape, dtype); + np::ndarray npQ1 = np::zeros(shape, dtype); + np::ndarray npQ2 = np::zeros(shape, dtype); + np::ndarray npQ3 = np::zeros(shape, dtype); + np::ndarray npV1 = np::zeros(shape, dtype); + np::ndarray npV1_old = np::zeros(shape, dtype); + np::ndarray npV2 = np::zeros(shape, dtype); + np::ndarray npV2_old = np::zeros(shape, dtype); + np::ndarray npU_old = np::zeros(shape, dtype); + + U = reinterpret_cast(npU.get_data()); + U_old = reinterpret_cast(npU_old.get_data()); + P1 = reinterpret_cast(npP1.get_data()); + P2 = reinterpret_cast(npP2.get_data()); + Q1 = reinterpret_cast(npQ1.get_data()); + Q2 = reinterpret_cast(npQ2.get_data()); + Q3 = reinterpret_cast(npQ3.get_data()); + V1 = reinterpret_cast(npV1.get_data()); + V1_old = reinterpret_cast(npV1_old.get_data()); + V2 = reinterpret_cast(npV2.get_data()); + V2_old = reinterpret_cast(npV2_old.get_data()); + //U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + /*dual variables*/ + /*P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + Q1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Q2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + Q3 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + + V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + /*printf("%i \n", i);*/ + L2 = 12.0; /*Lipshitz constant*/ + tau = 1.0 / pow(L2, 0.5); + sigma = 1.0 / pow(L2, 0.5); + + /*Copy A to U*/ + copyIm(A, U, dimX, dimY, dimZ); + /* Here primal-dual iterations begin for 2D */ + for (ll = 0; ll < iter; ll++) { + + /* Calculate Dual Variable P */ + DualP_2D(U, V1, V2, P1, P2, dimX, dimY, dimZ, sigma); + + /*Projection onto convex set for P*/ + ProjP_2D(P1, P2, dimX, dimY, dimZ, alpha1); + + /* Calculate Dual Variable Q */ + DualQ_2D(V1, V2, Q1, Q2, Q3, dimX, dimY, dimZ, sigma); + + /*Projection onto convex set for Q*/ + ProjQ_2D(Q1, Q2, Q3, dimX, dimY, dimZ, alpha0); + + /*saving U into U_old*/ + copyIm(U, U_old, dimX, dimY, dimZ); + + /*adjoint operation -> divergence and projection of P*/ + DivProjP_2D(U, A, P1, P2, dimX, dimY, dimZ, lambda, tau); + + /*get updated solution U*/ + newU(U, U_old, dimX, dimY, dimZ); + + /*saving V into V_old*/ + copyIm(V1, V1_old, dimX, dimY, dimZ); + copyIm(V2, V2_old, dimX, dimY, dimZ); + + /* upd V*/ + UpdV_2D(V1, V2, P1, P2, Q1, Q2, Q3, dimX, dimY, dimZ, tau); + + /*get new V*/ + newU(V1, V1_old, dimX, dimY, dimZ); + newU(V2, V2_old, dimX, dimY, dimZ); + } /*end of iterations*/ + + result.append(npU); + } + + + + return result; } @@ -997,8 +1043,9 @@ BOOST_PYTHON_MODULE(regularizers) np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); - def("mexFunction", mexFunction); def("SplitBregman_TV", SplitBregman_TV); def("FGP_TV", FGP_TV); def("LLT_model", LLT_model); + def("PatchBased_Regul", PatchBased_Regul); + def("TGV_PD", TGV_PD); } \ No newline at end of file -- cgit v1.2.3 From 4534a11d1c32a65484f4f38348c27a7bb2d9ad19 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 7 Aug 2017 17:21:54 +0100 Subject: added TGV_PD --- src/Python/setup.py | 1 + src/Python/test_regularizers.py | 195 ++++++++++++++++++++++++++++++++++------ 2 files changed, 168 insertions(+), 28 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index a4eed14..0468722 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -53,6 +53,7 @@ setup( "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", + "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c", "..\\..\\main_func\\regularizers_CPU\\utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 6abfba4..6a34749 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -47,6 +47,8 @@ class Regularizer(): SplitBregman_TV = regularizers.SplitBregman_TV FGP_TV = regularizers.FGP_TV LLT_model = regularizers.LLT_model + PatchBased_Regul = regularizers.PatchBased_Regul + TGV_PD = regularizers.TGV_PD # Algorithm class TotalVariationPenalty(Enum): @@ -55,13 +57,17 @@ class Regularizer(): # TotalVariationPenalty def __init__(self , algorithm): - + self.setAlgorithm ( algorithm ) + # __init__ + + def setAlgorithm(self, algorithm): self.algorithm = algorithm self.pars = self.parsForAlgorithm(algorithm) - # __init__ + # setAlgorithm def parsForAlgorithm(self, algorithm): pars = dict() + if algorithm == Regularizer.Algorithm.SplitBregman_TV : pars['algorithm'] = algorithm pars['input'] = None @@ -69,6 +75,7 @@ class Regularizer(): pars['number_of_iterations'] = 35 pars['tolerance_constant'] = 0.0001 pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.FGP_TV : pars['algorithm'] = algorithm pars['input'] = None @@ -76,6 +83,7 @@ class Regularizer(): pars['number_of_iterations'] = 50 pars['tolerance_constant'] = 0.001 pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + elif algorithm == Regularizer.Algorithm.LLT_model: pars['algorithm'] = algorithm pars['input'] = None @@ -85,6 +93,24 @@ class Regularizer(): pars['tolerance_constant'] = None pars['restrictive_Z_smoothing'] = 0 + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + + return pars # parsForAlgorithm @@ -98,6 +124,8 @@ class Regularizer(): self.pars['regularization_parameter'] = regularization_parameter #for key, value in self.pars.items(): # print("{0} = {1}".format(key, value)) + if None in self.pars: + raise Exception("Not all parameters have been provided") if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : return self.algorithm(input, regularization_parameter, @@ -112,15 +140,27 @@ class Regularizer(): elif self.algorithm == Regularizer.Algorithm.LLT_model : #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) # no default - if None in self.pars: - raise Exception("Not all parameters have been provided") - else: - return self.algorithm(input, - regularization_parameter, - self.pars['time_step'] , - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['restrictive_Z_smoothing'] ) + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # __call__ @@ -142,13 +182,40 @@ class Regularizer(): @staticmethod def LLT_model(input, regularization_parameter , time_step, number_of_iterations, tolerance_constant, restrictive_Z_smoothing=0): - reg = Regularizer(Regularizer.Algorithm.FGP_TV) + reg = Regularizer(Regularizer.Algorithm.LLT_model) out = list( reg(input, regularization_parameter, time_step=time_step, number_of_iterations=number_of_iterations, tolerance_constant=tolerance_constant, restrictive_Z_smoothing=restrictive_Z_smoothing) ) out.append(reg.pars) return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + return out #Example: @@ -171,17 +238,17 @@ u0 = Im + (perc* np.random.normal(size=np.shape(Im))) f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) u0 = f(u0).astype('float32') -# plot +## plot fig = plt.figure() -a=fig.add_subplot(2,3,1) -a.set_title('Original') -imgplot = plt.imshow(Im) +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) -a=fig.add_subplot(2,3,2) +a=fig.add_subplot(2,3,1) a.set_title('noise') imgplot = plt.imshow(u0) - +reg_output = [] ############################################################################## # Call regularizer @@ -199,8 +266,9 @@ out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., numbe TV_Penalty=Regularizer.TotalVariationPenalty.l1) out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) pars = out2[2] +reg_output.append(out2) -a=fig.add_subplot(2,3,3) +a=fig.add_subplot(2,3,2) a.set_title('SplitBregman_TV') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' textstr = textstr % (pars['regularization_parameter'], @@ -213,7 +281,7 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); @@ -221,7 +289,9 @@ out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, number_of_iterations=10) pars = out2[-1] -a=fig.add_subplot(2,3,4) +reg_output.append(out2) + +a=fig.add_subplot(2,3,3) a.set_title('FGP_TV') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' textstr = textstr % (pars['regularization_parameter'], @@ -234,18 +304,23 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) ###################### LLT_model ######################################### # * u0 = Im + .03*randn(size(Im)); % adding noise # [Den] = LLT_model(single(u0), 10, 0.1, 1); -out2 = Regularizer.LLT_model(input=u0, regularization_parameter=10., - time_step=0.1, - tolerance_constant=1e-4, - number_of_iterations=10) +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +#input, regularization_parameter , time_step, number_of_iterations, +# tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) pars = out2[-1] -a=fig.add_subplot(2,3,5) +reg_output.append(out2) + +a=fig.add_subplot(2,3,4) a.set_title('LLT_model') textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' textstr = textstr % (pars['regularization_parameter'], @@ -259,7 +334,71 @@ props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(out2[0]) +imgplot = plt.imshow(reg_output[-1][0]) + +###################### PatchBased_Regul ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-1] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) +a.set_title('PatchBased_Regul') +textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f' +textstr = textstr % (pars['regularization_parameter'], + pars['searching_window_ratio'], + pars['similarity_window_ratio'], + pars['PB_filtering_parameter']) + + + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) + + +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-1] +reg_output.append(out2) + +a=fig.add_subplot(2,3,6) +a.set_title('TGV_PD') +textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d' +textstr = textstr % (pars['regularization_parameter'], + pars['first_order_term'], + pars['second_order_term'], + pars['number_of_iterations']) + + + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) -- cgit v1.2.3 From 3fffd568589137b17d1fbe44e55a757e3745a3b1 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 11 Oct 2017 15:42:05 +0100 Subject: added simple_astra_test.py --- src/Python/test/simple_astra_test.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/Python/test/simple_astra_test.py (limited to 'src') diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py new file mode 100644 index 0000000..905eeea --- /dev/null +++ b/src/Python/test/simple_astra_test.py @@ -0,0 +1,25 @@ +import astra +import numpy + +detectorSpacingX = 1.0 +detectorSpacingY = 1.0 +det_row_count = 128 +det_col_count = 128 + +angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +image_size_x = 64 +image_size_y = 64 +image_size_z = 32 + +vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z) + +x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x) +sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom) -- cgit v1.2.3 From 0611d34c31fa1e706c3bcd7e17651f7555469e00 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 17 Aug 2017 16:33:09 +0100 Subject: initial revision --- src/Python/test/simple_astra_test.py | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100644 src/Python/test/simple_astra_test.py (limited to 'src') diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py deleted file mode 100644 index 905eeea..0000000 --- a/src/Python/test/simple_astra_test.py +++ /dev/null @@ -1,25 +0,0 @@ -import astra -import numpy - -detectorSpacingX = 1.0 -detectorSpacingY = 1.0 -det_row_count = 128 -det_col_count = 128 - -angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi - -proj_geom = astra.creators.create_proj_geom('parallel3d', - detectorSpacingX, - detectorSpacingY, - det_row_count, - det_col_count, - angles_rad) - -image_size_x = 64 -image_size_y = 64 -image_size_z = 32 - -vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z) - -x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x) -sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom) -- cgit v1.2.3 From bc29e0690d856ad9dd147b435d34c5761556a1e5 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 12:55:19 +0100 Subject: Regularizer.pyfirst commit --- src/Python/Regularizer.py | 322 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 src/Python/Regularizer.py (limited to 'src') diff --git a/src/Python/Regularizer.py b/src/Python/Regularizer.py new file mode 100644 index 0000000..15dbbb4 --- /dev/null +++ b/src/Python/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 8 14:26:00 2017 + +@author: ofn77899 +""" + +import regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 CPU (OMP) regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) LLT_model + 4) PatchBased_Regul + 5) TGV_PD + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = regularizers.SplitBregman_TV + FGP_TV = regularizers.FGP_TV + LLT_model = regularizers.LLT_model + PatchBased_Regul = regularizers.PatchBased_Regul + TGV_PD = regularizers.TGV_PD + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm, debug = True): + self.setAlgorithm ( algorithm ) + self.debug = debug + # __init__ + + def setAlgorithm(self, algorithm): + self.algorithm = algorithm + self.pars = self.getDefaultParsForAlgorithm(algorithm) + # setAlgorithm + + def getDefaultParsForAlgorithm(self, algorithm): + pars = dict() + + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + else: + raise Exception('Unknown regularizer algorithm') + + return pars + # parsForAlgorithm + + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.pars.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + def getParameter(self, **kwargs): + ret = {} + for key , value in kwargs.items(): + if key in self.pars.keys(): + ret[key] = self.pars[key] + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + + def __call__(self, input = None, regularization_parameter = None, **kwargs): + '''Actual call for the regularizer. + + One can either set the regularization parameters first and then call the + algorithm or set the regularization parameter during the call (as + is done in the static methods). + ''' + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + if input is not None: + self.pars['input'] = input + if regularization_parameter is not None: + self.pars['regularization_parameter'] = regularization_parameter + + if self.debug: + print ("--------------------------------------------------") + for key, value in self.pars.items(): + if key== 'algorithm' : + print("{0} = {1}".format(key, value.__name__)) + elif key == 'input': + print("{0} = {1}".format(key, np.shape(value))) + else: + print("{0} = {1}".format(key, value)) + + + if None in self.pars: + raise Exception("Not all parameters have been provided") + + input = self.pars['input'] + regularization_parameter = self.pars['regularization_parameter'] + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if len(np.shape(input)) == 2: + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + elif len(np.shape(input)) == 3: + #assuming it's 3D + # run independent calls on each slice + out3d = input.copy() + for i in range(np.shape(input)[2]): + out = self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # copy the result in the 3D image + out3d.T[i] = out[0].copy() + # append the rest of the info that the algorithm returns + output = [out3d] + for i in range(1,len(out)): + output.append(out[i]) + return output + + + + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.LLT_model) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + start_time = timeit.default_timer() + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + + return out + + def printParametersToString(self): + txt = r'' + for key, value in self.pars.items(): + if key== 'algorithm' : + txt += "{0} = {1}".format(key, value.__name__) + elif key == 'input': + txt += "{0} = {1}".format(key, np.shape(value)) + else: + txt += "{0} = {1}".format(key, value) + txt += '\n' + return txt + -- cgit v1.2.3 From 48a4d5315b4b6ca62eaa931912b6a02993979688 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 12:56:09 +0100 Subject: Test module for Boost Python currently can pass a function to the C++ layer to be evaluated. --- src/Python/Matlab2Python_utils.cpp | 68 +++++++++++++++++++++++++++++++++++++- src/Python/setup_test.py | 6 ++-- src/Python/test.py | 34 ++++++++++++++++--- 3 files changed, 99 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index 6aaad90..e15d738 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -175,6 +175,71 @@ bp::list mexFunction( np::ndarray input ) { } + bp::list result; + + result.append(number_of_dims); + result.append(dim_array[0]); + result.append(dim_array[1]); + result.append(dim_array[2]); + result.append(zz); + result.append(fzz); + + //result.append(tup); + return result; + +} +bp::list doSomething(np::ndarray input, PyObject *pyobj , PyObject *pyobj2) { + + boost::python::object output(boost::python::handle<>(boost::python::borrowed(pyobj))); + int isOutput = !(output == boost::python::api::object()); + + boost::python::object calculate(boost::python::handle<>(boost::python::borrowed(pyobj2))); + int isCalculate = !(calculate == boost::python::api::object()); + + int number_of_dims = input.get_nd(); + int dim_array[3]; + + dim_array[0] = input.shape(0); + dim_array[1] = input.shape(1); + if (number_of_dims == 2) { + dim_array[2] = -1; + } + else { + dim_array[2] = input.shape(2); + } + + /**************************************************************************/ + np::ndarray zz = zeros(3, dim_array, (int)0); + np::ndarray fzz = zeros(3, dim_array, (float)0); + /**************************************************************************/ + + int * A = reinterpret_cast(input.get_data()); + int * B = reinterpret_cast(zz.get_data()); + float * C = reinterpret_cast(fzz.get_data()); + + //Copy data and cast + for (int i = 0; i < dim_array[0]; i++) { + for (int j = 0; j < dim_array[1]; j++) { + for (int k = 0; k < dim_array[2]; k++) { + int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; + int val = (*(A + index)); + float fval = sqrt((float)val); + std::memcpy(B + index, &val, sizeof(int)); + std::memcpy(C + index, &fval, sizeof(float)); + // if the PyObj is not None evaluate the function + if (isOutput) + output(fval); + if (isCalculate) { + float nfval = (float)bp::extract(calculate(val)); + if (isOutput) + output(nfval); + std::memcpy(C + index, &nfval, sizeof(float)); + } + } + } + } + + bp::list result; result.append(number_of_dims); @@ -196,7 +261,7 @@ BOOST_PYTHON_MODULE(prova) //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "fista"; + package.attr("__path__") = "prova"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); @@ -207,4 +272,5 @@ BOOST_PYTHON_MODULE(prova) //numpy_boost_python_register_type(); //numpy_boost_python_register_type(); def("mexFunction", mexFunction); + def("doSomething", doSomething); } \ No newline at end of file diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py index ffb9c02..7c86175 100644 --- a/src/Python/setup_test.py +++ b/src/Python/setup_test.py @@ -30,13 +30,13 @@ extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x' extra_libraries = [] if platform.system() == 'Windows': extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB'] - extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] + #extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] if sys.version_info.major == 3 : extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] else: extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] else: - extra_include_dirs += ["../ContourTree/", "../Core/","."] + #extra_include_dirs += ["../ContourTree/", "../Core/","."] if sys.version_info.major == 3: extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] else: @@ -47,7 +47,7 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("fista", + ext_modules = [Extension("prova", sources=[ "Matlab2Python_utils.cpp", ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), diff --git a/src/Python/test.py b/src/Python/test.py index e283f89..db47380 100644 --- a/src/Python/test.py +++ b/src/Python/test.py @@ -5,14 +5,38 @@ Created on Thu Aug 3 14:08:09 2017 @author: ofn77899 """ -import fista +import prova import numpy as np -a = np.asarray([i for i in range(3*4*5)]) -a = a.reshape([3,4,5]) +a = np.asarray([i for i in range(1*2*3)]) +a = a.reshape([1,2,3]) print (a) -b = fista.mexFunction(a) +b = prova.mexFunction(a) #print (b) print (b[4].shape) print (b[4]) -print (b[5]) \ No newline at end of file +print (b[5]) + +def print_element(input): + print ("f: {0}".format(input)) + +prova.doSomething(a, print_element, None) + +c = [] +def append_to_list(input, shouldPrint=False): + c.append(input) + if shouldPrint: + print ("{0} appended to list {1}".format(input, c)) + +def element_wise_algebra(input, shouldPrint=True): + ret = input - 7 + if shouldPrint: + print ("element_wise {0}".format(ret)) + return ret + +prova.doSomething(a, append_to_list, None) +#print ("this is c: {0}".format(c)) + +b = prova.doSomething(a, None, element_wise_algebra) +#print (a) +print (b[5]) -- cgit v1.2.3 From c28385d0dd5efcb32bd2c33e4bd93ba61f959b3f Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:27:28 +0100 Subject: updated test for regularizer API --- src/Python/test_regularizers.py | 590 ++++++++++++++++++++-------------------- 1 file changed, 290 insertions(+), 300 deletions(-) (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 6a34749..5d25f02 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -8,216 +8,37 @@ Created on Fri Aug 4 11:10:05 2017 from ccpi.viewer.CILViewer2D import Converter import vtk -import regularizers import matplotlib.pyplot as plt import numpy as np import os from enum import Enum - -class Regularizer(): - '''Class to handle regularizer algorithms to be used during reconstruction - - Currently 5 regularization algorithms are available: - - 1) SplitBregman_TV - 2) FGP_TV - 3) - 4) - 5) - - Usage: - the regularizer can be invoked as object or as static method - Depending on the actual regularizer the input parameter may vary, and - a different default setting is defined. - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - - out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., - number_of_iterations=30, tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - A number of optional parameters can be passed or skipped - out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) - - ''' - class Algorithm(Enum): - SplitBregman_TV = regularizers.SplitBregman_TV - FGP_TV = regularizers.FGP_TV - LLT_model = regularizers.LLT_model - PatchBased_Regul = regularizers.PatchBased_Regul - TGV_PD = regularizers.TGV_PD - # Algorithm - - class TotalVariationPenalty(Enum): - isotropic = 0 - l1 = 1 - # TotalVariationPenalty - - def __init__(self , algorithm): - self.setAlgorithm ( algorithm ) - # __init__ - - def setAlgorithm(self, algorithm): - self.algorithm = algorithm - self.pars = self.parsForAlgorithm(algorithm) - # setAlgorithm - - def parsForAlgorithm(self, algorithm): - pars = dict() - - if algorithm == Regularizer.Algorithm.SplitBregman_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 35 - pars['tolerance_constant'] = 0.0001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.FGP_TV : - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['number_of_iterations'] = 50 - pars['tolerance_constant'] = 0.001 - pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic - - elif algorithm == Regularizer.Algorithm.LLT_model: - pars['algorithm'] = algorithm - pars['input'] = None - pars['regularization_parameter'] = None - pars['time_step'] = None - pars['number_of_iterations'] = None - pars['tolerance_constant'] = None - pars['restrictive_Z_smoothing'] = 0 - - elif algorithm == Regularizer.Algorithm.PatchBased_Regul: - pars['algorithm'] = algorithm - pars['input'] = None - pars['searching_window_ratio'] = None - pars['similarity_window_ratio'] = None - pars['PB_filtering_parameter'] = None - pars['regularization_parameter'] = None - - elif algorithm == Regularizer.Algorithm.TGV_PD: - pars['algorithm'] = algorithm - pars['input'] = None - pars['first_order_term'] = None - pars['second_order_term'] = None - pars['number_of_iterations'] = None - pars['regularization_parameter'] = None - - - - return pars - # parsForAlgorithm - - def __call__(self, input, regularization_parameter, **kwargs): - - if kwargs is not None: - for key, value in kwargs.items(): - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - self.pars['input'] = input - self.pars['regularization_parameter'] = regularization_parameter - #for key, value in self.pars.items(): - # print("{0} = {1}".format(key, value)) - if None in self.pars: - raise Exception("Not all parameters have been provided") - - if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : - return self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.FGP_TV : - return self.algorithm(input, regularization_parameter, - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['TV_penalty'].value ) - elif self.algorithm == Regularizer.Algorithm.LLT_model : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, - regularization_parameter, - self.pars['time_step'] , - self.pars['number_of_iterations'], - self.pars['tolerance_constant'], - self.pars['restrictive_Z_smoothing'] ) - elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, regularization_parameter, - self.pars['searching_window_ratio'] , - self.pars['similarity_window_ratio'] , - self.pars['PB_filtering_parameter']) - elif self.algorithm == Regularizer.Algorithm.TGV_PD : - #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) - # no default - return self.algorithm(input, regularization_parameter, - self.pars['first_order_term'] , - self.pars['second_order_term'] , - self.pars['number_of_iterations']) - - - - # __call__ - - @staticmethod - def SplitBregman_TV(input, regularization_parameter , **kwargs): - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - out = list( reg(input, regularization_parameter, **kwargs) ) - out.append(reg.pars) - return out - - @staticmethod - def FGP_TV(input, regularization_parameter , **kwargs): - reg = Regularizer(Regularizer.Algorithm.FGP_TV) - out = list( reg(input, regularization_parameter, **kwargs) ) - out.append(reg.pars) - return out - - @staticmethod - def LLT_model(input, regularization_parameter , time_step, number_of_iterations, - tolerance_constant, restrictive_Z_smoothing=0): - reg = Regularizer(Regularizer.Algorithm.LLT_model) - out = list( reg(input, regularization_parameter, time_step=time_step, - number_of_iterations=number_of_iterations, - tolerance_constant=tolerance_constant, - restrictive_Z_smoothing=restrictive_Z_smoothing) ) - out.append(reg.pars) - return out - - @staticmethod - def PatchBased_Regul(input, regularization_parameter, - searching_window_ratio, - similarity_window_ratio, - PB_filtering_parameter): - reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) - out = list( reg(input, - regularization_parameter, - searching_window_ratio=searching_window_ratio, - similarity_window_ratio=similarity_window_ratio, - PB_filtering_parameter=PB_filtering_parameter ) - ) - out.append(reg.pars) - return out - - @staticmethod - def TGV_PD(input, regularization_parameter , first_order_term, - second_order_term, number_of_iterations): - - reg = Regularizer(Regularizer.Algorithm.TGV_PD) - out = list( reg(input, regularization_parameter, - first_order_term=first_order_term, - second_order_term=second_order_term, - number_of_iterations=number_of_iterations) ) - out.append(reg.pars) - return out - - +import timeit + +from Regularizer import Regularizer + +############################################################################### +#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 +#NRMSE a normalization of the root of the mean squared error +#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum +# intensity from the two images being compared, and respectively the same for +# minval. RMSE is given by the square root of MSE: +# sqrt[(sum(A - B) ** 2) / |A|], +# where |A| means the number of elements in A. By doing this, the maximum value +# given by RMSE is maxval. + +def nrmse(im1, im2): + a, b = im1.shape + rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b)) + max_val = max(np.max(im1), np.max(im2)) + min_val = min(np.min(im1), np.min(im2)) + return 1 - (rmse / (max_val - min_val)) +############################################################################### + +############################################################################### +# +# 2D Regularizers +# +############################################################################### #Example: # figure; # Im = double(imread('lena_gray_256.tif'))/255; % loading image @@ -255,49 +76,55 @@ reg_output = [] ####################### SplitBregman_TV ##################################### # u = SplitBregman_TV(single(u0), 10, 30, 1e-04); -reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +use_object = True +if use_object: + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + # or + # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + #TV_Penalty=Regularizer.TotalVariationPenalty.l1) + plotme = reg() [0] + pars = reg.pars + textstr = reg.printParametersToString() + + #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + # TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) -out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, - #tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - -out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) -out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) -pars = out2[2] -reg_output.append(out2) +else: + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + pars = out2[2] + reg_output.append(out2) + plotme = reg_output[-1][0] + textstr = out2[-1] a=fig.add_subplot(2,3,2) -a.set_title('SplitBregman_TV') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['TV_penalty'].name) + # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) +imgplot = plt.imshow(plotme) ###################### FGP_TV ######################################### # u = FGP_TV(single(u0), 0.05, 100, 1e-04); -out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.05, - number_of_iterations=10) -pars = out2[-1] +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, + number_of_iterations=200) +pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,3) -a.set_title('FGP_TV') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\npenalty=%s' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['TV_penalty'].name) + +textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) @@ -316,50 +143,12 @@ out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, time_step=0.0003, tolerance_constant=0.0001, number_of_iterations=300) -pars = out2[-1] +pars = out2[-2] reg_output.append(out2) a=fig.add_subplot(2,3,4) -a.set_title('LLT_model') -textstr = 'regularization_parameter=%.2f\niterations=%d\ntolerance=%.2e\ntime-step=%f' -textstr = textstr % (pars['regularization_parameter'], - pars['number_of_iterations'], - pars['tolerance_constant'], - pars['time_step'] - ) - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) - -###################### PatchBased_Regul ######################################### -# Quick 2D denoising example in Matlab: -# Im = double(imread('lena_gray_256.tif'))/255; % loading image -# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); - -out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - searching_window_ratio=3, - similarity_window_ratio=1, - PB_filtering_parameter=0.08) -pars = out2[-1] -reg_output.append(out2) - -a=fig.add_subplot(2,3,5) -a.set_title('PatchBased_Regul') -textstr = 'regularization_parameter=%.2f\nsearching_window_ratio=%d\nsimilarity_window_ratio=%.2e\nPB_filtering_parameter=%f' -textstr = textstr % (pars['regularization_parameter'], - pars['searching_window_ratio'], - pars['similarity_window_ratio'], - pars['PB_filtering_parameter']) - - - - +textstr = out2[-1] # these are matplotlib.patch.Patch properties props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) # place a text box in upper left in axes coords @@ -367,6 +156,215 @@ a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, verticalalignment='top', bbox=props) imgplot = plt.imshow(reg_output[-1][0]) +# ###################### PatchBased_Regul ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + # searching_window_ratio=3, + # similarity_window_ratio=1, + # PB_filtering_parameter=0.08) +# pars = out2[-2] +# reg_output.append(out2) + +# a=fig.add_subplot(2,3,5) + + +# textstr = out2[-1] + +# # these are matplotlib.patch.Patch properties +# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# # place a text box in upper left in axes coords +# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + # verticalalignment='top', bbox=props) +# imgplot = plt.imshow(reg_output[-1][0]) + + +# ###################### TGV_PD ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + # first_order_term=1.3, + # second_order_term=1, + # number_of_iterations=550) +# pars = out2[-2] +# reg_output.append(out2) + +# a=fig.add_subplot(2,3,6) + + +# textstr = out2[-1] + + +# # these are matplotlib.patch.Patch properties +# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# # place a text box in upper left in axes coords +# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + # verticalalignment='top', bbox=props) +# imgplot = plt.imshow(reg_output[-1][0]) + + +plt.show() + +################################################################################ +## +## 3D Regularizers +## +################################################################################ +##Example: +## figure; +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" +# +#reader = vtk.vtkMetaImageReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +##vtk returns 3D images, let's take just the one slice there is as 2D +#Im = Converter.vtk2numpy(reader.GetOutput()) +#Im = Im.astype('float32') +##imgplot = plt.imshow(Im) +#perc = 0.05 +#u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +## map the u0 u0->u0>0 +#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +#u0 = f(u0).astype('float32') +#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), +# reader.GetOutput().GetOrigin()) +#converter.Update() +#writer = vtk.vtkMetaImageWriter() +#writer.SetInputData(converter.GetOutput()) +#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") +##writer.Write() +# +# +### plot +#fig3D = plt.figure() +##a=fig.add_subplot(3,3,1) +##a.set_title('Original') +##imgplot = plt.imshow(Im) +#sliceNo = 32 +# +#a=fig3D.add_subplot(2,3,1) +#a.set_title('noise') +#imgplot = plt.imshow(u0.T[sliceNo]) +# +#reg_output3d = [] +# +############################################################################### +## Call regularizer +# +######################## SplitBregman_TV ##################################### +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +# +##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +## #tolerance_constant=1e-4, +## TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +# +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### FGP_TV ######################################### +## u = FGP_TV(single(u0), 0.05, 100, 1e-04); +#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +# number_of_iterations=200) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### LLT_model ######################################### +## * u0 = Im + .03*randn(size(Im)); % adding noise +## [Den] = LLT_model(single(u0), 10, 0.1, 1); +##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +##input, regularization_parameter , time_step, number_of_iterations, +## tolerance_constant, restrictive_Z_smoothing=0 +#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +# time_step=0.0003, +# tolerance_constant=0.0001, +# number_of_iterations=300) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### PatchBased_Regul ######################################### +## Quick 2D denoising example in Matlab: +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +## ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); +# +#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +# searching_window_ratio=3, +# similarity_window_ratio=1, +# PB_filtering_parameter=0.08) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# ###################### TGV_PD ######################################### # Quick 2D denoising example in Matlab: @@ -375,30 +373,22 @@ imgplot = plt.imshow(reg_output[-1][0]) # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - first_order_term=1.3, - second_order_term=1, - number_of_iterations=550) -pars = out2[-1] -reg_output.append(out2) - -a=fig.add_subplot(2,3,6) -a.set_title('TGV_PD') -textstr = 'regularization_parameter=%.2f\nfirst_order_term=%.2f\nsecond_order_term=%.2f\nnumber_of_iterations=%d' -textstr = textstr % (pars['regularization_parameter'], - pars['first_order_term'], - pars['second_order_term'], - pars['number_of_iterations']) - - - - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output[-1][0]) - - - +#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +# first_order_term=1.3, +# second_order_term=1, +# number_of_iterations=550) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) -- cgit v1.2.3 From db45d96898f23c3bc97e4c19e834fa976ec301c8 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:31:16 +0100 Subject: initial commit of Reconstructor.py --- src/Python/ccpi/reconstruction/Reconstructor.py | 598 ++++++++++++++++++++++++ 1 file changed, 598 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/Reconstructor.py (limited to 'src') diff --git a/src/Python/ccpi/reconstruction/Reconstructor.py b/src/Python/ccpi/reconstruction/Reconstructor.py new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/reconstruction/Reconstructor.py @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: + + class Algorithm(Enum): + CGLS = alg.cgls + CGLS_CONV = alg.cgls_conv + SIRT = alg.sirt + MLEM = alg.mlem + CGLS_TICHONOV = alg.cgls_tikhonov + CGLS_TVREG = alg.cgls_TVreg + FISTA = 'fista' + + def __init__(self, algorithm = None, projection_data = None, + angles = None, center_of_rotation = None , + flat_field = None, dark_field = None, + iterations = None, resolution = None, isLogScale = False, threads = None, + normalized_projection = None): + + self.pars = dict() + self.pars['algorithm'] = algorithm + self.pars['projection_data'] = projection_data + self.pars['normalized_projection'] = normalized_projection + self.pars['angles'] = angles + self.pars['center_of_rotation'] = numpy.double(center_of_rotation) + self.pars['flat_field'] = flat_field + self.pars['iterations'] = iterations + self.pars['dark_field'] = dark_field + self.pars['resolution'] = resolution + self.pars['isLogScale'] = isLogScale + self.pars['threads'] = threads + if (iterations != None): + self.pars['iterationValues'] = numpy.zeros((iterations)) + + if projection_data != None and dark_field != None and flat_field != None: + norm = self.normalize(projection_data, dark_field, flat_field, 0.1) + self.pars['normalized_projection'] = norm + + + def setPars(self, parameters): + keys = ['algorithm','projection_data' ,'normalized_projection', \ + 'angles' , 'center_of_rotation' , 'flat_field', \ + 'iterations','dark_field' , 'resolution', 'isLogScale' , \ + 'threads' , 'iterationValues', 'regularize'] + + for k in keys: + if k not in parameters.keys(): + self.pars[k] = None + else: + self.pars[k] = parameters[k] + + + def sanityCheck(self): + projection_data = self.pars['projection_data'] + dark_field = self.pars['dark_field'] + flat_field = self.pars['flat_field'] + angles = self.pars['angles'] + + if projection_data != None and dark_field != None and \ + angles != None and flat_field != None: + data_shape = numpy.shape(projection_data) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + + if data_shape[1:] != numpy.shape(flat_field): + #raise Exception('Projection and flat field dimensions do not match') + return (False , 'Projection and flat field dimensions do not match') + if data_shape[1:] != numpy.shape(dark_field): + #raise Exception('Projection and dark field dimensions do not match') + return (False , 'Projection and dark field dimensions do not match') + + return (True , '' ) + elif self.pars['normalized_projection'] != None: + data_shape = numpy.shape(self.pars['normalized_projection']) + angle_shape = numpy.shape(angles) + + if angle_shape[0] != data_shape[0]: + #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ + # (angle_shape[0] , data_shape[0]) ) + return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ + (angle_shape[0] , data_shape[0]) ) + else: + return (True , '' ) + else: + return (False , 'Not enough data') + + def reconstruct(self, parameters = None): + if parameters != None: + self.setPars(parameters) + + go , reason = self.sanityCheck() + if go: + return self._reconstruct() + else: + raise Exception(reason) + + + def _reconstruct(self, parameters=None): + if parameters!=None: + self.setPars(parameters) + parameters = self.pars + + if parameters['algorithm'] != None and \ + parameters['normalized_projection'] != None and \ + parameters['angles'] != None and \ + parameters['center_of_rotation'] != None and \ + parameters['iterations'] != None and \ + parameters['resolution'] != None and\ + parameters['threads'] != None and\ + parameters['isLogScale'] != None: + + + if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, + Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): + #store parameters + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['isLogScale'] + ) + return result + elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, + Reconstructor.Algorithm.CGLS_TICHONOV, + Reconstructor.Algorithm.CGLS_TVREG) : + self.pars = parameters + result = parameters['algorithm']( + parameters['normalized_projection'] , + parameters['angles'], + parameters['center_of_rotation'], + parameters['resolution'], + parameters['iterations'], + parameters['threads'] , + parameters['regularize'], + numpy.zeros((parameters['iterations'])), + parameters['isLogScale'] + ) + + elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: + pass + + else: + if parameters['projection_data'] != None and \ + parameters['dark_field'] != None and \ + parameters['flat_field'] != None: + norm = self.normalize(parameters['projection_data'], + parameters['dark_field'], + parameters['flat_field'], 0.1) + self.pars['normalized_projection'] = norm + return self._reconstruct(parameters) + + + + def _normalize(self, projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + def normalize(self, projections, dark, flat, def_val=0): + norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] + return numpy.asarray (norm, dtype=numpy.float32) + + + +class FISTA(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , + 'weights' , 'region_of_interest' , 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): + print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): + dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): + flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +# angles = angle_proj, center_of_rotation = 86.2 , +# flat_field = flat, dark_field = dark, +# iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, + numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off') # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off') # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off') # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off') # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off') # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off') # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + return shiftScaler.GetOutput() + +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() -- cgit v1.2.3 From c3b58791b906aa6a3b99f32fa5f69a09bb075527 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:56:08 +0100 Subject: module rename to cpu_regularizers --- src/Python/setup.py | 4 ++-- src/Python/test_regularizers.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index 0468722..94467c4 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -47,7 +47,7 @@ setup( description='CCPi Core Imaging Library - FISTA Reconstruction Module', version=cil_version, cmdclass = {'build_ext': build_ext}, - ext_modules = [Extension("regularizers", + ext_modules = [Extension("ccpi.imaging.cpu_regularizers", sources=["fista_module.cpp", "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", @@ -60,5 +60,5 @@ setup( ], zip_safe = False, - packages = {'ccpi','ccpi.reconstruction'}, + packages = {'ccpi','ccpi.fistareconstruction'}, ) diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 5d25f02..755804a 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -14,7 +14,8 @@ import os from enum import Enum import timeit -from Regularizer import Regularizer +#from Regularizer import Regularizer +from ccpi.imaging.Regularizer import Regularizer ############################################################################### #https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 -- cgit v1.2.3 From 70d03d2c7567fac409086f015ca9e2ac47b0fc20 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 14:58:11 +0100 Subject: changed the backward slash to forward --- src/Python/setup.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/Python/setup.py b/src/Python/setup.py index 94467c4..154f979 100644 --- a/src/Python/setup.py +++ b/src/Python/setup.py @@ -49,12 +49,12 @@ setup( cmdclass = {'build_ext': build_ext}, ext_modules = [Extension("ccpi.imaging.cpu_regularizers", sources=["fista_module.cpp", - "..\\..\\main_func\\regularizers_CPU\\FGP_TV_core.c", - "..\\..\\main_func\\regularizers_CPU\\SplitBregman_TV_core.c", - "..\\..\\main_func\\regularizers_CPU\\LLT_model_core.c", - "..\\..\\main_func\\regularizers_CPU\\PatchBased_Regul_core.c", - "..\\..\\main_func\\regularizers_CPU\\TGV_PD_core.c", - "..\\..\\main_func\\regularizers_CPU\\utils.c" + "../../main_func/regularizers_CPU/FGP_TV_core.c", + "../../main_func/regularizers_CPU/SplitBregman_TV_core.c", + "../../main_func/regularizers_CPU/LLT_model_core.c", + "../../main_func/regularizers_CPU/PatchBased_Regul_core.c", + "../../main_func/regularizers_CPU/TGV_PD_core.c", + "../../main_func/regularizers_CPU/utils.c" ], include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ), -- cgit v1.2.3 From 396c11bd2c8bde1197b708062590a9e3b95538bd Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:01:28 +0100 Subject: added viewer for testing --- src/Python/ccpi/viewer/CILViewer.py | 361 +++++++ src/Python/ccpi/viewer/CILViewer2D.py | 1126 ++++++++++++++++++++ src/Python/ccpi/viewer/QVTKWidget.py | 340 ++++++ src/Python/ccpi/viewer/QVTKWidget2.py | 84 ++ src/Python/ccpi/viewer/__init__.py | 1 + .../viewer/__pycache__/CILViewer.cpython-35.pyc | Bin 0 -> 10542 bytes .../viewer/__pycache__/CILViewer2D.cpython-35.pyc | Bin 0 -> 35633 bytes .../viewer/__pycache__/QVTKWidget.cpython-35.pyc | Bin 0 -> 10099 bytes .../viewer/__pycache__/QVTKWidget2.cpython-35.pyc | Bin 0 -> 1316 bytes .../viewer/__pycache__/__init__.cpython-35.pyc | Bin 0 -> 210 bytes src/Python/ccpi/viewer/embedvtk.py | 75 ++ 11 files changed, 1987 insertions(+) create mode 100644 src/Python/ccpi/viewer/CILViewer.py create mode 100644 src/Python/ccpi/viewer/CILViewer2D.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget.py create mode 100644 src/Python/ccpi/viewer/QVTKWidget2.py create mode 100644 src/Python/ccpi/viewer/__init__.py create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc create mode 100644 src/Python/ccpi/viewer/embedvtk.py (limited to 'src') diff --git a/src/Python/ccpi/viewer/CILViewer.py b/src/Python/ccpi/viewer/CILViewer.py new file mode 100644 index 0000000..efcf8be --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +import math +from vtk.util import numpy_support + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + + + +class CILViewer(): + '''Simple 3D Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600): + '''creates the rendering pipeline''' + + # create a rendering window and renderer + self.ren = vtk.vtkRenderer() + self.renWin = vtk.vtkRenderWindow() + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + # img 3D as slice + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceActor = None + self.voi = None + self.wl = None + self.ia = None + self.sliceActorNo = 0 + # create a renderwindowinteractor + self.iren = vtk.vtkRenderWindowInteractor() + self.iren.SetRenderWindow(self.renWin) + + self.style = vtk.vtkInteractorStyleTrackballCamera() + self.iren.SetInteractorStyle(self.style) + + self.ren.SetBackground(.1, .2, .4) + + self.actors = {} + self.iren.RemoveObservers('MouseWheelForwardEvent') + self.iren.RemoveObservers('MouseWheelBackwardEvent') + + self.iren.AddObserver('MouseWheelForwardEvent', self.mouseInteraction, 1.0) + self.iren.AddObserver('MouseWheelBackwardEvent', self.mouseInteraction, 1.0) + + self.iren.RemoveObservers('KeyPressEvent') + self.iren.AddObserver('KeyPressEvent', self.keyPress, 1.0) + + + self.iren.Initialize() + + + + def getRenderer(self): + '''returns the renderer''' + return self.ren + + def getRenderWindow(self): + '''returns the render window''' + return self.renWin + + def getInteractor(self): + '''returns the render window interactor''' + return self.iren + + def getCamera(self): + '''returns the active camera''' + return self.ren.GetActiveCamera() + + def createPolyDataActor(self, polydata): + '''returns an actor for a given polydata''' + mapper = vtk.vtkPolyDataMapper() + if vtk.VTK_MAJOR_VERSION <= 5: + mapper.SetInput(polydata) + else: + mapper.SetInputData(polydata) + + # actor + actor = vtk.vtkActor() + actor.SetMapper(mapper) + #actor.GetProperty().SetOpacity(0.8) + return actor + + def setPolyDataActor(self, actor): + '''displays the given polydata''' + + self.ren.AddActor(actor) + + self.actors[len(self.actors)+1] = [actor, True] + self.iren.Initialize() + self.renWin.Render() + + def displayPolyData(self, polydata): + self.setPolyDataActor(self.createPolyDataActor(polydata)) + + def hideActor(self, actorno): + '''Hides an actor identified by its number in the list of actors''' + try: + if self.actors[actorno][1]: + self.ren.RemoveActor(self.actors[actorno][0]) + self.actors[actorno][1] = False + except KeyError as ke: + print ("Warning Actor not present") + + def showActor(self, actorno, actor = None): + '''Shows hidden actor identified by its number in the list of actors''' + try: + if not self.actors[actorno][1]: + self.ren.AddActor(self.actors[actorno][0]) + self.actors[actorno][1] = True + return actorno + except KeyError as ke: + # adds it to the actors if not there already + if actor != None: + self.ren.AddActor(actor) + self.actors[len(self.actors)+1] = [actor, True] + return len(self.actors) + + def addActor(self, actor): + '''Adds an actor to the render''' + return self.showActor(0, actor) + + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + + def startRenderLoop(self): + self.iren.Start() + + + def setupObservers(self, interactor): + interactor.RemoveObservers('LeftButtonPressEvent') + interactor.AddObserver('LeftButtonPressEvent', self.mouseInteraction) + interactor.Initialize() + + + def mouseInteraction(self, interactor, event): + if event == 'MouseWheelForwardEvent': + maxSlice = self.img3D.GetDimensions()[self.sliceOrientation] + if (self.sliceno + 1 < maxSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno + 1 + self.displaySliceActor(self.sliceno) + else: + minSlice = 0 + if (self.sliceno - 1 > minSlice): + self.hideActor(self.sliceActorNo) + self.sliceno = self.sliceno - 1 + self.displaySliceActor(self.sliceno) + + + def keyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "x": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_YZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "y": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XZ + self.sliceno = int(self.img3D.GetDimensions()[1] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + elif interactor.GetKeyCode() == "z": + # slice on the other orientation + self.sliceOrientation = SLICE_ORIENTATION_XY + self.sliceno = int(self.img3D.GetDimensions()[2] / 2) + self.hideActor(self.sliceActorNo) + self.displaySliceActor(self.sliceno) + if interactor.GetKeyCode() == "X": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("x") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("y") + self.keyPress(interactor, event) + elif interactor.GetKeyCode() == "Z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = math.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.ren.SetActiveCamera(camera) + self.ren.ResetCamera() + self.ren.Render() + interactor.SetKeyCode("z") + self.keyPress(interactor, event) + else : + print ("Unhandled event %s" % interactor.GetKeyCode()) + + + + def setInput3DData(self, imageData): + self.img3D = imageData + + def setInputAsNumpy(self, numpyarray): + if (len(numpy.shape(numpyarray)) == 3): + doubleImg = vtk.vtkImageData() + shape = numpy.shape(numpyarray) + doubleImg.SetDimensions(shape[0], shape[1], shape[2]) + doubleImg.SetOrigin(0,0,0) + doubleImg.SetSpacing(1,1,1) + doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) + #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) + doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) + + for i in range(shape[0]): + for j in range(shape[1]): + for k in range(shape[2]): + doubleImg.SetScalarComponentFromDouble( + i,j,k,0, numpyarray[i][j][k]) + #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(doubleImg) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(doubleImg) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(iMin) + shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + + def displaySliceActor(self, sliceno = 0): + self.sliceno = sliceno + first = False + + self.sliceActor , self.voi, self.wl , self.ia = \ + self.getSliceActor(self.img3D, + sliceno, + self.sliceActor, + self.voi, + self.wl, + self.ia) + no = self.showActor(self.sliceActorNo, self.sliceActor) + self.sliceActorNo = no + + self.iren.Initialize() + self.renWin.Render() + + return self.sliceActorNo + + + def getSliceActor(self, + imageData , + sliceno=0, + imageActor=None , + voi=None, + windowLevel=None, + imageAccumulate=None): + '''Slices a 3D volume and then creates an actor to be rendered''' + if (voi==None): + voi = vtk.vtkExtractVOI() + #voi = vtk.vtkImageClip() + voi.SetInputData(imageData) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = sliceno + extent[self.sliceOrientation * 2 + 1] = sliceno + voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + voi.Update() + # set window/level for all slices + if imageAccumulate == None: + imageAccumulate = vtk.vtkImageAccumulate() + + if (windowLevel == None): + windowLevel = vtk.vtkImageMapToWindowLevelColors() + imageAccumulate.SetInputData(imageData) + imageAccumulate.Update() + cmax = imageAccumulate.GetMax()[0] + cmin = imageAccumulate.GetMin()[0] + windowLevel.SetLevel((cmax+cmin)/2) + windowLevel.SetWindow(cmax-cmin) + + windowLevel.SetInputData(voi.GetOutput()) + windowLevel.Update() + + if imageActor == None: + imageActor = vtk.vtkImageActor() + imageActor.SetInputData(windowLevel.GetOutput()) + imageActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + imageActor.Update() + return (imageActor , voi, windowLevel, imageAccumulate) + + + # Set interpolation on + def setInterpolateOn(self): + self.sliceActor.SetInterpolate(True) + self.renWin.Render() + + # Set interpolation off + def setInterpolateOff(self): + self.sliceActor.SetInterpolate(False) + self.renWin.Render() \ No newline at end of file diff --git a/src/Python/ccpi/viewer/CILViewer2D.py b/src/Python/ccpi/viewer/CILViewer2D.py new file mode 100644 index 0000000..c1629af --- /dev/null +++ b/src/Python/ccpi/viewer/CILViewer2D.py @@ -0,0 +1,1126 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vtk +import numpy +from vtk.util import numpy_support , vtkImageImportFromArray +from enum import Enum + +SLICE_ORIENTATION_XY = 2 # Z +SLICE_ORIENTATION_XZ = 1 # Y +SLICE_ORIENTATION_YZ = 0 # X + +CONTROL_KEY = 8 +SHIFT_KEY = 4 +ALT_KEY = -128 + + +# Converter class +class Converter(): + + # Utility functions to transform numpy arrays to vtkImageData and viceversa + @staticmethod + def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Creates a vtkImageImportFromArray object and returns it. + + It handles the different axis order from numpy to VTK''' + importer = vtkImageImportFromArray.vtkImageImportFromArray() + importer.SetArray(numpy.transpose(nparray).copy()) + importer.SetDataSpacing(spacing) + importer.SetDataOrigin(origin) + return importer + + @staticmethod + def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): + '''Converts a 3D numpy array to a vtkImageData''' + importer = Converter.numpy2vtkImporter(nparray, spacing, origin) + importer.Update() + return importer.GetOutput() + + @staticmethod + def vtk2numpy(imgdata): + '''Converts the VTK data to 3D numpy array''' + img_data = numpy_support.vtk_to_numpy( + imgdata.GetPointData().GetScalars()) + + dims = imgdata.GetDimensions() + dims = (dims[2],dims[1],dims[0]) + data3d = numpy.reshape(img_data, dims) + + return numpy.transpose(data3d).copy() + + @staticmethod + def tiffStack2numpy(filename, indices, + extent = None , sampleRate = None ,\ + flatField = None, darkField = None): + '''Converts a stack of TIFF files to numpy array. + + filename must contain the whole path. The filename is supposed to be named and + have a suffix with the ordinal file number, i.e. /path/to/projection_%03d.tif + + indices are the suffix, generally an increasing number + + Optionally extracts only a selection of the 2D images and (optionally) + normalizes. + ''' + + stack = vtk.vtkImageData() + reader = vtk.vtkTIFFReader() + voi = vtk.vtkExtractVOI() + + #directory = "C:\\Users\\ofn77899\\Documents\\CCPi\\IMAT\\20170419_crabtomo\\crabtomo\\" + + stack_image = numpy.asarray([]) + nreduced = len(indices) + + for num in range(len(indices)): + fn = filename % indices[num] + print ("resampling %s" % ( fn ) ) + reader.SetFileName(fn) + reader.Update() + print (reader.GetOutput().GetScalarTypeAsString()) + if num == 0: + if (extent == None): + sliced = reader.GetOutput().GetExtent() + stack.SetExtent(sliced[0],sliced[1], sliced[2],sliced[3], 0, nreduced-1) + else: + sliced = extent + voi.SetVOI(extent) + + if sampleRate is not None: + voi.SetSampleRate(sampleRate) + ext = numpy.asarray([(sliced[2*i+1] - sliced[2*i])/sampleRate[i] for i in range(3)], dtype=int) + print ("ext {0}".format(ext)) + stack.SetExtent(0, ext[0] , 0, ext[1], 0, nreduced-1) + else: + stack.SetExtent(0, sliced[1] - sliced[0] , 0, sliced[3]-sliced[2], 0, nreduced-1) + if (flatField != None and darkField != None): + stack.AllocateScalars(vtk.VTK_FLOAT, 1) + else: + stack.AllocateScalars(reader.GetOutput().GetScalarType(), 1) + print ("Image Size: %d" % ((sliced[1]+1)*(sliced[3]+1) )) + stack_image = Converter.vtk2numpy(stack) + print ("Stack shape %s" % str(numpy.shape(stack_image))) + + if extent!=None: + voi.SetInputData(reader.GetOutput()) + voi.Update() + img = voi.GetOutput() + else: + img = reader.GetOutput() + + theSlice = Converter.vtk2numpy(img).T[0] + if darkField != None and flatField != None: + print("Try to normalize") + #if numpy.shape(darkField) == numpy.shape(flatField) and numpy.shape(flatField) == numpy.shape(theSlice): + theSlice = Converter.normalize(theSlice, darkField, flatField, 0.01) + print (theSlice.dtype) + + + print ("Slice shape %s" % str(numpy.shape(theSlice))) + stack_image.T[num] = theSlice.copy() + + return stack_image + + @staticmethod + def normalize(projection, dark, flat, def_val=0): + a = (projection - dark) + b = (flat-dark) + with numpy.errstate(divide='ignore', invalid='ignore'): + c = numpy.true_divide( a, b ) + c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 + return c + + + +## Utility functions to transform numpy arrays to vtkImageData and viceversa +#def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtkImporter(nparray, spacing, origin) +# +#def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): +# return Converter.numpy2vtk(nparray, spacing, origin) +# +#def vtk2numpy(imgdata): +# return Converter.vtk2numpy(imgdata) +# +#def tiffStack2numpy(filename, indices): +# return Converter.tiffStack2numpy(filename, indices) + +class ViewerEvent(Enum): + # left button + PICK_EVENT = 0 + # alt + right button + move + WINDOW_LEVEL_EVENT = 1 + # shift + right button + ZOOM_EVENT = 2 + # control + right button + PAN_EVENT = 3 + # control + left button + CREATE_ROI_EVENT = 4 + # alt + left button + DELETE_ROI_EVENT = 5 + # release button + NO_EVENT = -1 + + +#class CILInteractorStyle(vtk.vtkInteractorStyleTrackballCamera): +class CILInteractorStyle(vtk.vtkInteractorStyleImage): + + def __init__(self, callback): + vtk.vtkInteractorStyleImage.__init__(self) + self.callback = callback + self._viewer = callback + priority = 1.0 + +# self.AddObserver("MouseWheelForwardEvent" , callback.OnMouseWheelForward , priority) +# self.AddObserver("MouseWheelBackwardEvent" , callback.OnMouseWheelBackward, priority) +# self.AddObserver('KeyPressEvent', callback.OnKeyPress, priority) +# self.AddObserver('LeftButtonPressEvent', callback.OnLeftButtonPressEvent, priority) +# self.AddObserver('RightButtonPressEvent', callback.OnRightButtonPressEvent, priority) +# self.AddObserver('LeftButtonReleaseEvent', callback.OnLeftButtonReleaseEvent, priority) +# self.AddObserver('RightButtonReleaseEvent', callback.OnRightButtonReleaseEvent, priority) +# self.AddObserver('MouseMoveEvent', callback.OnMouseMoveEvent, priority) + + self.AddObserver("MouseWheelForwardEvent" , self.OnMouseWheelForward , priority) + self.AddObserver("MouseWheelBackwardEvent" , self.OnMouseWheelBackward, priority) + self.AddObserver('KeyPressEvent', self.OnKeyPress, priority) + self.AddObserver('LeftButtonPressEvent', self.OnLeftButtonPressEvent, priority) + self.AddObserver('RightButtonPressEvent', self.OnRightButtonPressEvent, priority) + self.AddObserver('LeftButtonReleaseEvent', self.OnLeftButtonReleaseEvent, priority) + self.AddObserver('RightButtonReleaseEvent', self.OnRightButtonReleaseEvent, priority) + self.AddObserver('MouseMoveEvent', self.OnMouseMoveEvent, priority) + + self.InitialEventPosition = (0,0) + + + def SetInitialEventPosition(self, xy): + self.InitialEventPosition = xy + + def GetInitialEventPosition(self): + return self.InitialEventPosition + + def GetKeyCode(self): + return self.GetInteractor().GetKeyCode() + + def SetKeyCode(self, keycode): + self.GetInteractor().SetKeyCode(keycode) + + def GetControlKey(self): + return self.GetInteractor().GetControlKey() == CONTROL_KEY + + def GetShiftKey(self): + return self.GetInteractor().GetShiftKey() == SHIFT_KEY + + def GetAltKey(self): + return self.GetInteractor().GetAltKey() == ALT_KEY + + def GetEventPosition(self): + return self.GetInteractor().GetEventPosition() + + def GetEventPositionInWorldCoordinates(self): + pass + + def GetDeltaEventPosition(self): + x,y = self.GetInteractor().GetEventPosition() + return (x - self.InitialEventPosition[0] , y - self.InitialEventPosition[1]) + + def Dolly(self, factor): + self.callback.camera.Dolly(factor) + self.callback.ren.ResetCameraClippingRange() + + def GetDimensions(self): + return self._viewer.img3D.GetDimensions() + + def GetInputData(self): + return self._viewer.img3D + + def GetSliceOrientation(self): + return self._viewer.sliceOrientation + + def SetSliceOrientation(self, orientation): + self._viewer.sliceOrientation = orientation + + def GetActiveSlice(self): + return self._viewer.sliceno + + def SetActiveSlice(self, sliceno): + self._viewer.sliceno = sliceno + + def UpdatePipeline(self, reset = False): + self._viewer.updatePipeline(reset) + + def GetActiveCamera(self): + return self._viewer.ren.GetActiveCamera() + + def SetActiveCamera(self, camera): + self._viewer.ren.SetActiveCamera(camera) + + def ResetCamera(self): + self._viewer.ren.ResetCamera() + + def Render(self): + self._viewer.renWin.Render() + + def UpdateSliceActor(self): + self._viewer.sliceActor.Update() + + def AdjustCamera(self): + self._viewer.AdjustCamera() + + def SaveRender(self, filename): + self._viewer.SaveRender(filename) + + def GetRenderWindow(self): + return self._viewer.renWin + + def GetRenderer(self): + return self._viewer.ren + + def GetROIWidget(self): + return self._viewer.ROIWidget + + def SetViewerEvent(self, event): + self._viewer.event = event + + def GetViewerEvent(self): + return self._viewer.event + + def SetInitialCameraPosition(self, position): + self._viewer.InitialCameraPosition = position + + def GetInitialCameraPosition(self): + return self._viewer.InitialCameraPosition + + def SetInitialLevel(self, level): + self._viewer.InitialLevel = level + + def GetInitialLevel(self): + return self._viewer.InitialLevel + + def SetInitialWindow(self, window): + self._viewer.InitialWindow = window + + def GetInitialWindow(self): + return self._viewer.InitialWindow + + def GetWindowLevel(self): + return self._viewer.wl + + def SetROI(self, roi): + self._viewer.ROI = roi + + def GetROI(self): + return self._viewer.ROI + + def UpdateCornerAnnotation(self, text, corner): + self._viewer.updateCornerAnnotation(text, corner) + + def GetPicker(self): + return self._viewer.picker + + def GetCornerAnnotation(self): + return self._viewer.cornerAnnotation + + def UpdateROIHistogram(self): + self._viewer.updateROIHistogram() + + + ############### Handle events + def OnMouseWheelForward(self, interactor, event): + maxSlice = self.GetDimensions()[self.GetSliceOrientation()] + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + + if (self.GetActiveSlice() + advance < maxSlice): + self.SetActiveSlice(self.GetActiveSlice() + advance) + + self.UpdatePipeline() + else: + print ("maxSlice %d request %d" % (maxSlice, self.GetActiveSlice() + 1 )) + + def OnMouseWheelBackward(self, interactor, event): + minSlice = 0 + shift = interactor.GetShiftKey() + advance = 1 + if shift: + advance = 10 + if (self.GetActiveSlice() - advance >= minSlice): + self.SetActiveSlice( self.GetActiveSlice() - advance) + self.UpdatePipeline() + else: + print ("minSlice %d request %d" % (minSlice, self.GetActiveSlice() + 1 )) + + def OnKeyPress(self, interactor, event): + #print ("Pressed key %s" % interactor.GetKeyCode()) + # Slice Orientation + if interactor.GetKeyCode() == "X": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_YZ ) + self.SetActiveSlice( int(self.GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Y": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XZ ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[1] / 2) ) + self.UpdatePipeline(True) + elif interactor.GetKeyCode() == "Z": + # slice on the other orientation + self.SetSliceOrientation ( SLICE_ORIENTATION_XY ) + self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[2] / 2) ) + self.UpdatePipeline(True) + if interactor.GetKeyCode() == "x": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_YZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("X") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "y": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,0,-1) + self.SetActiveCamera(camera) + self.Render() + interactor.SetKeyCode("Y") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "z": + # Change the camera view point + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] + newposition[SLICE_ORIENTATION_XY] = numpy.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) + camera.SetPosition(newposition) + camera.SetViewUp(0,1,0) + self.SetActiveCamera(camera) + self.ResetCamera() + self.Render() + interactor.SetKeyCode("Z") + self.OnKeyPress(interactor, event) + elif interactor.GetKeyCode() == "a": + # reset color/window + cmax = self._viewer.ia.GetMax()[0] + cmin = self._viewer.ia.GetMin()[0] + + self.SetInitialLevel( (cmax+cmin)/2 ) + self.SetInitialWindow( cmax-cmin ) + + self.GetWindowLevel().SetLevel(self.GetInitialLevel()) + self.GetWindowLevel().SetWindow(self.GetInitialWindow()) + + self.GetWindowLevel().Update() + + self.UpdateSliceActor() + self.AdjustCamera() + self.Render() + + elif interactor.GetKeyCode() == "s": + filename = "current_render" + self.SaveRender(filename) + elif interactor.GetKeyCode() == "q": + print ("Terminating by pressing q %s" % (interactor.GetKeyCode(), )) + interactor.SetKeyCode("e") + self.OnKeyPress(interactor, event) + else : + #print ("Unhandled event %s" % (interactor.GetKeyCode(), ))) + pass + + def OnLeftButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + if ctrl and not (alt and shift): + self.SetViewerEvent( ViewerEvent.CREATE_ROI_EVENT ) + wsize = self.GetRenderWindow().GetSize() + position = interactor.GetEventPosition() + self.GetROIWidget().GetBorderRepresentation().SetPosition((position[0]/wsize[0] - 0.05) , (position[1]/wsize[1] - 0.05)) + self.GetROIWidget().GetBorderRepresentation().SetPosition2( (0.1) , (0.1)) + + self.GetROIWidget().On() + self.SetDisplayHistogram(True) + self.Render() + print ("Event %s is CREATE_ROI_EVENT" % (event)) + elif alt and not (shift and ctrl): + self.SetViewerEvent( ViewerEvent.DELETE_ROI_EVENT ) + self.GetROIWidget().Off() + self._viewer.updateCornerAnnotation("", 1, False) + self.SetDisplayHistogram(False) + self.Render() + print ("Event %s is DELETE_ROI_EVENT" % (event)) + elif not (ctrl and alt and shift): + self.SetViewerEvent ( ViewerEvent.PICK_EVENT ) + self.HandlePickEvent(interactor, event) + print ("Event %s is PICK_EVENT" % (event)) + + + def SetDisplayHistogram(self, display): + if display: + if (self._viewer.displayHistogram == 0): + self.GetRenderer().AddActor(self._viewer.histogramPlotActor) + self.firstHistogram = 1 + self.Render() + + self._viewer.histogramPlotActor.VisibilityOn() + self._viewer.displayHistogram = True + else: + self._viewer.histogramPlotActor.VisibilityOff() + self._viewer.displayHistogram = False + + + def OnLeftButtonReleaseEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.CREATE_ROI_EVENT: + #bc = self.ROIWidget.GetBorderRepresentation().GetPositionCoordinate() + #print (bc.GetValue()) + self.OnROIModifiedEvent(interactor, event) + + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def OnRightButtonPressEvent(self, interactor, event): + alt = interactor.GetAltKey() + shift = interactor.GetShiftKey() + ctrl = interactor.GetControlKey() +# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) +# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) +# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) + + interactor.SetInitialEventPosition(interactor.GetEventPosition()) + + + if alt and not (ctrl and shift): + self.SetViewerEvent( ViewerEvent.WINDOW_LEVEL_EVENT ) + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif shift and not (ctrl and alt): + self.SetViewerEvent( ViewerEvent.ZOOM_EVENT ) + self.SetInitialCameraPosition( self.GetActiveCamera().GetPosition()) + print ("Event %s is ZOOM_EVENT" % (event)) + elif ctrl and not (shift and alt): + self.SetViewerEvent (ViewerEvent.PAN_EVENT ) + self.SetInitialCameraPosition ( self.GetActiveCamera().GetPosition() ) + print ("Event %s is PAN_EVENT" % (event)) + + def OnRightButtonReleaseEvent(self, interactor, event): + print (event) + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + self.SetInitialLevel( self.GetWindowLevel().GetLevel() ) + self.SetInitialWindow ( self.GetWindowLevel().GetWindow() ) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT or \ + self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.SetInitialCameraPosition( () ) + + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + + def OnROIModifiedEvent(self, interactor, event): + + #print ("ROI EVENT " + event) + p1 = self.GetROIWidget().GetBorderRepresentation().GetPositionCoordinate() + p2 = self.GetROIWidget().GetBorderRepresentation().GetPosition2Coordinate() + wsize = self.GetRenderWindow().GetSize() + + #print (p1.GetValue()) + #print (p2.GetValue()) + pp1 = [p1.GetValue()[0] * wsize[0] , p1.GetValue()[1] * wsize[1] , 0.0] + pp2 = [p2.GetValue()[0] * wsize[0] + pp1[0] , p2.GetValue()[1] * wsize[1] + pp1[1] , 0.0] + vox1 = self.viewport2imageCoordinate(pp1) + vox2 = self.viewport2imageCoordinate(pp2) + + self.SetROI( (vox1 , vox2) ) + roi = self.GetROI() + print ("Pixel1 %d,%d,%d Value %f" % vox1 ) + print ("Pixel2 %d,%d,%d Value %f" % vox2 ) + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + x = abs(roi[1][0] - roi[0][0]) + y = abs(roi[1][2] - roi[0][2]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + x = abs(roi[1][1] - roi[0][1]) + y = abs(roi[1][2] - roi[0][2]) + + text = "ROI: %d x %d, %.2f kp" % (x,y,float(x*y)/1024.) + print (text) + self.UpdateCornerAnnotation(text, 1) + self.UpdateROIHistogram() + self.SetViewerEvent( ViewerEvent.NO_EVENT ) + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.GetPicker().Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.GetPicker().GetPickPosition()) + pickPosition[self.GetSliceOrientation()] = \ + self.GetInputData().GetSpacing()[self.GetSliceOrientation()] * self.GetActiveSlice() + \ + self.GetInputData().GetOrigin()[self.GetSliceOrientation()] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.GetInputData().GetDimensions() + print (dims) + spac = self.GetInputData().GetSpacing() + orig = self.GetInputData().GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.GetInputData().GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + + def OnMouseMoveEvent(self, interactor, event): + if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: + print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) + self.HandleWindowLevel(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: + self.HandlePickEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT: + self.HandleZoomEvent(interactor, event) + elif self.GetViewerEvent() == ViewerEvent.PAN_EVENT: + self.HandlePanEvent(interactor, event) + + + def HandleZoomEvent(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + size = self.GetRenderWindow().GetSize() + dy = - 4 * dy / size[1] + + print ("distance: " + str(self.GetActiveCamera().GetDistance())) + + print ("\ndy: %f\ncamera dolly %f\n" % (dy, 1 + dy)) + + camera = vtk.vtkCamera() + camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + dist = newposition[SLICE_ORIENTATION_XY] * ( 1 + dy ) + newposition[SLICE_ORIENTATION_XY] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[SLICE_ORIENTATION_XZ] *= ( 1 + dy ) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[SLICE_ORIENTATION_YZ] *= ( 1 + dy ) + #print ("new position " + str(newposition)) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + print ("distance after: " + str(self.GetActiveCamera().GetDistance())) + + def HandlePanEvent(self, interactor, event): + x,y = interactor.GetEventPosition() + x0,y0 = interactor.GetInitialEventPosition() + + ic = self.viewport2imageCoordinate((x,y)) + ic0 = self.viewport2imageCoordinate((x0,y0)) + + dx = 4 *( ic[0] - ic0[0]) + dy = 4* (ic[1] - ic0[1]) + + camera = vtk.vtkCamera() + #print ("current position " + str(self.InitialCameraPosition)) + camera.SetViewUp(self.GetActiveCamera().GetViewUp()) + camera.SetPosition(self.GetInitialCameraPosition()) + newposition = [i for i in self.GetInitialCameraPosition()] + newfocalpoint = [i for i in self.GetActiveCamera().GetFocalPoint()] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + newposition[0] -= dx + newposition[1] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[1] = newposition[1] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + newposition[0] -= dx + newposition[2] -= dy + newfocalpoint[0] = newposition[0] + newfocalpoint[2] = newposition[2] + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + newposition[1] -= dx + newposition[2] -= dy + newfocalpoint[2] = newposition[2] + newfocalpoint[1] = newposition[1] + #print ("new position " + str(newposition)) + camera.SetFocalPoint(newfocalpoint) + camera.SetPosition(newposition) + self.SetActiveCamera(camera) + + self.Render() + + def HandleWindowLevel(self, interactor, event): + dx,dy = interactor.GetDeltaEventPosition() + print ("Event delta %d %d" % (dx,dy)) + size = self.GetRenderWindow().GetSize() + + dx = 4 * dx / size[0] + dy = 4 * dy / size[1] + window = self.GetInitialWindow() + level = self.GetInitialLevel() + + if abs(window) > 0.01: + dx = dx * window + else: + dx = dx * (lambda x: -0.01 if x <0 else 0.01)(window); + + if abs(level) > 0.01: + dy = dy * level + else: + dy = dy * (lambda x: -0.01 if x <0 else 0.01)(level) + + + # Abs so that direction does not flip + + if window < 0.0: + dx = -1*dx + if level < 0.0: + dy = -1*dy + + # Compute new window level + + newWindow = dx + window + newLevel = level - dy + + # Stay away from zero and really + + if abs(newWindow) < 0.01: + newWindow = 0.01 * (lambda x: -1 if x <0 else 1)(newWindow) + + if abs(newLevel) < 0.01: + newLevel = 0.01 * (lambda x: -1 if x <0 else 1)(newLevel) + + self.GetWindowLevel().SetWindow(newWindow) + self.GetWindowLevel().SetLevel(newLevel) + + self.GetWindowLevel().Update() + self.UpdateSliceActor() + self.AdjustCamera() + + self.Render() + + def HandlePickEvent(self, interactor, event): + position = interactor.GetEventPosition() + #print ("PICK " + str(position)) + vox = self.viewport2imageCoordinate(position) + #print ("Pixel %d,%d,%d Value %f" % vox ) + self._viewer.cornerAnnotation.VisibilityOn() + self.UpdateCornerAnnotation("[%d,%d,%d] : %.2f" % vox , 0) + self.Render() + +############################################################################### + + + +class CILViewer2D(): + '''Simple Interactive Viewer based on VTK classes''' + + def __init__(self, dimx=600,dimy=600, ren=None, renWin=None,iren=None): + '''creates the rendering pipeline''' + # create a rendering window and renderer + if ren == None: + self.ren = vtk.vtkRenderer() + else: + self.ren = ren + if renWin == None: + self.renWin = vtk.vtkRenderWindow() + else: + self.renWin = renWin + if iren == None: + self.iren = vtk.vtkRenderWindowInteractor() + else: + self.iren = iren + + self.renWin.SetSize(dimx,dimy) + self.renWin.AddRenderer(self.ren) + + self.style = CILInteractorStyle(self) + + self.iren.SetInteractorStyle(self.style) + self.iren.SetRenderWindow(self.renWin) + self.iren.Initialize() + self.ren.SetBackground(.1, .2, .4) + + self.camera = vtk.vtkCamera() + self.camera.ParallelProjectionOn() + self.ren.SetActiveCamera(self.camera) + + # data + self.img3D = None + self.sliceno = 0 + self.sliceOrientation = SLICE_ORIENTATION_XY + + #Actors + self.sliceActor = vtk.vtkImageActor() + self.voi = vtk.vtkExtractVOI() + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia = vtk.vtkImageAccumulate() + self.sliceActorNo = 0 + + #initial Window/Level + self.InitialLevel = 0 + self.InitialWindow = 0 + + #ViewerEvent + self.event = ViewerEvent.NO_EVENT + + # ROI Widget + self.ROIWidget = vtk.vtkBorderWidget() + self.ROIWidget.SetInteractor(self.iren) + self.ROIWidget.CreateDefaultRepresentation() + self.ROIWidget.GetBorderRepresentation().GetBorderProperty().SetColor(0,1,0) + self.ROIWidget.AddObserver(vtk.vtkWidgetEvent.Select, self.style.OnROIModifiedEvent, 1.0) + + # edge points of the ROI + self.ROI = () + + #picker + self.picker = vtk.vtkPropPicker() + self.picker.PickFromListOn() + self.picker.AddPickList(self.sliceActor) + + self.iren.SetPicker(self.picker) + + # corner annotation + self.cornerAnnotation = vtk.vtkCornerAnnotation() + self.cornerAnnotation.SetMaximumFontSize(12); + self.cornerAnnotation.PickableOff(); + self.cornerAnnotation.VisibilityOff(); + self.cornerAnnotation.GetTextProperty().ShadowOn(); + self.cornerAnnotation.SetLayerNumber(1); + + + + # cursor doesn't show up + self.cursor = vtk.vtkCursor2D() + self.cursorMapper = vtk.vtkPolyDataMapper2D() + self.cursorActor = vtk.vtkActor2D() + self.cursor.SetModelBounds(-10, 10, -10, 10, 0, 0) + self.cursor.SetFocalPoint(0, 0, 0) + self.cursor.AllOff() + self.cursor.AxesOn() + self.cursorActor.PickableOff() + self.cursorActor.VisibilityOn() + self.cursorActor.GetProperty().SetColor(1, 1, 1) + self.cursorActor.SetLayerNumber(1) + self.cursorMapper.SetInputData(self.cursor.GetOutput()) + self.cursorActor.SetMapper(self.cursorMapper) + + # Zoom + self.InitialCameraPosition = () + + # XY Plot actor for histogram + self.displayHistogram = False + self.firstHistogram = 0 + self.roiIA = vtk.vtkImageAccumulate() + self.roiVOI = vtk.vtkExtractVOI() + self.histogramPlotActor = vtk.vtkXYPlotActor() + self.histogramPlotActor.ExchangeAxesOff(); + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetXLabelFormat( "%g" ) + self.histogramPlotActor.SetAdjustXLabels(3) + self.histogramPlotActor.SetXTitle( "Level" ) + self.histogramPlotActor.SetYTitle( "N" ) + self.histogramPlotActor.SetXValuesToValue() + self.histogramPlotActor.SetPlotColor(0, (0,1,1) ) + self.histogramPlotActor.SetPosition(0.6,0.6) + self.histogramPlotActor.SetPosition2(0.4,0.4) + + + + def GetInteractor(self): + return self.iren + + def GetRenderer(self): + return self.ren + + def setInput3DData(self, imageData): + self.img3D = imageData + self.installPipeline() + + def setInputAsNumpy(self, numpyarray, origin=(0,0,0), spacing=(1.,1.,1.), + rescale=True, dtype=vtk.VTK_UNSIGNED_SHORT): + importer = Converter.numpy2vtkImporter(numpyarray, spacing, origin) + importer.Update() + + if rescale: + # rescale to appropriate VTK_UNSIGNED_SHORT + stats = vtk.vtkImageAccumulate() + stats.SetInputData(importer.GetOutput()) + stats.Update() + iMin = stats.GetMin()[0] + iMax = stats.GetMax()[0] + if (iMax - iMin == 0): + scale = 1 + else: + if dtype == vtk.VTK_UNSIGNED_SHORT: + scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + elif dtype == vtk.VTK_UNSIGNED_INT: + scale = vtk.VTK_UNSIGNED_INT_MAX / (iMax - iMin) + + shiftScaler = vtk.vtkImageShiftScale () + shiftScaler.SetInputData(importer.GetOutput()) + shiftScaler.SetScale(scale) + shiftScaler.SetShift(-iMin) + shiftScaler.SetOutputScalarType(dtype) + shiftScaler.Update() + self.img3D = shiftScaler.GetOutput() + else: + self.img3D = importer.GetOutput() + + self.installPipeline() + + def displaySlice(self, sliceno = 0): + self.sliceno = sliceno + + self.updatePipeline() + + self.renWin.Render() + + return self.sliceActorNo + + def updatePipeline(self, resetcamera = False): + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + self.ia.Update() + self.wl.Update() + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + + self.updateCornerAnnotation("Slice %d/%d" % (self.sliceno + 1 , self.img3D.GetDimensions()[self.sliceOrientation])) + + if self.displayHistogram: + self.updateROIHistogram() + + self.AdjustCamera(resetcamera) + + self.renWin.Render() + + + def installPipeline(self): + '''Slices a 3D volume and then creates an actor to be rendered''' + + self.ren.AddViewProp(self.cornerAnnotation) + + self.voi.SetInputData(self.img3D) + #select one slice in Z + extent = [ i for i in self.img3D.GetExtent()] + extent[self.sliceOrientation * 2] = self.sliceno + extent[self.sliceOrientation * 2 + 1] = self.sliceno + self.voi.SetVOI(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + + self.voi.Update() + # set window/level for current slices + + + self.wl = vtk.vtkImageMapToWindowLevelColors() + self.ia.SetInputData(self.voi.GetOutput()) + self.ia.Update() + cmax = self.ia.GetMax()[0] + cmin = self.ia.GetMin()[0] + + self.InitialLevel = (cmax+cmin)/2 + self.InitialWindow = cmax-cmin + + + self.wl.SetLevel(self.InitialLevel) + self.wl.SetWindow(self.InitialWindow) + + self.wl.SetInputData(self.voi.GetOutput()) + self.wl.Update() + + self.sliceActor.SetInputData(self.wl.GetOutput()) + self.sliceActor.SetDisplayExtent(extent[0], extent[1], + extent[2], extent[3], + extent[4], extent[5]) + self.sliceActor.Update() + self.sliceActor.SetInterpolate(False) + self.ren.AddActor(self.sliceActor) + self.ren.ResetCamera() + self.ren.Render() + + self.AdjustCamera() + + self.ren.AddViewProp(self.cursorActor) + self.cursorActor.VisibilityOn() + + self.iren.Initialize() + self.renWin.Render() + #self.iren.Start() + + def AdjustCamera(self, resetcamera = False): + self.ren.ResetCameraClippingRange() + if resetcamera: + self.ren.ResetCamera() + + + def getROI(self): + return self.ROI + + def getROIExtent(self): + p0 = self.ROI[0] + p1 = self.ROI[1] + return (p0[0], p1[0],p0[1],p1[1],p0[2],p1[2]) + + ############### Handle events are moved to the interactor style + + + def viewport2imageCoordinate(self, viewerposition): + #Determine point index + + self.picker.Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) + pickPosition = list(self.picker.GetPickPosition()) + pickPosition[self.sliceOrientation] = \ + self.img3D.GetSpacing()[self.sliceOrientation] * self.sliceno + \ + self.img3D.GetOrigin()[self.sliceOrientation] + print ("Pick Position " + str (pickPosition)) + + if (pickPosition != [0,0,0]): + dims = self.img3D.GetDimensions() + print (dims) + spac = self.img3D.GetSpacing() + orig = self.img3D.GetOrigin() + imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] + + pixelValue = self.img3D.GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) + return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) + else: + return (0,0,0,0) + + + + def GetRenderWindow(self): + return self.renWin + + + def startRenderLoop(self): + self.iren.Start() + + def GetSliceOrientation(self): + return self.sliceOrientation + + def GetActiveSlice(self): + return self.sliceno + + def updateCornerAnnotation(self, text , idx=0, visibility=True): + if visibility: + self.cornerAnnotation.VisibilityOn() + else: + self.cornerAnnotation.VisibilityOff() + + self.cornerAnnotation.SetText(idx, text) + self.iren.Render() + + def saveRender(self, filename, renWin=None): + '''Save the render window to PNG file''' + # screenshot code: + w2if = vtk.vtkWindowToImageFilter() + if renWin == None: + renWin = self.renWin + w2if.SetInput(renWin) + w2if.Update() + + writer = vtk.vtkPNGWriter() + writer.SetFileName("%s.png" % (filename)) + writer.SetInputConnection(w2if.GetOutputPort()) + writer.Write() + + def updateROIHistogram(self): + + extent = [0 for i in range(6)] + if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + extent[4] = self.GetActiveSlice() + extent[5] = self.GetActiveSlice()+1 + #y = abs(roi[1][1] - roi[0][1]) + elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: + print ("slice orientation : XY") + extent[0] = self.ROI[0][0] + extent[1] = self.ROI[1][0] + #x = abs(roi[1][0] - roi[0][0]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[2] = self.GetActiveSlice() + extent[3] = self.GetActiveSlice()+1 + elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: + print ("slice orientation : XY") + extent[2] = self.ROI[0][1] + extent[3] = self.ROI[1][1] + #x = abs(roi[1][1] - roi[0][1]) + extent[4] = self.ROI[0][2] + extent[5] = self.ROI[1][2] + #y = abs(roi[1][2] - roi[0][2]) + extent[0] = self.GetActiveSlice() + extent[1] = self.GetActiveSlice()+1 + + self.roiVOI.SetVOI(extent) + self.roiVOI.SetInputData(self.img3D) + self.roiVOI.Update() + irange = self.roiVOI.GetOutput().GetScalarRange() + + self.roiIA.SetInputData(self.roiVOI.GetOutput()) + self.roiIA.IgnoreZeroOff() + self.roiIA.SetComponentExtent(0,int(irange[1]-irange[0]-1),0,0,0,0 ) + self.roiIA.SetComponentOrigin( int(irange[0]),0,0 ); + self.roiIA.SetComponentSpacing( 1,0,0 ); + self.roiIA.Update() + + self.histogramPlotActor.AddDataSetInputConnection(self.roiIA.GetOutputPort()) + self.histogramPlotActor.SetXRange(irange[0],irange[1]) + + self.histogramPlotActor.SetYRange( self.roiIA.GetOutput().GetScalarRange() ) + + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/QVTKWidget.py b/src/Python/ccpi/viewer/QVTKWidget.py new file mode 100644 index 0000000..906786b --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget.py @@ -0,0 +1,340 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter + +class QVTKWidget(QtWidgets.QWidget): + + """ A QVTKWidget for Python and Qt.""" + + # Map between VTK and Qt cursors. + _CURSOR_MAP = { + 0: QtCore.Qt.ArrowCursor, # VTK_CURSOR_DEFAULT + 1: QtCore.Qt.ArrowCursor, # VTK_CURSOR_ARROW + 2: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZENE + 3: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZENWSE + 4: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZESW + 5: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZESE + 6: QtCore.Qt.SizeVerCursor, # VTK_CURSOR_SIZENS + 7: QtCore.Qt.SizeHorCursor, # VTK_CURSOR_SIZEWE + 8: QtCore.Qt.SizeAllCursor, # VTK_CURSOR_SIZEALL + 9: QtCore.Qt.PointingHandCursor, # VTK_CURSOR_HAND + 10: QtCore.Qt.CrossCursor, # VTK_CURSOR_CROSSHAIR + } + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + # the current button + self._ActiveButton = QtCore.Qt.NoButton + + # private attributes + self.__oldFocus = None + self.__saveX = 0 + self.__saveY = 0 + self.__saveModifiers = QtCore.Qt.NoModifier + self.__saveButtons = QtCore.Qt.NoButton + self.__timeframe = 0 + + # create qt-level widget + QtWidgets.QWidget.__init__(self, parent, wflags|QtCore.Qt.MSWindowsOwnDC) + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D() + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + + self._Iren.Register(self._RenderWindow) + self._Iren.SetRenderWindow(self._RenderWindow) + self._RenderWindow.SetWindowInfo(str(int(self.winId()))) + + # do all the necessary qt setup + self.setAttribute(QtCore.Qt.WA_OpaquePaintEvent) + self.setAttribute(QtCore.Qt.WA_PaintOnScreen) + self.setMouseTracking(True) # get all mouse events + self.setFocusPolicy(QtCore.Qt.WheelFocus) + self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)) + + self._Timer = QtCore.QTimer(self) + #self.connect(self._Timer, QtCore.pyqtSignal('timeout()'), self.TimerEvent) + + self._Iren.AddObserver('CreateTimerEvent', self.CreateTimer) + self._Iren.AddObserver('DestroyTimerEvent', self.DestroyTimer) + self._Iren.GetRenderWindow().AddObserver('CursorChangedEvent', + self.CursorChangedEvent) + + # Destructor + def __del__(self): + self._Iren.UnRegister(self._RenderWindow) + #QtWidgets.QWidget.__del__(self) + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + + # GetInteractor + def GetInteractor(self): + return self._Iren + + # Display image data + def GetPyveViewer(self): + return self._PyveViewer + + def __getattr__(self, attr): + """Makes the object behave like a vtkGenericRenderWindowInteractor""" + print (attr) + if attr == '__vtk__': + return lambda t=self._Iren: t + elif hasattr(self._Iren, attr): + return getattr(self._Iren, attr) +# else: +# raise AttributeError( self.__class__.__name__ + \ +# " has no attribute named " + attr ) + + def CreateTimer(self, obj, evt): + self._Timer.start(10) + + def DestroyTimer(self, obj, evt): + self._Timer.stop() + return 1 + + def TimerEvent(self): + self._Iren.InvokeEvent("TimerEvent") + + def CursorChangedEvent(self, obj, evt): + """Called when the CursorChangedEvent fires on the render window.""" + # This indirection is needed since when the event fires, the current + # cursor is not yet set so we defer this by which time the current + # cursor should have been set. + QtCore.QTimer.singleShot(0, self.ShowCursor) + + def HideCursor(self): + """Hides the cursor.""" + self.setCursor(QtCore.Qt.BlankCursor) + + def ShowCursor(self): + """Shows the cursor.""" + vtk_cursor = self._Iren.GetRenderWindow().GetCurrentCursor() + qt_cursor = self._CURSOR_MAP.get(vtk_cursor, QtCore.Qt.ArrowCursor) + self.setCursor(qt_cursor) + + def sizeHint(self): + return QtCore.QSize(400, 400) + + def paintEngine(self): + return None + + def paintEvent(self, ev): + self._RenderWindow.Render() + + def resizeEvent(self, ev): + self._RenderWindow.Render() + w = self.width() + h = self.height() + + self._RenderWindow.SetSize(w, h) + self._Iren.SetSize(w, h) + + def _GetCtrlShiftAlt(self, ev): + ctrl = shift = alt = False + + if hasattr(ev, 'modifiers'): + if ev.modifiers() & QtCore.Qt.ShiftModifier: + shift = True + if ev.modifiers() & QtCore.Qt.ControlModifier: + ctrl = True + if ev.modifiers() & QtCore.Qt.AltModifier: + alt = True + else: + if self.__saveModifiers & QtCore.Qt.ShiftModifier: + shift = True + if self.__saveModifiers & QtCore.Qt.ControlModifier: + ctrl = True + if self.__saveModifiers & QtCore.Qt.AltModifier: + alt = True + + return ctrl, shift, alt + + def enterEvent(self, ev): + if not self.hasFocus(): + self.__oldFocus = self.focusWidget() + self.setFocus() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("EnterEvent") + + def leaveEvent(self, ev): + if self.__saveButtons == QtCore.Qt.NoButton and self.__oldFocus: + self.__oldFocus.setFocus() + self.__oldFocus = None + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("LeaveEvent") + + def mousePressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + repeat = 0 + if ev.type() == QtCore.QEvent.MouseButtonDblClick: + repeat = 1 + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), repeat, None) + + self._Iren.SetAltKey(alt) + self._ActiveButton = ev.button() + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonPressEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonPressEvent") + + def mouseReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + + if self._ActiveButton == QtCore.Qt.LeftButton: + self._Iren.InvokeEvent("LeftButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.RightButton: + self._Iren.InvokeEvent("RightButtonReleaseEvent") + elif self._ActiveButton == QtCore.Qt.MidButton: + self._Iren.InvokeEvent("MiddleButtonReleaseEvent") + + def mouseMoveEvent(self, ev): + self.__saveModifiers = ev.modifiers() + self.__saveButtons = ev.buttons() + self.__saveX = ev.x() + self.__saveY = ev.y() + + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), + ctrl, shift, chr(0), 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("MouseMoveEvent") + + def keyPressEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = str(ev.text()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyPressEvent") + self._Iren.InvokeEvent("CharEvent") + + def keyReleaseEvent(self, ev): + ctrl, shift, alt = self._GetCtrlShiftAlt(ev) + if ev.key() < 256: + key = chr(ev.key()) + else: + key = chr(0) + + self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, + ctrl, shift, key, 0, None) + self._Iren.SetAltKey(alt) + self._Iren.InvokeEvent("KeyReleaseEvent") + + def wheelEvent(self, ev): + print ("angleDeltaX %d" % ev.angleDelta().x()) + print ("angleDeltaY %d" % ev.angleDelta().y()) + if ev.angleDelta().y() >= 0: + self._Iren.InvokeEvent("MouseWheelForwardEvent") + else: + self._Iren.InvokeEvent("MouseWheelBackwardEvent") + + def GetRenderWindow(self): + return self._RenderWindow + + def Render(self): + self.update() + + +def QVTKExample(): + """A simple example that uses the QVTKWidget class.""" + + # every QT app needs an app + app = QtWidgets.QApplication(['PyVE QVTKWidget Example']) + page_VTK = QtWidgets.QWidget() + page_VTK.resize(500,500) + layout = QtWidgets.QVBoxLayout(page_VTK) + # create the widget + widget = QVTKWidget(parent=None) + layout.addWidget(widget) + + #reader = vtk.vtkPNGReader() + #reader.SetFileName("F:\Diagnostics\Images\PyVE\VTKData\Data\camscene.png") + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + widget.SetInput(reader.GetOutput()) + + # show the widget + page_VTK.show() + # start event processing + app.exec_() + +if __name__ == "__main__": + QVTKExample() diff --git a/src/Python/ccpi/viewer/QVTKWidget2.py b/src/Python/ccpi/viewer/QVTKWidget2.py new file mode 100644 index 0000000..e32e1c2 --- /dev/null +++ b/src/Python/ccpi/viewer/QVTKWidget2.py @@ -0,0 +1,84 @@ +################################################################################ +# File: QVTKWidget.py +# Author: Edoardo Pasca +# Description: PyVE Viewer Qt widget +# +# License: +# This file is part of PyVE. PyVE is an open-source image +# analysis and visualization environment focused on medical +# imaging. More info at http://pyve.sourceforge.net +# +# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. Neither name of Edoardo Pasca or Lukas +# Batteau nor the names of any contributors may be used to endorse +# or promote products derived from this software without specific +# prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, +# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +# OF SUCH DAMAGE. +# +# CHANGE HISTORY +# +# 20120118 Edoardo Pasca Initial version +# +############################################################################### + +import os +from PyQt5 import QtCore, QtGui, QtWidgets +#import itk +import vtk +#from viewer import PyveViewer +from ccpi.viewer.CILViewer2D import CILViewer2D , Converter +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor + +class QVTKWidget(QVTKRenderWindowInteractor): + + + def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): + kw = dict() + super().__init__(parent, **kw) + + + # Link to PyVE Viewer + self._PyveViewer = CILViewer2D(400,400) + #self._Viewer = self._PyveViewer._vtkPyveViewer + + self._Iren = self._PyveViewer.GetInteractor() + kw['iren'] = self._Iren + #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() + self._RenderWindow = self._PyveViewer.GetRenderWindow() + #self._RenderWindow = self._Viewer.GetRenderWindow() + kw['rw'] = self._RenderWindow + + + + + def GetInteractor(self): + return self._Iren + + # Display image data + def SetInput(self, imageData): + self._PyveViewer.setInput3DData(imageData) + \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__init__.py b/src/Python/ccpi/viewer/__init__.py new file mode 100644 index 0000000..946188b --- /dev/null +++ b/src/Python/ccpi/viewer/__init__.py @@ -0,0 +1 @@ +from ccpi.viewer.CILViewer import CILViewer \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc new file mode 100644 index 0000000..711f77a Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc new file mode 100644 index 0000000..77c2ca8 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc new file mode 100644 index 0000000..3d11b87 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc new file mode 100644 index 0000000..2fa2eaf Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..fcea537 Binary files /dev/null and b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc differ diff --git a/src/Python/ccpi/viewer/embedvtk.py b/src/Python/ccpi/viewer/embedvtk.py new file mode 100644 index 0000000..b5eb0a7 --- /dev/null +++ b/src/Python/ccpi/viewer/embedvtk.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Jul 27 12:18:58 2017 + +@author: ofn77899 +""" + +#!/usr/bin/env python + +import sys +import vtk +from PyQt5 import QtCore, QtWidgets +from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor +import QVTKWidget2 + +class MainWindow(QtWidgets.QMainWindow): + + def __init__(self, parent = None): + QtWidgets.QMainWindow.__init__(self, parent) + + self.frame = QtWidgets.QFrame() + + self.vl = QtWidgets.QVBoxLayout() +# self.vtkWidget = QVTKRenderWindowInteractor(self.frame) + + self.vtkWidget = QVTKWidget2.QVTKWidget(self.frame) + self.iren = self.vtkWidget.GetInteractor() + self.vl.addWidget(self.vtkWidget) + + + + + self.ren = vtk.vtkRenderer() + self.vtkWidget.GetRenderWindow().AddRenderer(self.ren) +# self.iren = self.vtkWidget.GetRenderWindow().GetInteractor() +# +# # Create source +# source = vtk.vtkSphereSource() +# source.SetCenter(0, 0, 0) +# source.SetRadius(5.0) +# +# # Create a mapper +# mapper = vtk.vtkPolyDataMapper() +# mapper.SetInputConnection(source.GetOutputPort()) +# +# # Create an actor +# actor = vtk.vtkActor() +# actor.SetMapper(mapper) +# +# self.ren.AddActor(actor) +# +# self.ren.ResetCamera() +# + self.frame.setLayout(self.vl) + self.setCentralWidget(self.frame) + reader = vtk.vtkMetaImageReader() + reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") + reader.Update() + + self.vtkWidget.SetInput(reader.GetOutput()) + + #self.vktWidget.Initialize() + #self.vktWidget.Start() + + self.show() + #self.iren.Initialize() + + +if __name__ == "__main__": + + app = QtWidgets.QApplication(sys.argv) + + window = MainWindow() + + sys.exit(app.exec_()) \ No newline at end of file -- cgit v1.2.3 From 1a841b967e1db92a04e8e12c52b83489da27be1c Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:08:32 +0100 Subject: initial revision --- src/Python/ccpi/imaging/Regularizer.py | 322 +++++++++++++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 src/Python/ccpi/imaging/Regularizer.py (limited to 'src') diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py new file mode 100644 index 0000000..fb9ae08 --- /dev/null +++ b/src/Python/ccpi/imaging/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 8 14:26:00 2017 + +@author: ofn77899 +""" + +from ccpi.imaging import cpu_regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): + '''Class to handle regularizer algorithms to be used during reconstruction + + Currently 5 CPU (OMP) regularization algorithms are available: + + 1) SplitBregman_TV + 2) FGP_TV + 3) LLT_model + 4) PatchBased_Regul + 5) TGV_PD + + Usage: + the regularizer can be invoked as object or as static method + Depending on the actual regularizer the input parameter may vary, and + a different default setting is defined. + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + + out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., + number_of_iterations=30, tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + A number of optional parameters can be passed or skipped + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + + ''' + class Algorithm(Enum): + SplitBregman_TV = cpu_regularizers.SplitBregman_TV + FGP_TV = cpu_regularizers.FGP_TV + LLT_model = cpu_regularizers.LLT_model + PatchBased_Regul = cpu_regularizers.PatchBased_Regul + TGV_PD = cpu_regularizers.TGV_PD + # Algorithm + + class TotalVariationPenalty(Enum): + isotropic = 0 + l1 = 1 + # TotalVariationPenalty + + def __init__(self , algorithm, debug = True): + self.setAlgorithm ( algorithm ) + self.debug = debug + # __init__ + + def setAlgorithm(self, algorithm): + self.algorithm = algorithm + self.pars = self.getDefaultParsForAlgorithm(algorithm) + # setAlgorithm + + def getDefaultParsForAlgorithm(self, algorithm): + pars = dict() + + if algorithm == Regularizer.Algorithm.SplitBregman_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 35 + pars['tolerance_constant'] = 0.0001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.FGP_TV : + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['number_of_iterations'] = 50 + pars['tolerance_constant'] = 0.001 + pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic + + elif algorithm == Regularizer.Algorithm.LLT_model: + pars['algorithm'] = algorithm + pars['input'] = None + pars['regularization_parameter'] = None + pars['time_step'] = None + pars['number_of_iterations'] = None + pars['tolerance_constant'] = None + pars['restrictive_Z_smoothing'] = 0 + + elif algorithm == Regularizer.Algorithm.PatchBased_Regul: + pars['algorithm'] = algorithm + pars['input'] = None + pars['searching_window_ratio'] = None + pars['similarity_window_ratio'] = None + pars['PB_filtering_parameter'] = None + pars['regularization_parameter'] = None + + elif algorithm == Regularizer.Algorithm.TGV_PD: + pars['algorithm'] = algorithm + pars['input'] = None + pars['first_order_term'] = None + pars['second_order_term'] = None + pars['number_of_iterations'] = None + pars['regularization_parameter'] = None + + else: + raise Exception('Unknown regularizer algorithm') + + return pars + # parsForAlgorithm + + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.pars.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + def getParameter(self, **kwargs): + ret = {} + for key , value in kwargs.items(): + if key in self.pars.keys(): + ret[key] = self.pars[key] + else: + raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) + # setParameter + + + def __call__(self, input = None, regularization_parameter = None, **kwargs): + '''Actual call for the regularizer. + + One can either set the regularization parameters first and then call the + algorithm or set the regularization parameter during the call (as + is done in the static methods). + ''' + + if kwargs is not None: + for key, value in kwargs.items(): + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + if input is not None: + self.pars['input'] = input + if regularization_parameter is not None: + self.pars['regularization_parameter'] = regularization_parameter + + if self.debug: + print ("--------------------------------------------------") + for key, value in self.pars.items(): + if key== 'algorithm' : + print("{0} = {1}".format(key, value.__name__)) + elif key == 'input': + print("{0} = {1}".format(key, np.shape(value))) + else: + print("{0} = {1}".format(key, value)) + + + if None in self.pars: + raise Exception("Not all parameters have been provided") + + input = self.pars['input'] + regularization_parameter = self.pars['regularization_parameter'] + if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.FGP_TV : + return self.algorithm(input, regularization_parameter, + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['TV_penalty'].value ) + elif self.algorithm == Regularizer.Algorithm.LLT_model : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, + regularization_parameter, + self.pars['time_step'] , + self.pars['number_of_iterations'], + self.pars['tolerance_constant'], + self.pars['restrictive_Z_smoothing'] ) + elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + return self.algorithm(input, regularization_parameter, + self.pars['searching_window_ratio'] , + self.pars['similarity_window_ratio'] , + self.pars['PB_filtering_parameter']) + elif self.algorithm == Regularizer.Algorithm.TGV_PD : + #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) + # no default + if len(np.shape(input)) == 2: + return self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + elif len(np.shape(input)) == 3: + #assuming it's 3D + # run independent calls on each slice + out3d = input.copy() + for i in range(np.shape(input)[2]): + out = self.algorithm(input, regularization_parameter, + self.pars['first_order_term'] , + self.pars['second_order_term'] , + self.pars['number_of_iterations']) + # copy the result in the 3D image + out3d.T[i] = out[0].copy() + # append the rest of the info that the algorithm returns + output = [out3d] + for i in range(1,len(out)): + output.append(out[i]) + return output + + + + + + # __call__ + + @staticmethod + def SplitBregman_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def FGP_TV(input, regularization_parameter , **kwargs): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.FGP_TV) + out = list( reg(input, regularization_parameter, **kwargs) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def LLT_model(input, regularization_parameter , time_step, number_of_iterations, + tolerance_constant, restrictive_Z_smoothing=0): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.LLT_model) + out = list( reg(input, regularization_parameter, time_step=time_step, + number_of_iterations=number_of_iterations, + tolerance_constant=tolerance_constant, + restrictive_Z_smoothing=restrictive_Z_smoothing) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def PatchBased_Regul(input, regularization_parameter, + searching_window_ratio, + similarity_window_ratio, + PB_filtering_parameter): + start_time = timeit.default_timer() + reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul) + out = list( reg(input, + regularization_parameter, + searching_window_ratio=searching_window_ratio, + similarity_window_ratio=similarity_window_ratio, + PB_filtering_parameter=PB_filtering_parameter ) + ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + return out + + @staticmethod + def TGV_PD(input, regularization_parameter , first_order_term, + second_order_term, number_of_iterations): + start_time = timeit.default_timer() + + reg = Regularizer(Regularizer.Algorithm.TGV_PD) + out = list( reg(input, regularization_parameter, + first_order_term=first_order_term, + second_order_term=second_order_term, + number_of_iterations=number_of_iterations) ) + out.append(reg.pars) + txt = reg.printParametersToString() + txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) + out.append(txt) + + return out + + def printParametersToString(self): + txt = r'' + for key, value in self.pars.items(): + if key== 'algorithm' : + txt += "{0} = {1}".format(key, value.__name__) + elif key == 'input': + txt += "{0} = {1}".format(key, np.shape(value)) + else: + txt += "{0} = {1}".format(key, value) + txt += '\n' + return txt + -- cgit v1.2.3 From 9f8fb57e1e89c1ad200d9c7eada5c653be34db66 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 15:10:04 +0100 Subject: module rename --- src/Python/fista_module.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index eacda3d..c36329e 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -1032,13 +1032,13 @@ bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_al return result; } -BOOST_PYTHON_MODULE(regularizers) +BOOST_PYTHON_MODULE(cpu_regularizers) { np::initialize(); //To specify that this module is a package bp::object package = bp::scope(); - package.attr("__path__") = "regularizers"; + package.attr("__path__") = "cpu_regularizers"; np::dtype dt1 = np::dtype::get_builtin(); np::dtype dt2 = np::dtype::get_builtin(); -- cgit v1.2.3 From a203949c84484fe2641e39451f033d20d445b1f3 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 16:51:18 +0100 Subject: export/import data from hdf5 Added file to export the data from DemoRD2.m to HDF5 to pass it to Python. Added file to import the data from DemoRD2.m from HDF5. --- src/Python/test/readhd5.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/Python/test/readhd5.py (limited to 'src') diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py new file mode 100644 index 0000000..1e19e14 --- /dev/null +++ b/src/Python/test/readhd5.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +""" + +import h5py +import numpy + +def getEntry(nx, location): + for item in nx[location].keys(): + print (item) + +filename = r'C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\Demos\DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D')) +Weights3D = numpy.asarray(nx.get('/Weights3D')) +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad')) +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] \ No newline at end of file -- cgit v1.2.3 From 8d53e078d3dabf7107982a8d25b4d66b1d0e73ce Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 23 Aug 2017 16:54:59 +0100 Subject: initial revision for testing --- .../ccpi/reconstruction/FISTAReconstructor.py | 354 +++++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 src/Python/ccpi/reconstruction/FISTAReconstructor.py (limited to 'src') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py new file mode 100644 index 0000000..ea96b53 --- /dev/null +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -0,0 +1,354 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +#from ccpi.reconstruction.parallelbeam import alg + +from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + + + +class FISTAReconstructor(): + '''FISTA-based reconstruction algorithm using ASTRA-toolbox + + ''' + # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> + # ___Input___: + # params.[] file: + # - .proj_geom (geometry of the projector) [required] + # - .vol_geom (geometry of the reconstructed object) [required] + # - .sino (vectorized in 2D or 3D sinogram) [required] + # - .iterFISTA (iterations for the main loop, default 40) + # - .L_const (Lipschitz constant, default Power method) ) + # - .X_ideal (ideal image, if given) + # - .weights (statisitcal weights, size of the sinogram) + # - .ROI (Region-of-interest, only if X_ideal is given) + # - .initialize (a 'warm start' using SIRT method from ASTRA) + #----------------Regularization choices------------------------ + # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) + # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) + # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) + # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) + # - .Regul_Iterations (iterations for the selected penalty, default 25) + # - .Regul_tauLLT (time step parameter for LLT term) + # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) + # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) + #----------------Visualization parameters------------------------ + # - .show (visualize reconstruction 1/0, (0 default)) + # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) + # - .slice (for 3D volumes - slice number to imshow) + # ___Output___: + # 1. X - reconstructed image/volume + # 2. output - a structure with + # - .Resid_error - residual error (if X_ideal is given) + # - .objective: value of the objective function + # - .L_const: Lipshitz constant to avoid recalculations + + # References: + # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + # Problems" by A. Beck and M Teboulle + # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo + # 3. "A novel tomographic reconstruction method based on the robust + # Student's t function for suppressing data outliers" D. Kazantsev et.al. + # D. Kazantsev, 2016-17 + def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + self.params = dict() + self.params['projector_geometry'] = projector_geometry + self.params['output_geometry'] = output_geometry + self.params['input_sinogram'] = input_sinogram + detectors, nangles, sliceZ = numpy.shape(input_sinogram) + self.params['detectors'] = detectors + self.params['number_og_angles'] = nangles + self.params['SlicesZ'] = sliceZ + + # Accepted input keywords + kw = ('number_of_iterations', + 'Lipschitz_constant' , + 'ideal_image' , + 'weights' , + 'region_of_interest' , + 'initialize' , + 'regularizer' , + 'ring_lambda_R_L1', + 'ring_alpha') + + # handle keyworded parameters + if kwargs is not None: + for key, value in kwargs.items(): + if key in kw: + #print("{0} = {1}".format(key, value)) + self.pars[key] = value + + # set the default values for the parameters if not set + if 'number_of_iterations' in kwargs.keys(): + self.pars['number_of_iterations'] = kwargs['number_of_iterations'] + else: + self.pars['number_of_iterations'] = 40 + if 'weights' in kwargs.keys(): + self.pars['weights'] = kwargs['weights'] + else: + self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + if 'Lipschitz_constant' in kwargs.keys(): + self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] + else: + self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + + if not self.pars['ideal_image'] in kwargs.keys(): + self.pars['ideal_image'] = None + + if not self.pars['region_of_interest'] : + if self.pars['ideal_image'] == None: + pass + else: + self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + + if not self.pars['regularizer'] : + self.pars['regularizer'] = None + else: + # the regularizer must be a correctly instantiated object + if not self.pars['ring_lambda_R_L1']: + self.pars['ring_lambda_R_L1'] = 0 + if not self.pars['ring_alpha']: + self.pars['ring_alpha'] = 1 + + + + + def calculateLipschitzConstantWithPowerMethod(self): + ''' using Power method (PM) to establish L constant''' + + #N = params.vol_geom.GridColCount + N = self.pars['output_geometry'].GridColCount + proj_geom = self.params['projector_geometry'] + vol_geom = self.params['output_geometry'] + weights = self.pars['weights'] + SlicesZ = self.pars['SlicesZ'] + + if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights.T[0]) + proj_geomT = proj_geom.copy(); + proj_geomT.DetectorRowCount = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + + for i in range(niter): + if i == 0: + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight * y # element wise multiplication + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); + y = sqweight*y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT + else: + #% divergen beam geometry + #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights.T[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 + + return s + + + def setRegularizer(self, regularizer): + if regularizer is not None: + self.pars['regularizer'] = regularizer + + + + + +def getEntry(location): + for item in nx[location].keys(): + print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" +##nx = h5py.File(fname, "r") +## +### the data are stored in a particular location in the hdf5 +##for item in nx['entry1/tomo_entry/data'].keys(): +## print (item) +## +##data = nx.get('entry1/tomo_entry/data/rotation_angle') +##angles = numpy.zeros(data.shape) +##data.read_direct(angles) +##print (angles) +### angles should be in degrees +## +##data = nx.get('entry1/tomo_entry/data/data') +##stack = numpy.zeros(data.shape) +##data.read_direct(stack) +##print (data.shape) +## +##print ("Data Loaded") +## +## +### Normalize +##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +##itype = numpy.zeros(data.shape) +##data.read_direct(itype) +### 2 is dark field +##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +##dark = darks[0] +##for i in range(1, len(darks)): +## dark += darks[i] +##dark = dark / len(darks) +###dark[0][0] = dark[0][1] +## +### 1 is flat field +##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +##flat = flats[0] +##for i in range(1, len(flats)): +## flat += flats[i] +##flat = flat / len(flats) +###flat[0][0] = dark[0][1] +## +## +### 0 is projection data +##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +##angle_proj = numpy.asarray (angle_proj) +##angle_proj = angle_proj.astype(numpy.float32) +## +### normalized data are +### norm = (projection - dark)/(flat-dark) +## +##def normalize(projection, dark, flat, def_val=0.1): +## a = (projection - dark) +## b = (flat-dark) +## with numpy.errstate(divide='ignore', invalid='ignore'): +## c = numpy.true_divide( a, b ) +## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 +## return c +## +## +##norm = [normalize(projection, dark, flat) for projection in proj] +##norm = numpy.asarray (norm) +##norm = norm.astype(numpy.float32) + + +##niterations = 15 +##threads = 3 +## +##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## iteration_values, False) +##print ("iteration values %s" % str(iteration_values)) +## +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +##iteration_values = numpy.zeros((niterations,)) +##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +## numpy.double(1e-5), iteration_values , False) +##print ("iteration values %s" % str(iteration_values)) +## +## +####numpy.save("cgls_recon.npy", img_data) +##import matplotlib.pyplot as plt +##fig, ax = plt.subplots(1,6,sharey=True) +##ax[0].imshow(img_cgls[80]) +##ax[0].axis('off') # clear x- and y-axes +##ax[1].imshow(img_sirt[80]) +##ax[1].axis('off') # clear x- and y-axes +##ax[2].imshow(img_mlem[80]) +##ax[2].axis('off') # clear x- and y-axesplt.show() +##ax[3].imshow(img_cgls_conv[80]) +##ax[3].axis('off') # clear x- and y-axesplt.show() +##ax[4].imshow(img_cgls_tikhonov[80]) +##ax[4].axis('off') # clear x- and y-axesplt.show() +##ax[5].imshow(img_cgls_TVreg[80]) +##ax[5].axis('off') # clear x- and y-axesplt.show() +## +## +##plt.show() +## + -- cgit v1.2.3 From 05bd227b56ec43c97c81630f50c3b741ef86ddcd Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:39:37 +0100 Subject: bugfix --- src/Python/Matlab2Python_utils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp index e15d738..ee76bc7 100644 --- a/src/Python/Matlab2Python_utils.cpp +++ b/src/Python/Matlab2Python_utils.cpp @@ -123,7 +123,7 @@ T * mxGetData(const np::ndarray pm) { probably this would work. A = reinterpret_cast(prhs[0]); */ - return reinterpret_cast(prhs[0]); + //return reinterpret_cast(prhs[0]); } template @@ -273,4 +273,4 @@ BOOST_PYTHON_MODULE(prova) //numpy_boost_python_register_type(); def("mexFunction", mexFunction); def("doSomething", doSomething); -} \ No newline at end of file +} -- cgit v1.2.3 From 879c6723969eaea8e00f97291612fe22443c69f3 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:41:10 +0100 Subject: initial facility to test the FISTA --- src/Python/test_reconstructor.py | 179 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 src/Python/test_reconstructor.py (limited to 'src') diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py new file mode 100644 index 0000000..0fd08f5 --- /dev/null +++ b/src/Python/test_reconstructor.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.reconstruction_dev.FISTAReconstructor import FISTAReconstructor +import astra + +##def getEntry(nx, location): +## for item in nx[location].keys(): +## print (item) + +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D')) +Weights3D = numpy.asarray(nx.get('/Weights3D')) +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad')) +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 3 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, + image_size_y, + image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +#fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D ) + #N = params.vol_geom.GridColCount + +pars = dict() +pars['projector_geometry'] = proj_geom +pars['output_geometry'] = vol_geom +pars['input_sinogram'] = Sino3D +sliceZ , nangles , detectors = numpy.shape(Sino3D) +pars['detectors'] = detectors +pars['number_of_angles'] = nangles +pars['SlicesZ'] = sliceZ + + +pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) + +N = pars['output_geometry']['GridColCount'] +proj_geom = pars['projector_geometry'] +vol_geom = pars['output_geometry'] +weights = pars['weights'] +SlicesZ = pars['SlicesZ'] + +if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + #% for parallel geometry we can do just one slice + print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 15;# % number of iteration for the PM + #N = params.vol_geom.GridColCount; + #x1 = rand(N,N,1); + x1 = numpy.random.rand(1,N,N) + #sqweight = sqrt(weights(:,:,1)); + sqweight = numpy.sqrt(weights[0]) + proj_geomT = proj_geom.copy(); + proj_geomT['DetectorRowCount'] = 1; + vol_geomT = vol_geom.copy(); + vol_geomT['GridSliceCount'] = 1; + + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + + import matplotlib.pyplot as plt + fig = plt.figure() + + #a.set_title('Lipschitz') + for i in range(niter): +# [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); +# s = norm(x1(:)); +# x1 = x1/s; +# [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); +# y = sqweight.*y; +# astra_mex_data3d('delete', sino_id); +# astra_mex_data3d('delete', id); + print ("iteration {0}".format(i)) + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + #a=fig.add_subplot(2,1,1) + #imgplot = plt.imshow(y[0]) + + y = sqweight * y # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geomT, + vol_geomT); + print ("shape {1} x1 {0}".format(x1.T[:4].T, numpy.shape(x1))) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + print ("x1 {0}".format(x1.T[:4].T)) + +# ### this line? +# sino_id, y = astra.creators.create_sino3d_gpu(x1, +# proj_geomT, +# vol_geomT); +# y = sqweight * y; + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + del proj_geomT + del vol_geomT +else: + #% divergen beam geometry + print('Calculating Lipshitz constant for divergen beam geometry...') + niter = 8; #% number of iteration for PM + x1 = numpy.random.rand(SlicesZ , N , N); + #sqweight = sqrt(weights); + sqweight = numpy.sqrt(weights[0]) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id); + + for i in range(niter): + #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); + idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, + proj_geom, + vol_geom) + s = numpy.linalg.norm(x1) + ### this line? + x1 = x1/s; + ### this line? + #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geom, + vol_geom); + + y = sqweight*y; + #astra_mex_data3d('delete', sino_id); + #astra_mex_data3d('delete', id); + astra.matlab.data3d('delete', sino_id); + astra.matlab.data3d('delete', idx); + #end + #clear x1 + del x1 -- cgit v1.2.3 From e58f774938edd3664dfb1f3905964b3add050bc9 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:05 +0100 Subject: initial revision --- src/Python/ccpi/imaging/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/imaging/__init__.py (limited to 'src') diff --git a/src/Python/ccpi/imaging/__init__.py b/src/Python/ccpi/imaging/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 9c974a26c0fc8060008745796fbe9f7ef5c250eb Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 24 Aug 2017 16:42:27 +0100 Subject: initial revision --- src/Python/ccpi/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/__init__.py (limited to 'src') diff --git a/src/Python/ccpi/__init__.py b/src/Python/ccpi/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From c8693f530e95e140a3fba85fc65d879b51b79e6d Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 11 Oct 2017 15:11:00 +0100 Subject: table with regularizers output --- src/Python/test_regularizers.py | 66 ++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 33 deletions(-) (limited to 'src') diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py index 755804a..5804897 100644 --- a/src/Python/test_regularizers.py +++ b/src/Python/test_regularizers.py @@ -163,52 +163,52 @@ imgplot = plt.imshow(reg_output[-1][0]) # # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise # # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); -# out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - # searching_window_ratio=3, - # similarity_window_ratio=1, - # PB_filtering_parameter=0.08) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,5) +a=fig.add_subplot(2,3,5) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) -# ###################### TGV_PD ######################################### -# # Quick 2D denoising example in Matlab: -# # Im = double(imread('lena_gray_256.tif'))/255; % loading image -# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -# out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, - # first_order_term=1.3, - # second_order_term=1, - # number_of_iterations=550) -# pars = out2[-2] -# reg_output.append(out2) +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) -# a=fig.add_subplot(2,3,6) +a=fig.add_subplot(2,3,6) -# textstr = out2[-1] +textstr = out2[-1] -# # these are matplotlib.patch.Patch properties -# props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# # place a text box in upper left in axes coords -# a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - # verticalalignment='top', bbox=props) -# imgplot = plt.imshow(reg_output[-1][0]) +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) plt.show() -- cgit v1.2.3 From 776070e22bf95491275a023f3a5ac00cea356714 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 11 Oct 2017 15:12:42 +0100 Subject: read and plot the hdf5 --- src/Python/test/readhd5.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py index 1e19e14..b042341 100644 --- a/src/Python/test/readhd5.py +++ b/src/Python/test/readhd5.py @@ -25,4 +25,17 @@ angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] angles_rad = numpy.asarray(nx.get('/angles_rad')) recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] -slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] \ No newline at end of file +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +#from ccpi.viewer.CILViewer2D import CILViewer2D +#v = CILViewer2D() +#v.setInputAsNumpy(Weights3D) +#v.startRenderLoop() + +import matplotlib.pyplot as plt +fig = plt.figure() + +a=fig.add_subplot(1,1,1) +a.set_title('noise') +imgplot = plt.imshow(Weights3D[0].T) +plt.show() -- cgit v1.2.3 From 5c978b706192bc5885c7e5001a4bc4626f63d29f Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 11 Oct 2017 15:49:18 +0100 Subject: initial revision --- src/Python/test/simple_astra_test.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/Python/test/simple_astra_test.py (limited to 'src') diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py new file mode 100644 index 0000000..905eeea --- /dev/null +++ b/src/Python/test/simple_astra_test.py @@ -0,0 +1,25 @@ +import astra +import numpy + +detectorSpacingX = 1.0 +detectorSpacingY = 1.0 +det_row_count = 128 +det_col_count = 128 + +angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +image_size_x = 64 +image_size_y = 64 +image_size_z = 32 + +vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z) + +x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x) +sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom) -- cgit v1.2.3 From 49c4a595c58d296c3a4b2f7fd480e9c64f638897 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 13 Oct 2017 16:48:24 +0100 Subject: Added setParameter minor beautification of code --- src/Python/ccpi/fista/FISTAReconstructor.py | 164 ++++++---------------------- 1 file changed, 34 insertions(+), 130 deletions(-) (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index 1e76815..cbd27da 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -73,7 +73,8 @@ class FISTAReconstructor(): # 3. "A novel tomographic reconstruction method based on the robust # Student's t function for suppressing data outliers" D. Kazantsev et.al. # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): + def __init__(self, projector_geometry, output_geometry, input_sinogram, + **kwargs): # handle parmeters: # obligatory parameters self.pars = dict() @@ -98,6 +99,7 @@ class FISTAReconstructor(): 'regularizer' , 'ring_lambda_R_L1', 'ring_alpha') + self.acceptedInputKeywords = kw # handle keyworded parameters if kwargs is not None: @@ -114,11 +116,14 @@ class FISTAReconstructor(): if 'weights' in kwargs.keys(): self.pars['weights'] = kwargs['weights'] else: - self.pars['weights'] = numpy.ones(numpy.shape(self.pars['input_sinogram'])) + self.pars['weights'] = \ + numpy.ones(numpy.shape( + self.pars['input_sinogram'])) if 'Lipschitz_constant' in kwargs.keys(): self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + self.pars['Lipschitz_constant'] = \ + self.calculateLipschitzConstantWithPowerMethod() if not 'ideal_image' in kwargs.keys(): self.pars['ideal_image'] = None @@ -127,7 +132,8 @@ class FISTAReconstructor(): if self.pars['ideal_image'] == None: pass else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) + self.pars['region_of_interest'] = numpy.nonzero( + self.pars['ideal_image']>0.0) if not 'regularizer' in kwargs.keys() : self.pars['regularizer'] = None @@ -140,7 +146,29 @@ class FISTAReconstructor(): + def setParameter(self, **kwargs): + '''set named parameter for the regularization engine + raises Exception if the named parameter is not recognized + Typical usage is: + + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + + it can be also used as + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + reg.setParameter(input=u0 , regularization_parameter=10.) + ''' + + for key , value in kwargs.items(): + if key in self.acceptedInputKeywords.keys(): + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for '.format(key) + + 'Reconstruction algorithm') + # setParameter + def calculateLipschitzConstantWithPowerMethod(self): ''' using Power method (PM) to establish L constant''' @@ -152,7 +180,8 @@ class FISTAReconstructor(): - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + if (proj_geom['type'] == 'parallel') or \ + (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice #print('Calculating Lipshitz constant for parallel beam geometry...') niter = 5;# % number of iteration for the PM @@ -262,128 +291,3 @@ class FISTAReconstructor(): - - -def getEntry(location, nx): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" -##nx = h5py.File(fname, "r") -## -### the data are stored in a particular location in the hdf5 -##for item in nx['entry1/tomo_entry/data'].keys(): -## print (item) -## -##data = nx.get('entry1/tomo_entry/data/rotation_angle') -##angles = numpy.zeros(data.shape) -##data.read_direct(angles) -##print (angles) -### angles should be in degrees -## -##data = nx.get('entry1/tomo_entry/data/data') -##stack = numpy.zeros(data.shape) -##data.read_direct(stack) -##print (data.shape) -## -##print ("Data Loaded") -## -## -### Normalize -##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -##itype = numpy.zeros(data.shape) -##data.read_direct(itype) -### 2 is dark field -##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -##dark = darks[0] -##for i in range(1, len(darks)): -## dark += darks[i] -##dark = dark / len(darks) -###dark[0][0] = dark[0][1] -## -### 1 is flat field -##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -##flat = flats[0] -##for i in range(1, len(flats)): -## flat += flats[i] -##flat = flat / len(flats) -###flat[0][0] = dark[0][1] -## -## -### 0 is projection data -##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = numpy.asarray (angle_proj) -##angle_proj = angle_proj.astype(numpy.float32) -## -### normalized data are -### norm = (projection - dark)/(flat-dark) -## -##def normalize(projection, dark, flat, def_val=0.1): -## a = (projection - dark) -## b = (flat-dark) -## with numpy.errstate(divide='ignore', invalid='ignore'): -## c = numpy.true_divide( a, b ) -## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 -## return c -## -## -##norm = [normalize(projection, dark, flat) for projection in proj] -##norm = numpy.asarray (norm) -##norm = norm.astype(numpy.float32) - - -##niterations = 15 -##threads = 3 -## -##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## iteration_values, False) -##print ("iteration values %s" % str(iteration_values)) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -## -## -####numpy.save("cgls_recon.npy", img_data) -##import matplotlib.pyplot as plt -##fig, ax = plt.subplots(1,6,sharey=True) -##ax[0].imshow(img_cgls[80]) -##ax[0].axis('off') # clear x- and y-axes -##ax[1].imshow(img_sirt[80]) -##ax[1].axis('off') # clear x- and y-axes -##ax[2].imshow(img_mlem[80]) -##ax[2].axis('off') # clear x- and y-axesplt.show() -##ax[3].imshow(img_cgls_conv[80]) -##ax[3].axis('off') # clear x- and y-axesplt.show() -##ax[4].imshow(img_cgls_tikhonov[80]) -##ax[4].axis('off') # clear x- and y-axesplt.show() -##ax[5].imshow(img_cgls_TVreg[80]) -##ax[5].axis('off') # clear x- and y-axesplt.show() -## -## -##plt.show() -## - -- cgit v1.2.3 From 24598bda0c2983664f0c5e1aefa576e5d0a36db7 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 13 Oct 2017 16:52:34 +0100 Subject: uses FISTAReconstructor class deleted the calculation of the lipschitz constant that is now in the reconstructor class. --- src/Python/test_reconstructor.py | 144 ++------------------------------------- 1 file changed, 6 insertions(+), 138 deletions(-) (limited to 'src') diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index a4a622b..a338d34 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -58,143 +58,11 @@ vol_geom = astra.creators.create_vol_geom( image_size_x, ## First pass the arguments to the FISTAReconstructor and test the ## Lipschitz constant -fistaRecon = FISTAReconstructor(proj_geom, vol_geom, Sino3D , weights=Weights3D) -print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) - #N = params.vol_geom.GridColCount - -pars = dict() -pars['projector_geometry'] = proj_geom.copy() -pars['output_geometry'] = vol_geom.copy() -pars['input_sinogram'] = Sino3D.copy() -sliceZ , nangles , detectors = numpy.shape(Sino3D) -pars['detectors'] = detectors -pars['number_of_angles'] = nangles -pars['SlicesZ'] = sliceZ - - -#pars['weights'] = numpy.ones(numpy.shape(pars['input_sinogram'])) -pars['weights'] = Weights3D.copy() +fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) -N = pars['output_geometry']['GridColCount'] -proj_geom = pars['projector_geometry'] -vol_geom = pars['output_geometry'] -weights = pars['weights'] -SlicesZ = pars['SlicesZ'] - -if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - print('Calculating Lipshitz constant for parallel beam geometry...') - niter = 5;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights[0]) - proj_geomT = proj_geom.copy(); - proj_geomT['DetectorRowCount'] = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - - import matplotlib.pyplot as plt - fig = [] - props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) - - #a.set_title('Lipschitz') - for i in range(niter): -# [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); -# s = norm(x1(:)); -# x1 = x1/s; -# [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); -# y = sqweight.*y; -# astra_mex_data3d('delete', sino_id); -# astra_mex_data3d('delete', id); - #print ("iteration {0}".format(i)) - fig.append(plt.figure()) +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) - a=fig[-1].add_subplot(1,2,1) - a.text(0.05, 0.95, "iteration {0}, x1".format(i), transform=a.transAxes, - fontsize=14,verticalalignment='top', bbox=props) - - imgplot = plt.imshow(x1[0].copy()) - - - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geomT, - vol_geomT) - a=fig[-1].add_subplot(1,2,2) - a.text(0.05, 0.95, "iteration {0}, y".format(i), - transform=a.transAxes, fontsize=14,verticalalignment='top', - bbox=props) - - imgplot = plt.imshow(y[0].copy()) - - y = (sqweight * y) # element wise multiplication - - #b=fig.add_subplot(2,1,2) - #imgplot = plt.imshow(x1[0]) - #plt.show() - - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - del x1 - - idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y), - proj_geomT, - vol_geomT) - del y - - - s = numpy.linalg.norm(x1) - ### this line? - x1 = (x1/s) - -# ### this line? -# sino_id, y = astra.creators.create_sino3d_gpu(x1, -# proj_geomT, -# vol_geomT); -# y = sqweight * y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx) - print ("iteration {0} s= {1}".format(i,s)) - - #end - del proj_geomT - del vol_geomT - #plt.show() -else: - #% divergen beam geometry - print('Calculating Lipshitz constant for divergen beam geometry...') - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 +## the calculation of the lipschitz constant should not start by itself -- cgit v1.2.3 From 2353624fcb8241222e2044cb9d10ffa7c11c87c6 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 13 Oct 2017 16:55:15 +0100 Subject: changes for vishighmem --- src/Python/test/readhd5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py index 1e19e14..406fda9 100644 --- a/src/Python/test/readhd5.py +++ b/src/Python/test/readhd5.py @@ -12,7 +12,7 @@ def getEntry(nx, location): for item in nx[location].keys(): print (item) -filename = r'C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\Demos\DendrData.h5' +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' nx = h5py.File(filename, "r") #getEntry(nx, '/') # I have exported the entries as children of / @@ -25,4 +25,4 @@ angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] angles_rad = numpy.asarray(nx.get('/angles_rad')) recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] -slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] \ No newline at end of file +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] -- cgit v1.2.3 From 3c2815ec1d0ddd9d00a5c1f454fcecc060126623 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 17 Oct 2017 09:35:11 +0100 Subject: Added many methods --- src/Python/ccpi/fista/FISTAReconstructor.py | 184 ++++++++++++++++++++++++---- 1 file changed, 160 insertions(+), 24 deletions(-) (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index cbd27da..8318ea6 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -78,19 +78,28 @@ class FISTAReconstructor(): # handle parmeters: # obligatory parameters self.pars = dict() - self.pars['projector_geometry'] = projector_geometry - self.pars['output_geometry'] = output_geometry - self.pars['input_sinogram'] = input_sinogram + self.pars['projector_geometry'] = projector_geometry # proj_geom + self.pars['output_geometry'] = output_geometry # vol_geom + self.pars['input_sinogram'] = input_sinogram # sino detectors, nangles, sliceZ = numpy.shape(input_sinogram) self.pars['detectors'] = detectors - self.pars['number_og_angles'] = nangles + self.pars['number_of_angles'] = nangles self.pars['SlicesZ'] = sliceZ print (self.pars) # handle optional input parameters (at instantiation) # Accepted input keywords - kw = ('number_of_iterations', + kw = ( + # mandatory fields + 'projector_geometry', + 'output_geometry', + 'input_sinogram', + 'detectors', + 'number_of_angles', + 'SlicesZ', + # optional fields + 'number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , 'weights' , @@ -98,8 +107,9 @@ class FISTAReconstructor(): 'initialize' , 'regularizer' , 'ring_lambda_R_L1', - 'ring_alpha') - self.acceptedInputKeywords = kw + 'ring_alpha', + 'subsets') + self.acceptedInputKeywords = list(kw) # handle keyworded parameters if kwargs is not None: @@ -122,8 +132,7 @@ class FISTAReconstructor(): if 'Lipschitz_constant' in kwargs.keys(): self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] else: - self.pars['Lipschitz_constant'] = \ - self.calculateLipschitzConstantWithPowerMethod() + self.pars['Lipschitz_constant'] = None if not 'ideal_image' in kwargs.keys(): self.pars['ideal_image'] = None @@ -143,31 +152,44 @@ class FISTAReconstructor(): self.pars['ring_lambda_R_L1'] = 0 if not 'ring_alpha' in kwargs.keys(): self.pars['ring_alpha'] = 1 - + + if not 'subsets' in kwargs.keys(): + self.pars['subsets'] = 0 + else: + self.createOrderedSubsets() + + if not 'initialize' in kwargs.keys(): + self.pars['initialize'] = False def setParameter(self, **kwargs): - '''set named parameter for the regularization engine + '''set named parameter for the reconstructor engine raises Exception if the named parameter is not recognized - Typical usage is: - - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - reg.setParameter(input=u0) - reg.setParameter(regularization_parameter=10.) - it can be also used as - reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - reg.setParameter(input=u0 , regularization_parameter=10.) ''' - for key , value in kwargs.items(): - if key in self.acceptedInputKeywords.keys(): + if key in self.acceptedInputKeywords: self.pars[key] = value else: - raise Exception('Wrong parameter {0} for '.format(key) + - 'Reconstruction algorithm') + raise Exception('Wrong parameter {0} for '.format(key) + + 'reconstructor') # setParameter + + def getParameter(self, key): + if type(key) is str: + if key in self.acceptedInputKeywords: + return self.pars[key] + else: + raise Exception('Unrecongnised parameter: {0} '.format(key) ) + elif type(key) is list: + outpars = [] + for k in key: + outpars.append(self.getParameter(k)) + return outpars + else: + raise Exception('Unhandled input {0}' .format(str(type(key)))) + def calculateLipschitzConstantWithPowerMethod(self): ''' using Power method (PM) to establish L constant''' @@ -289,5 +311,119 @@ class FISTAReconstructor(): if regularizer is not None: self.pars['regularizer'] = regularizer + + def initialize(self): + # convenience variable storage + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + sino = self.pars['input_sinogram'] + + # a 'warm start' with SIRT method + # Create a data object for the reconstruction + rec_id = astra.matlab.data3d('create', '-vol', + vol_geom); + + #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); + sinogram_id = astra.matlab.data3d('create', '-proj3d', + proj_geom, + sino) + + sirt_config = astra.astra_dict('SIRT3D_CUDA') + sirt_config['ReconstructionDataId' ] = rec_id + sirt_config['ProjectionDataId'] = sinogram_id + + sirt = astra.algorithm.create(sirt_config) + astra.algorithm.run(sirt, iterations=35) + X = astra.matlab.data3d('get', rec_id) + + # clean up memory + astra.matlab.data3d('delete', rec_id) + astra.matlab.data3d('delete', sinogram_id) + astra.algorithm.delete(sirt) + + + + return X + + def createOrderedSubsets(self, subsets=None): + if subsets is None: + try: + subsets = self.getParameter('subsets') + except Exception(): + subsets = 0 + #return subsets + + angles = self.getParameter('projector_geometry')['ProjectionAngles'] + + + + + + + def prepareForIteration(self): + self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) + self.objective = numpy.zeros((self.pars['number_of_iterations'])) + + #2D array (for 3D data) of sparse "ring" + detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) + self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) + # another ring variable + self.rx = self.r.copy() + + self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) + + if self.getParameter('Lipschitz_constant') is None: + self.pars['Lipschitz_constant'] = \ + self.calculateLipschitzConstantWithPowerMethod() + + # prepareForIteration + + def iterate(self, Xin=None): + # convenience variable storage + proj_geom , vol_geom, sino , \ + SlicesZ = self.getParameter(['projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ']) + + t = 1 + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + X = Xin.copy() + + X_t = X.copy() + + for i in range(self.getParameter('number_of_iterations')): + X_old = X.copy() + t_old = t + r_old = self.r.copy() + if self.pars['projector_geometry']['type'] == 'parallel' or \ + self.pars['projector_geometry']['type'] == 'parallel3d': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + + #for kkk = 1:SlicesZ + # [sino_id, sino_updt(:,:,kkk)] = + # astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT); + # astra_mex_data3d('delete', sino_id); + for kkk in range(SlicesZ): + sino_id, sino_updt[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk], proj_geomT, vol_geomT) + + else: + # for divergent 3D geometry (watch GPU memory overflow in + # Astra < 1.8 + sino_id, y = astra.creators.create_sino3d_gpu(X_t, + proj_geom, + vol_geom) - + + -- cgit v1.2.3 From dd30175d2a198a44c92cdbdb40c3512f15a637e8 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 17 Oct 2017 09:37:25 +0100 Subject: Squashing 2 commits: Added and removed hdf5 (too big) Added data in hdf5 format removed hdf5 data --- src/Python/ccpi/reconstruction/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Python/ccpi/reconstruction/__init__.py (limited to 'src') diff --git a/src/Python/ccpi/reconstruction/__init__.py b/src/Python/ccpi/reconstruction/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 2014650ab9fbf5a7d1c7334fa54ac0b1c5908915 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 17 Oct 2017 17:01:12 +0100 Subject: Progress in pythonization --- src/Python/ccpi/fista/FISTAReconstructor.py | 104 +++++++++---- src/Python/test_reconstructor.py | 229 +++++++++++++++++++++++++++- 2 files changed, 298 insertions(+), 35 deletions(-) (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index 8318ea6..87dd2c0 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -81,7 +81,7 @@ class FISTAReconstructor(): self.pars['projector_geometry'] = projector_geometry # proj_geom self.pars['output_geometry'] = output_geometry # vol_geom self.pars['input_sinogram'] = input_sinogram # sino - detectors, nangles, sliceZ = numpy.shape(input_sinogram) + sliceZ, nangles, detectors = numpy.shape(input_sinogram) self.pars['detectors'] = detectors self.pars['number_of_angles'] = nangles self.pars['SlicesZ'] = sliceZ @@ -108,7 +108,9 @@ class FISTAReconstructor(): 'regularizer' , 'ring_lambda_R_L1', 'ring_alpha', - 'subsets') + 'subsets', + 'use_studentt_fidelity', + 'studentt') self.acceptedInputKeywords = list(kw) # handle keyworded parameters @@ -143,16 +145,18 @@ class FISTAReconstructor(): else: self.pars['region_of_interest'] = numpy.nonzero( self.pars['ideal_image']>0.0) - + + # the regularizer must be a correctly instantiated object if not 'regularizer' in kwargs.keys() : self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not 'ring_lambda_R_L1' in kwargs.keys(): - self.pars['ring_lambda_R_L1'] = 0 - if not 'ring_alpha' in kwargs.keys(): - self.pars['ring_alpha'] = 1 + #RING REMOVAL + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + # ORDERED SUBSET if not 'subsets' in kwargs.keys(): self.pars['subsets'] = 0 else: @@ -160,6 +164,15 @@ class FISTAReconstructor(): if not 'initialize' in kwargs.keys(): self.pars['initialize'] = False + + if not 'use_studentt_fidelity' in kwargs.keys(): + self.setParameter(studentt=False) + else: + print ("studentt {0}".format(kwargs['use_studentt_fidelity'])) + if kwargs['use_studentt_fidelity']: + raise Exception('Not implemented') + + self.setParameter(studentt=kwargs['use_studentt_fidelity']) def setParameter(self, **kwargs): @@ -170,6 +183,8 @@ class FISTAReconstructor(): ''' for key , value in kwargs.items(): if key in self.acceptedInputKeywords: + if key == 'use_studentt_fidelity': + raise Exception('use_studentt_fidelity Not implemented') self.pars[key] = value else: raise Exception('Wrong parameter {0} for '.format(key) + @@ -354,10 +369,28 @@ class FISTAReconstructor(): #return subsets angles = self.getParameter('projector_geometry')['ProjectionAngles'] - - - + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles))) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + + return IndicesReorg def prepareForIteration(self): @@ -368,23 +401,24 @@ class FISTAReconstructor(): detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) # another ring variable - self.rx = self.r.copy() + self.r_x = self.r.copy() self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) if self.getParameter('Lipschitz_constant') is None: self.pars['Lipschitz_constant'] = \ self.calculateLipschitzConstantWithPowerMethod() + # prepareForIteration def iterate(self, Xin=None): # convenience variable storage proj_geom , vol_geom, sino , \ - SlicesZ = self.getParameter(['projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ']) + SlicesZ = self.getParameter([ 'projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ']) t = 1 if Xin is None: @@ -394,7 +428,8 @@ class FISTAReconstructor(): N = vol_geom['GridColCount'] X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) else: - X = Xin.copy() + # copy by reference + X = Xin X_t = X.copy() @@ -402,28 +437,31 @@ class FISTAReconstructor(): X_old = X.copy() t_old = t r_old = self.r.copy() - if self.pars['projector_geometry']['type'] == 'parallel' or \ - self.pars['projector_geometry']['type'] == 'parallel3d': + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'parallel3d': # if the geometry is parallel use slice-by-slice # projection-backprojection routine #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) - - #for kkk = 1:SlicesZ - # [sino_id, sino_updt(:,:,kkk)] = - # astra_create_sino3d_cuda(X_t(:,:,kkk), proj_geomT, vol_geomT); - # astra_mex_data3d('delete', sino_id); for kkk in range(SlicesZ): + print (kkk) sino_id, sino_updt[kkk] = \ astra.creators.create_sino3d_gpu( - X_t[kkk], proj_geomT, vol_geomT) - + X_t[kkk:kkk+1], proj_geomT, vol_geomT) + astra.matlab.data3d('delete', sino_id) else: - # for divergent 3D geometry (watch GPU memory overflow in - # Astra < 1.8 - sino_id, y = astra.creators.create_sino3d_gpu(X_t, - proj_geom, - vol_geom) - + # for divergent 3D geometry (watch the GPU memory overflow in + # ASTRA versions < 1.8) + #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + sino_id, sino_updt = astra.matlab.create_sino3d_gpu( + X_t, proj_geom, vol_geom) + + + ## RING REMOVAL + ## REGULARIZATION diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index a338d34..f8f6b3c 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -31,7 +31,7 @@ recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] -Z_slices = 3 +Z_slices = 20 det_row_count = Z_slices # next definition is just for consistency of naming det_col_count = size_det @@ -64,5 +64,230 @@ fistaRecon = FISTAReconstructor(proj_geom, weights=Weights3D) print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +fistaRecon.setParameter(number_of_iterations = 12) +fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +fistaRecon.setParameter(ring_alpha = 21) +fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) +#fistaRecon.setParameter(use_studentt_fidelity= True) -## the calculation of the lipschitz constant should not start by itself +## Ordered subset +if False: + subsets = 16 + angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles))) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + +if True: + fistaRecon.prepareForIteration() + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + + + + proj_geom , vol_geom, sino , \ + SlicesZ = fistaRecon.getParameter(['projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ']) + + fistaRecon.setParameter(number_of_iterations = 3) + iterFISTA = fistaRecon.getParameter('number_of_iterations') + # errors vector (if the ground truth is given) + Resid_error = numpy.zeros((iterFISTA)); + # objective function values vector + objective = numpy.zeros((iterFISTA)); + + + print ("line") + t = 1 + print ("line") + + if False: + # if X doesn't exist + #N = params.vol_geom.GridColCount + N = vol_geom['GridColCount'] + print ("N " + str(N)) + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + #X = fistaRecon.initialize() + X = numpy.load("X.npy") + + print (numpy.shape(X)) + X_t = X.copy() + print ("X_t copy") +## % Outer FISTA iterations loop + for i in range(fistaRecon.getParameter('number_of_iterations')): + X_old = X.copy() + t_old = t + r_old = fistaRecon.r.copy() + if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + for kkk in range(SlicesZ): + print (kkk) + sino_id, sino_updt[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geomT, vol_geomT) + astra.matlab.data3d('delete', sino_id) + else: + # for divergent 3D geometry (watch the GPU memory overflow in + # ASTRA versions < 1.8) + #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + sino_id, sino_updt = astra.matlab.create_sino3d_gpu( + X_t, proj_geom, vol_geom) + + ## RING REMOVAL + residual = fistaRecon.residual + lambdaR_L1 , alpha_ring , weights , L_const= \ + fistaRecon.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant']) + r_x = fistaRecon.r_x + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(fistaRecon.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + for kkk in range(anglesNumb): + print ("angles {0}".format(kkk)) + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + fistaRecon.r = (r_x - (1./L_const) * vec).copy() + objective[i] = (0.5 * (residual ** 2).sum()) +## % the ring removal part (Group-Huber fidelity) +## for kkk = 1:anglesNumb +## residual(:,kkk,:) = squeeze(weights(:,kkk,:)).* +## (squeeze(sino_updt(:,kkk,:)) - +## (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); +## end +## vec = sum(residual,2); +## if (SlicesZ > 1) +## vec = squeeze(vec(:,1,:)); +## end +## r = r_x - (1./L_const).*vec; +## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + + else: + if fistaRecon.getParameter('use_studentt_fidelity'): + residual = weights * (sino_updt - sino) + for kkk in range(SlicesZ): + # reshape(residual(:,:,kkk), Detectors*anglesNumb, 1) + # 1D + res_vec = numpy.reshape(residual[kkk], (Detectors * anglesNumb,1)) + +## else +## if (studentt == 1) +## % artifacts removal with Students t penalty +## residual = weights.*(sino_updt - sino); +## for kkk = 1:SlicesZ +## res_vec = reshape(residual(:,:,kkk), Detectors*anglesNumb, 1); % 1D vectorized sinogram +## %s = 100; +## %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec); +## [ff, gr] = studentst(res_vec, 1); +## residual(:,:,kkk) = reshape(gr, Detectors, anglesNumb); +## end +## objective(i) = ff; % for the objective function output +## else +## % no ring removal (LS model) +## residual = weights.*(sino_updt - sino); +## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output +## end +## end + + # Projection/Backprojection Routine + if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d': + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + for kkk in range(SlicesZ): + print ("Projection/Backprojection Routine {0}".format( kkk )) + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) + + + ## REGULARIZATION + ## SKIPPING FOR NOW + ## Should be simpli + # regularizer = fistaRecon.getParameter('regularizer') + # for slices: + # out = regularizer(input=X) + + + ## FINAL + lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + fistaRecon.r = numpy.max( + numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ + numpy.sign(fistaRecon.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + fistaRecon.r_x = fistaRecon.r + \ + (((t_old-1)/t) * (fistaRecon.r - r_old)) + + if fistaRecon.getParameter('ideal_image') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, objective[i])) + +## if (lambdaR_L1 > 0) +## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector +## end +## +## t = (1 + sqrt(1 + 4*t^2))/2; % updating t +## X_t = X + ((t_old-1)/t).*(X - X_old); % updating X +## +## if (lambdaR_L1 > 0) +## r_x = r + ((t_old-1)/t).*(r - r_old); % updating r +## end +## +## if (show == 1) +## figure(10); imshow(X(:,:,slice), [0 maxvalplot]); +## if (lambdaR_L1 > 0) +## figure(11); plot(r); title('Rings offset vector') +## end +## pause(0.01); +## end +## if (strcmp(X_ideal, 'none' ) == 0) +## Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); +## fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); +## else +## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); +## end -- cgit v1.2.3 From 1af73a75ccab1147a8d2387b7056f91f0642549f Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 18 Oct 2017 10:04:32 +0100 Subject: removed viewer package from tree --- src/Python/ccpi/viewer/CILViewer.py | 361 ------- src/Python/ccpi/viewer/CILViewer2D.py | 1126 -------------------- src/Python/ccpi/viewer/QVTKWidget.py | 340 ------ src/Python/ccpi/viewer/QVTKWidget2.py | 84 -- src/Python/ccpi/viewer/__init__.py | 1 - .../viewer/__pycache__/CILViewer.cpython-35.pyc | Bin 10542 -> 0 bytes .../viewer/__pycache__/CILViewer2D.cpython-35.pyc | Bin 35633 -> 0 bytes .../viewer/__pycache__/QVTKWidget.cpython-35.pyc | Bin 10099 -> 0 bytes .../viewer/__pycache__/QVTKWidget2.cpython-35.pyc | Bin 1316 -> 0 bytes .../viewer/__pycache__/__init__.cpython-35.pyc | Bin 210 -> 0 bytes src/Python/ccpi/viewer/embedvtk.py | 75 -- 11 files changed, 1987 deletions(-) delete mode 100644 src/Python/ccpi/viewer/CILViewer.py delete mode 100644 src/Python/ccpi/viewer/CILViewer2D.py delete mode 100644 src/Python/ccpi/viewer/QVTKWidget.py delete mode 100644 src/Python/ccpi/viewer/QVTKWidget2.py delete mode 100644 src/Python/ccpi/viewer/__init__.py delete mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc delete mode 100644 src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc delete mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc delete mode 100644 src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc delete mode 100644 src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc delete mode 100644 src/Python/ccpi/viewer/embedvtk.py (limited to 'src') diff --git a/src/Python/ccpi/viewer/CILViewer.py b/src/Python/ccpi/viewer/CILViewer.py deleted file mode 100644 index efcf8be..0000000 --- a/src/Python/ccpi/viewer/CILViewer.py +++ /dev/null @@ -1,361 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 Edoardo Pasca -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import vtk -import numpy -import math -from vtk.util import numpy_support - -SLICE_ORIENTATION_XY = 2 # Z -SLICE_ORIENTATION_XZ = 1 # Y -SLICE_ORIENTATION_YZ = 0 # X - - - -class CILViewer(): - '''Simple 3D Viewer based on VTK classes''' - - def __init__(self, dimx=600,dimy=600): - '''creates the rendering pipeline''' - - # create a rendering window and renderer - self.ren = vtk.vtkRenderer() - self.renWin = vtk.vtkRenderWindow() - self.renWin.SetSize(dimx,dimy) - self.renWin.AddRenderer(self.ren) - - # img 3D as slice - self.img3D = None - self.sliceno = 0 - self.sliceOrientation = SLICE_ORIENTATION_XY - self.sliceActor = None - self.voi = None - self.wl = None - self.ia = None - self.sliceActorNo = 0 - # create a renderwindowinteractor - self.iren = vtk.vtkRenderWindowInteractor() - self.iren.SetRenderWindow(self.renWin) - - self.style = vtk.vtkInteractorStyleTrackballCamera() - self.iren.SetInteractorStyle(self.style) - - self.ren.SetBackground(.1, .2, .4) - - self.actors = {} - self.iren.RemoveObservers('MouseWheelForwardEvent') - self.iren.RemoveObservers('MouseWheelBackwardEvent') - - self.iren.AddObserver('MouseWheelForwardEvent', self.mouseInteraction, 1.0) - self.iren.AddObserver('MouseWheelBackwardEvent', self.mouseInteraction, 1.0) - - self.iren.RemoveObservers('KeyPressEvent') - self.iren.AddObserver('KeyPressEvent', self.keyPress, 1.0) - - - self.iren.Initialize() - - - - def getRenderer(self): - '''returns the renderer''' - return self.ren - - def getRenderWindow(self): - '''returns the render window''' - return self.renWin - - def getInteractor(self): - '''returns the render window interactor''' - return self.iren - - def getCamera(self): - '''returns the active camera''' - return self.ren.GetActiveCamera() - - def createPolyDataActor(self, polydata): - '''returns an actor for a given polydata''' - mapper = vtk.vtkPolyDataMapper() - if vtk.VTK_MAJOR_VERSION <= 5: - mapper.SetInput(polydata) - else: - mapper.SetInputData(polydata) - - # actor - actor = vtk.vtkActor() - actor.SetMapper(mapper) - #actor.GetProperty().SetOpacity(0.8) - return actor - - def setPolyDataActor(self, actor): - '''displays the given polydata''' - - self.ren.AddActor(actor) - - self.actors[len(self.actors)+1] = [actor, True] - self.iren.Initialize() - self.renWin.Render() - - def displayPolyData(self, polydata): - self.setPolyDataActor(self.createPolyDataActor(polydata)) - - def hideActor(self, actorno): - '''Hides an actor identified by its number in the list of actors''' - try: - if self.actors[actorno][1]: - self.ren.RemoveActor(self.actors[actorno][0]) - self.actors[actorno][1] = False - except KeyError as ke: - print ("Warning Actor not present") - - def showActor(self, actorno, actor = None): - '''Shows hidden actor identified by its number in the list of actors''' - try: - if not self.actors[actorno][1]: - self.ren.AddActor(self.actors[actorno][0]) - self.actors[actorno][1] = True - return actorno - except KeyError as ke: - # adds it to the actors if not there already - if actor != None: - self.ren.AddActor(actor) - self.actors[len(self.actors)+1] = [actor, True] - return len(self.actors) - - def addActor(self, actor): - '''Adds an actor to the render''' - return self.showActor(0, actor) - - - def saveRender(self, filename, renWin=None): - '''Save the render window to PNG file''' - # screenshot code: - w2if = vtk.vtkWindowToImageFilter() - if renWin == None: - renWin = self.renWin - w2if.SetInput(renWin) - w2if.Update() - - writer = vtk.vtkPNGWriter() - writer.SetFileName("%s.png" % (filename)) - writer.SetInputConnection(w2if.GetOutputPort()) - writer.Write() - - - def startRenderLoop(self): - self.iren.Start() - - - def setupObservers(self, interactor): - interactor.RemoveObservers('LeftButtonPressEvent') - interactor.AddObserver('LeftButtonPressEvent', self.mouseInteraction) - interactor.Initialize() - - - def mouseInteraction(self, interactor, event): - if event == 'MouseWheelForwardEvent': - maxSlice = self.img3D.GetDimensions()[self.sliceOrientation] - if (self.sliceno + 1 < maxSlice): - self.hideActor(self.sliceActorNo) - self.sliceno = self.sliceno + 1 - self.displaySliceActor(self.sliceno) - else: - minSlice = 0 - if (self.sliceno - 1 > minSlice): - self.hideActor(self.sliceActorNo) - self.sliceno = self.sliceno - 1 - self.displaySliceActor(self.sliceno) - - - def keyPress(self, interactor, event): - #print ("Pressed key %s" % interactor.GetKeyCode()) - # Slice Orientation - if interactor.GetKeyCode() == "x": - # slice on the other orientation - self.sliceOrientation = SLICE_ORIENTATION_YZ - self.sliceno = int(self.img3D.GetDimensions()[1] / 2) - self.hideActor(self.sliceActorNo) - self.displaySliceActor(self.sliceno) - elif interactor.GetKeyCode() == "y": - # slice on the other orientation - self.sliceOrientation = SLICE_ORIENTATION_XZ - self.sliceno = int(self.img3D.GetDimensions()[1] / 2) - self.hideActor(self.sliceActorNo) - self.displaySliceActor(self.sliceno) - elif interactor.GetKeyCode() == "z": - # slice on the other orientation - self.sliceOrientation = SLICE_ORIENTATION_XY - self.sliceno = int(self.img3D.GetDimensions()[2] / 2) - self.hideActor(self.sliceActorNo) - self.displaySliceActor(self.sliceno) - if interactor.GetKeyCode() == "X": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_YZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,0,-1) - self.ren.SetActiveCamera(camera) - self.ren.ResetCamera() - self.ren.Render() - interactor.SetKeyCode("x") - self.keyPress(interactor, event) - elif interactor.GetKeyCode() == "Y": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_XZ] = math.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,0,-1) - self.ren.SetActiveCamera(camera) - self.ren.ResetCamera() - self.ren.Render() - interactor.SetKeyCode("y") - self.keyPress(interactor, event) - elif interactor.GetKeyCode() == "Z": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.ren.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.ren.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.ren.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_XY] = math.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,0,-1) - self.ren.SetActiveCamera(camera) - self.ren.ResetCamera() - self.ren.Render() - interactor.SetKeyCode("z") - self.keyPress(interactor, event) - else : - print ("Unhandled event %s" % interactor.GetKeyCode()) - - - - def setInput3DData(self, imageData): - self.img3D = imageData - - def setInputAsNumpy(self, numpyarray): - if (len(numpy.shape(numpyarray)) == 3): - doubleImg = vtk.vtkImageData() - shape = numpy.shape(numpyarray) - doubleImg.SetDimensions(shape[0], shape[1], shape[2]) - doubleImg.SetOrigin(0,0,0) - doubleImg.SetSpacing(1,1,1) - doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) - #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) - doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) - - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - doubleImg.SetScalarComponentFromDouble( - i,j,k,0, numpyarray[i][j][k]) - #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(doubleImg) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(doubleImg) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(iMin) - shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) - shiftScaler.Update() - self.img3D = shiftScaler.GetOutput() - - def displaySliceActor(self, sliceno = 0): - self.sliceno = sliceno - first = False - - self.sliceActor , self.voi, self.wl , self.ia = \ - self.getSliceActor(self.img3D, - sliceno, - self.sliceActor, - self.voi, - self.wl, - self.ia) - no = self.showActor(self.sliceActorNo, self.sliceActor) - self.sliceActorNo = no - - self.iren.Initialize() - self.renWin.Render() - - return self.sliceActorNo - - - def getSliceActor(self, - imageData , - sliceno=0, - imageActor=None , - voi=None, - windowLevel=None, - imageAccumulate=None): - '''Slices a 3D volume and then creates an actor to be rendered''' - if (voi==None): - voi = vtk.vtkExtractVOI() - #voi = vtk.vtkImageClip() - voi.SetInputData(imageData) - #select one slice in Z - extent = [ i for i in self.img3D.GetExtent()] - extent[self.sliceOrientation * 2] = sliceno - extent[self.sliceOrientation * 2 + 1] = sliceno - voi.SetVOI(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - - voi.Update() - # set window/level for all slices - if imageAccumulate == None: - imageAccumulate = vtk.vtkImageAccumulate() - - if (windowLevel == None): - windowLevel = vtk.vtkImageMapToWindowLevelColors() - imageAccumulate.SetInputData(imageData) - imageAccumulate.Update() - cmax = imageAccumulate.GetMax()[0] - cmin = imageAccumulate.GetMin()[0] - windowLevel.SetLevel((cmax+cmin)/2) - windowLevel.SetWindow(cmax-cmin) - - windowLevel.SetInputData(voi.GetOutput()) - windowLevel.Update() - - if imageActor == None: - imageActor = vtk.vtkImageActor() - imageActor.SetInputData(windowLevel.GetOutput()) - imageActor.SetDisplayExtent(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - imageActor.Update() - return (imageActor , voi, windowLevel, imageAccumulate) - - - # Set interpolation on - def setInterpolateOn(self): - self.sliceActor.SetInterpolate(True) - self.renWin.Render() - - # Set interpolation off - def setInterpolateOff(self): - self.sliceActor.SetInterpolate(False) - self.renWin.Render() \ No newline at end of file diff --git a/src/Python/ccpi/viewer/CILViewer2D.py b/src/Python/ccpi/viewer/CILViewer2D.py deleted file mode 100644 index c1629af..0000000 --- a/src/Python/ccpi/viewer/CILViewer2D.py +++ /dev/null @@ -1,1126 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2017 Edoardo Pasca -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import vtk -import numpy -from vtk.util import numpy_support , vtkImageImportFromArray -from enum import Enum - -SLICE_ORIENTATION_XY = 2 # Z -SLICE_ORIENTATION_XZ = 1 # Y -SLICE_ORIENTATION_YZ = 0 # X - -CONTROL_KEY = 8 -SHIFT_KEY = 4 -ALT_KEY = -128 - - -# Converter class -class Converter(): - - # Utility functions to transform numpy arrays to vtkImageData and viceversa - @staticmethod - def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): - '''Creates a vtkImageImportFromArray object and returns it. - - It handles the different axis order from numpy to VTK''' - importer = vtkImageImportFromArray.vtkImageImportFromArray() - importer.SetArray(numpy.transpose(nparray).copy()) - importer.SetDataSpacing(spacing) - importer.SetDataOrigin(origin) - return importer - - @staticmethod - def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): - '''Converts a 3D numpy array to a vtkImageData''' - importer = Converter.numpy2vtkImporter(nparray, spacing, origin) - importer.Update() - return importer.GetOutput() - - @staticmethod - def vtk2numpy(imgdata): - '''Converts the VTK data to 3D numpy array''' - img_data = numpy_support.vtk_to_numpy( - imgdata.GetPointData().GetScalars()) - - dims = imgdata.GetDimensions() - dims = (dims[2],dims[1],dims[0]) - data3d = numpy.reshape(img_data, dims) - - return numpy.transpose(data3d).copy() - - @staticmethod - def tiffStack2numpy(filename, indices, - extent = None , sampleRate = None ,\ - flatField = None, darkField = None): - '''Converts a stack of TIFF files to numpy array. - - filename must contain the whole path. The filename is supposed to be named and - have a suffix with the ordinal file number, i.e. /path/to/projection_%03d.tif - - indices are the suffix, generally an increasing number - - Optionally extracts only a selection of the 2D images and (optionally) - normalizes. - ''' - - stack = vtk.vtkImageData() - reader = vtk.vtkTIFFReader() - voi = vtk.vtkExtractVOI() - - #directory = "C:\\Users\\ofn77899\\Documents\\CCPi\\IMAT\\20170419_crabtomo\\crabtomo\\" - - stack_image = numpy.asarray([]) - nreduced = len(indices) - - for num in range(len(indices)): - fn = filename % indices[num] - print ("resampling %s" % ( fn ) ) - reader.SetFileName(fn) - reader.Update() - print (reader.GetOutput().GetScalarTypeAsString()) - if num == 0: - if (extent == None): - sliced = reader.GetOutput().GetExtent() - stack.SetExtent(sliced[0],sliced[1], sliced[2],sliced[3], 0, nreduced-1) - else: - sliced = extent - voi.SetVOI(extent) - - if sampleRate is not None: - voi.SetSampleRate(sampleRate) - ext = numpy.asarray([(sliced[2*i+1] - sliced[2*i])/sampleRate[i] for i in range(3)], dtype=int) - print ("ext {0}".format(ext)) - stack.SetExtent(0, ext[0] , 0, ext[1], 0, nreduced-1) - else: - stack.SetExtent(0, sliced[1] - sliced[0] , 0, sliced[3]-sliced[2], 0, nreduced-1) - if (flatField != None and darkField != None): - stack.AllocateScalars(vtk.VTK_FLOAT, 1) - else: - stack.AllocateScalars(reader.GetOutput().GetScalarType(), 1) - print ("Image Size: %d" % ((sliced[1]+1)*(sliced[3]+1) )) - stack_image = Converter.vtk2numpy(stack) - print ("Stack shape %s" % str(numpy.shape(stack_image))) - - if extent!=None: - voi.SetInputData(reader.GetOutput()) - voi.Update() - img = voi.GetOutput() - else: - img = reader.GetOutput() - - theSlice = Converter.vtk2numpy(img).T[0] - if darkField != None and flatField != None: - print("Try to normalize") - #if numpy.shape(darkField) == numpy.shape(flatField) and numpy.shape(flatField) == numpy.shape(theSlice): - theSlice = Converter.normalize(theSlice, darkField, flatField, 0.01) - print (theSlice.dtype) - - - print ("Slice shape %s" % str(numpy.shape(theSlice))) - stack_image.T[num] = theSlice.copy() - - return stack_image - - @staticmethod - def normalize(projection, dark, flat, def_val=0): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - - -## Utility functions to transform numpy arrays to vtkImageData and viceversa -#def numpy2vtkImporter(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): -# return Converter.numpy2vtkImporter(nparray, spacing, origin) -# -#def numpy2vtk(nparray, spacing=(1.,1.,1.), origin=(0,0,0)): -# return Converter.numpy2vtk(nparray, spacing, origin) -# -#def vtk2numpy(imgdata): -# return Converter.vtk2numpy(imgdata) -# -#def tiffStack2numpy(filename, indices): -# return Converter.tiffStack2numpy(filename, indices) - -class ViewerEvent(Enum): - # left button - PICK_EVENT = 0 - # alt + right button + move - WINDOW_LEVEL_EVENT = 1 - # shift + right button - ZOOM_EVENT = 2 - # control + right button - PAN_EVENT = 3 - # control + left button - CREATE_ROI_EVENT = 4 - # alt + left button - DELETE_ROI_EVENT = 5 - # release button - NO_EVENT = -1 - - -#class CILInteractorStyle(vtk.vtkInteractorStyleTrackballCamera): -class CILInteractorStyle(vtk.vtkInteractorStyleImage): - - def __init__(self, callback): - vtk.vtkInteractorStyleImage.__init__(self) - self.callback = callback - self._viewer = callback - priority = 1.0 - -# self.AddObserver("MouseWheelForwardEvent" , callback.OnMouseWheelForward , priority) -# self.AddObserver("MouseWheelBackwardEvent" , callback.OnMouseWheelBackward, priority) -# self.AddObserver('KeyPressEvent', callback.OnKeyPress, priority) -# self.AddObserver('LeftButtonPressEvent', callback.OnLeftButtonPressEvent, priority) -# self.AddObserver('RightButtonPressEvent', callback.OnRightButtonPressEvent, priority) -# self.AddObserver('LeftButtonReleaseEvent', callback.OnLeftButtonReleaseEvent, priority) -# self.AddObserver('RightButtonReleaseEvent', callback.OnRightButtonReleaseEvent, priority) -# self.AddObserver('MouseMoveEvent', callback.OnMouseMoveEvent, priority) - - self.AddObserver("MouseWheelForwardEvent" , self.OnMouseWheelForward , priority) - self.AddObserver("MouseWheelBackwardEvent" , self.OnMouseWheelBackward, priority) - self.AddObserver('KeyPressEvent', self.OnKeyPress, priority) - self.AddObserver('LeftButtonPressEvent', self.OnLeftButtonPressEvent, priority) - self.AddObserver('RightButtonPressEvent', self.OnRightButtonPressEvent, priority) - self.AddObserver('LeftButtonReleaseEvent', self.OnLeftButtonReleaseEvent, priority) - self.AddObserver('RightButtonReleaseEvent', self.OnRightButtonReleaseEvent, priority) - self.AddObserver('MouseMoveEvent', self.OnMouseMoveEvent, priority) - - self.InitialEventPosition = (0,0) - - - def SetInitialEventPosition(self, xy): - self.InitialEventPosition = xy - - def GetInitialEventPosition(self): - return self.InitialEventPosition - - def GetKeyCode(self): - return self.GetInteractor().GetKeyCode() - - def SetKeyCode(self, keycode): - self.GetInteractor().SetKeyCode(keycode) - - def GetControlKey(self): - return self.GetInteractor().GetControlKey() == CONTROL_KEY - - def GetShiftKey(self): - return self.GetInteractor().GetShiftKey() == SHIFT_KEY - - def GetAltKey(self): - return self.GetInteractor().GetAltKey() == ALT_KEY - - def GetEventPosition(self): - return self.GetInteractor().GetEventPosition() - - def GetEventPositionInWorldCoordinates(self): - pass - - def GetDeltaEventPosition(self): - x,y = self.GetInteractor().GetEventPosition() - return (x - self.InitialEventPosition[0] , y - self.InitialEventPosition[1]) - - def Dolly(self, factor): - self.callback.camera.Dolly(factor) - self.callback.ren.ResetCameraClippingRange() - - def GetDimensions(self): - return self._viewer.img3D.GetDimensions() - - def GetInputData(self): - return self._viewer.img3D - - def GetSliceOrientation(self): - return self._viewer.sliceOrientation - - def SetSliceOrientation(self, orientation): - self._viewer.sliceOrientation = orientation - - def GetActiveSlice(self): - return self._viewer.sliceno - - def SetActiveSlice(self, sliceno): - self._viewer.sliceno = sliceno - - def UpdatePipeline(self, reset = False): - self._viewer.updatePipeline(reset) - - def GetActiveCamera(self): - return self._viewer.ren.GetActiveCamera() - - def SetActiveCamera(self, camera): - self._viewer.ren.SetActiveCamera(camera) - - def ResetCamera(self): - self._viewer.ren.ResetCamera() - - def Render(self): - self._viewer.renWin.Render() - - def UpdateSliceActor(self): - self._viewer.sliceActor.Update() - - def AdjustCamera(self): - self._viewer.AdjustCamera() - - def SaveRender(self, filename): - self._viewer.SaveRender(filename) - - def GetRenderWindow(self): - return self._viewer.renWin - - def GetRenderer(self): - return self._viewer.ren - - def GetROIWidget(self): - return self._viewer.ROIWidget - - def SetViewerEvent(self, event): - self._viewer.event = event - - def GetViewerEvent(self): - return self._viewer.event - - def SetInitialCameraPosition(self, position): - self._viewer.InitialCameraPosition = position - - def GetInitialCameraPosition(self): - return self._viewer.InitialCameraPosition - - def SetInitialLevel(self, level): - self._viewer.InitialLevel = level - - def GetInitialLevel(self): - return self._viewer.InitialLevel - - def SetInitialWindow(self, window): - self._viewer.InitialWindow = window - - def GetInitialWindow(self): - return self._viewer.InitialWindow - - def GetWindowLevel(self): - return self._viewer.wl - - def SetROI(self, roi): - self._viewer.ROI = roi - - def GetROI(self): - return self._viewer.ROI - - def UpdateCornerAnnotation(self, text, corner): - self._viewer.updateCornerAnnotation(text, corner) - - def GetPicker(self): - return self._viewer.picker - - def GetCornerAnnotation(self): - return self._viewer.cornerAnnotation - - def UpdateROIHistogram(self): - self._viewer.updateROIHistogram() - - - ############### Handle events - def OnMouseWheelForward(self, interactor, event): - maxSlice = self.GetDimensions()[self.GetSliceOrientation()] - shift = interactor.GetShiftKey() - advance = 1 - if shift: - advance = 10 - - if (self.GetActiveSlice() + advance < maxSlice): - self.SetActiveSlice(self.GetActiveSlice() + advance) - - self.UpdatePipeline() - else: - print ("maxSlice %d request %d" % (maxSlice, self.GetActiveSlice() + 1 )) - - def OnMouseWheelBackward(self, interactor, event): - minSlice = 0 - shift = interactor.GetShiftKey() - advance = 1 - if shift: - advance = 10 - if (self.GetActiveSlice() - advance >= minSlice): - self.SetActiveSlice( self.GetActiveSlice() - advance) - self.UpdatePipeline() - else: - print ("minSlice %d request %d" % (minSlice, self.GetActiveSlice() + 1 )) - - def OnKeyPress(self, interactor, event): - #print ("Pressed key %s" % interactor.GetKeyCode()) - # Slice Orientation - if interactor.GetKeyCode() == "X": - # slice on the other orientation - self.SetSliceOrientation ( SLICE_ORIENTATION_YZ ) - self.SetActiveSlice( int(self.GetDimensions()[1] / 2) ) - self.UpdatePipeline(True) - elif interactor.GetKeyCode() == "Y": - # slice on the other orientation - self.SetSliceOrientation ( SLICE_ORIENTATION_XZ ) - self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[1] / 2) ) - self.UpdatePipeline(True) - elif interactor.GetKeyCode() == "Z": - # slice on the other orientation - self.SetSliceOrientation ( SLICE_ORIENTATION_XY ) - self.SetActiveSlice ( int(self.GetInputData().GetDimensions()[2] / 2) ) - self.UpdatePipeline(True) - if interactor.GetKeyCode() == "x": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_YZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,0,-1) - self.SetActiveCamera(camera) - self.Render() - interactor.SetKeyCode("X") - self.OnKeyPress(interactor, event) - elif interactor.GetKeyCode() == "y": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_XZ] = numpy.sqrt(newposition[SLICE_ORIENTATION_XY] ** 2 + newposition[SLICE_ORIENTATION_YZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,0,-1) - self.SetActiveCamera(camera) - self.Render() - interactor.SetKeyCode("Y") - self.OnKeyPress(interactor, event) - elif interactor.GetKeyCode() == "z": - # Change the camera view point - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) - camera.SetViewUp(self.GetActiveCamera().GetViewUp()) - newposition = [i for i in self.GetActiveCamera().GetFocalPoint()] - newposition[SLICE_ORIENTATION_XY] = numpy.sqrt(newposition[SLICE_ORIENTATION_YZ] ** 2 + newposition[SLICE_ORIENTATION_XZ] ** 2) - camera.SetPosition(newposition) - camera.SetViewUp(0,1,0) - self.SetActiveCamera(camera) - self.ResetCamera() - self.Render() - interactor.SetKeyCode("Z") - self.OnKeyPress(interactor, event) - elif interactor.GetKeyCode() == "a": - # reset color/window - cmax = self._viewer.ia.GetMax()[0] - cmin = self._viewer.ia.GetMin()[0] - - self.SetInitialLevel( (cmax+cmin)/2 ) - self.SetInitialWindow( cmax-cmin ) - - self.GetWindowLevel().SetLevel(self.GetInitialLevel()) - self.GetWindowLevel().SetWindow(self.GetInitialWindow()) - - self.GetWindowLevel().Update() - - self.UpdateSliceActor() - self.AdjustCamera() - self.Render() - - elif interactor.GetKeyCode() == "s": - filename = "current_render" - self.SaveRender(filename) - elif interactor.GetKeyCode() == "q": - print ("Terminating by pressing q %s" % (interactor.GetKeyCode(), )) - interactor.SetKeyCode("e") - self.OnKeyPress(interactor, event) - else : - #print ("Unhandled event %s" % (interactor.GetKeyCode(), ))) - pass - - def OnLeftButtonPressEvent(self, interactor, event): - alt = interactor.GetAltKey() - shift = interactor.GetShiftKey() - ctrl = interactor.GetControlKey() -# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) -# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) -# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) - - interactor.SetInitialEventPosition(interactor.GetEventPosition()) - - if ctrl and not (alt and shift): - self.SetViewerEvent( ViewerEvent.CREATE_ROI_EVENT ) - wsize = self.GetRenderWindow().GetSize() - position = interactor.GetEventPosition() - self.GetROIWidget().GetBorderRepresentation().SetPosition((position[0]/wsize[0] - 0.05) , (position[1]/wsize[1] - 0.05)) - self.GetROIWidget().GetBorderRepresentation().SetPosition2( (0.1) , (0.1)) - - self.GetROIWidget().On() - self.SetDisplayHistogram(True) - self.Render() - print ("Event %s is CREATE_ROI_EVENT" % (event)) - elif alt and not (shift and ctrl): - self.SetViewerEvent( ViewerEvent.DELETE_ROI_EVENT ) - self.GetROIWidget().Off() - self._viewer.updateCornerAnnotation("", 1, False) - self.SetDisplayHistogram(False) - self.Render() - print ("Event %s is DELETE_ROI_EVENT" % (event)) - elif not (ctrl and alt and shift): - self.SetViewerEvent ( ViewerEvent.PICK_EVENT ) - self.HandlePickEvent(interactor, event) - print ("Event %s is PICK_EVENT" % (event)) - - - def SetDisplayHistogram(self, display): - if display: - if (self._viewer.displayHistogram == 0): - self.GetRenderer().AddActor(self._viewer.histogramPlotActor) - self.firstHistogram = 1 - self.Render() - - self._viewer.histogramPlotActor.VisibilityOn() - self._viewer.displayHistogram = True - else: - self._viewer.histogramPlotActor.VisibilityOff() - self._viewer.displayHistogram = False - - - def OnLeftButtonReleaseEvent(self, interactor, event): - if self.GetViewerEvent() == ViewerEvent.CREATE_ROI_EVENT: - #bc = self.ROIWidget.GetBorderRepresentation().GetPositionCoordinate() - #print (bc.GetValue()) - self.OnROIModifiedEvent(interactor, event) - - elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: - self.HandlePickEvent(interactor, event) - - self.SetViewerEvent( ViewerEvent.NO_EVENT ) - - def OnRightButtonPressEvent(self, interactor, event): - alt = interactor.GetAltKey() - shift = interactor.GetShiftKey() - ctrl = interactor.GetControlKey() -# print ("alt pressed " + (lambda x : "Yes" if x else "No")(alt)) -# print ("shift pressed " + (lambda x : "Yes" if x else "No")(shift)) -# print ("ctrl pressed " + (lambda x : "Yes" if x else "No")(ctrl)) - - interactor.SetInitialEventPosition(interactor.GetEventPosition()) - - - if alt and not (ctrl and shift): - self.SetViewerEvent( ViewerEvent.WINDOW_LEVEL_EVENT ) - print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) - self.HandleWindowLevel(interactor, event) - elif shift and not (ctrl and alt): - self.SetViewerEvent( ViewerEvent.ZOOM_EVENT ) - self.SetInitialCameraPosition( self.GetActiveCamera().GetPosition()) - print ("Event %s is ZOOM_EVENT" % (event)) - elif ctrl and not (shift and alt): - self.SetViewerEvent (ViewerEvent.PAN_EVENT ) - self.SetInitialCameraPosition ( self.GetActiveCamera().GetPosition() ) - print ("Event %s is PAN_EVENT" % (event)) - - def OnRightButtonReleaseEvent(self, interactor, event): - print (event) - if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: - self.SetInitialLevel( self.GetWindowLevel().GetLevel() ) - self.SetInitialWindow ( self.GetWindowLevel().GetWindow() ) - elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT or \ - self.GetViewerEvent() == ViewerEvent.PAN_EVENT: - self.SetInitialCameraPosition( () ) - - self.SetViewerEvent( ViewerEvent.NO_EVENT ) - - - def OnROIModifiedEvent(self, interactor, event): - - #print ("ROI EVENT " + event) - p1 = self.GetROIWidget().GetBorderRepresentation().GetPositionCoordinate() - p2 = self.GetROIWidget().GetBorderRepresentation().GetPosition2Coordinate() - wsize = self.GetRenderWindow().GetSize() - - #print (p1.GetValue()) - #print (p2.GetValue()) - pp1 = [p1.GetValue()[0] * wsize[0] , p1.GetValue()[1] * wsize[1] , 0.0] - pp2 = [p2.GetValue()[0] * wsize[0] + pp1[0] , p2.GetValue()[1] * wsize[1] + pp1[1] , 0.0] - vox1 = self.viewport2imageCoordinate(pp1) - vox2 = self.viewport2imageCoordinate(pp2) - - self.SetROI( (vox1 , vox2) ) - roi = self.GetROI() - print ("Pixel1 %d,%d,%d Value %f" % vox1 ) - print ("Pixel2 %d,%d,%d Value %f" % vox2 ) - if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: - print ("slice orientation : XY") - x = abs(roi[1][0] - roi[0][0]) - y = abs(roi[1][1] - roi[0][1]) - elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: - print ("slice orientation : XY") - x = abs(roi[1][0] - roi[0][0]) - y = abs(roi[1][2] - roi[0][2]) - elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: - print ("slice orientation : XY") - x = abs(roi[1][1] - roi[0][1]) - y = abs(roi[1][2] - roi[0][2]) - - text = "ROI: %d x %d, %.2f kp" % (x,y,float(x*y)/1024.) - print (text) - self.UpdateCornerAnnotation(text, 1) - self.UpdateROIHistogram() - self.SetViewerEvent( ViewerEvent.NO_EVENT ) - - def viewport2imageCoordinate(self, viewerposition): - #Determine point index - - self.GetPicker().Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) - pickPosition = list(self.GetPicker().GetPickPosition()) - pickPosition[self.GetSliceOrientation()] = \ - self.GetInputData().GetSpacing()[self.GetSliceOrientation()] * self.GetActiveSlice() + \ - self.GetInputData().GetOrigin()[self.GetSliceOrientation()] - print ("Pick Position " + str (pickPosition)) - - if (pickPosition != [0,0,0]): - dims = self.GetInputData().GetDimensions() - print (dims) - spac = self.GetInputData().GetSpacing() - orig = self.GetInputData().GetOrigin() - imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] - - pixelValue = self.GetInputData().GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) - return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) - else: - return (0,0,0,0) - - - - - def OnMouseMoveEvent(self, interactor, event): - if self.GetViewerEvent() == ViewerEvent.WINDOW_LEVEL_EVENT: - print ("Event %s is WINDOW_LEVEL_EVENT" % (event)) - self.HandleWindowLevel(interactor, event) - elif self.GetViewerEvent() == ViewerEvent.PICK_EVENT: - self.HandlePickEvent(interactor, event) - elif self.GetViewerEvent() == ViewerEvent.ZOOM_EVENT: - self.HandleZoomEvent(interactor, event) - elif self.GetViewerEvent() == ViewerEvent.PAN_EVENT: - self.HandlePanEvent(interactor, event) - - - def HandleZoomEvent(self, interactor, event): - dx,dy = interactor.GetDeltaEventPosition() - size = self.GetRenderWindow().GetSize() - dy = - 4 * dy / size[1] - - print ("distance: " + str(self.GetActiveCamera().GetDistance())) - - print ("\ndy: %f\ncamera dolly %f\n" % (dy, 1 + dy)) - - camera = vtk.vtkCamera() - camera.SetFocalPoint(self.GetActiveCamera().GetFocalPoint()) - #print ("current position " + str(self.InitialCameraPosition)) - camera.SetViewUp(self.GetActiveCamera().GetViewUp()) - camera.SetPosition(self.GetInitialCameraPosition()) - newposition = [i for i in self.GetInitialCameraPosition()] - if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: - dist = newposition[SLICE_ORIENTATION_XY] * ( 1 + dy ) - newposition[SLICE_ORIENTATION_XY] *= ( 1 + dy ) - elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: - newposition[SLICE_ORIENTATION_XZ] *= ( 1 + dy ) - elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: - newposition[SLICE_ORIENTATION_YZ] *= ( 1 + dy ) - #print ("new position " + str(newposition)) - camera.SetPosition(newposition) - self.SetActiveCamera(camera) - - self.Render() - - print ("distance after: " + str(self.GetActiveCamera().GetDistance())) - - def HandlePanEvent(self, interactor, event): - x,y = interactor.GetEventPosition() - x0,y0 = interactor.GetInitialEventPosition() - - ic = self.viewport2imageCoordinate((x,y)) - ic0 = self.viewport2imageCoordinate((x0,y0)) - - dx = 4 *( ic[0] - ic0[0]) - dy = 4* (ic[1] - ic0[1]) - - camera = vtk.vtkCamera() - #print ("current position " + str(self.InitialCameraPosition)) - camera.SetViewUp(self.GetActiveCamera().GetViewUp()) - camera.SetPosition(self.GetInitialCameraPosition()) - newposition = [i for i in self.GetInitialCameraPosition()] - newfocalpoint = [i for i in self.GetActiveCamera().GetFocalPoint()] - if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: - newposition[0] -= dx - newposition[1] -= dy - newfocalpoint[0] = newposition[0] - newfocalpoint[1] = newposition[1] - elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: - newposition[0] -= dx - newposition[2] -= dy - newfocalpoint[0] = newposition[0] - newfocalpoint[2] = newposition[2] - elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: - newposition[1] -= dx - newposition[2] -= dy - newfocalpoint[2] = newposition[2] - newfocalpoint[1] = newposition[1] - #print ("new position " + str(newposition)) - camera.SetFocalPoint(newfocalpoint) - camera.SetPosition(newposition) - self.SetActiveCamera(camera) - - self.Render() - - def HandleWindowLevel(self, interactor, event): - dx,dy = interactor.GetDeltaEventPosition() - print ("Event delta %d %d" % (dx,dy)) - size = self.GetRenderWindow().GetSize() - - dx = 4 * dx / size[0] - dy = 4 * dy / size[1] - window = self.GetInitialWindow() - level = self.GetInitialLevel() - - if abs(window) > 0.01: - dx = dx * window - else: - dx = dx * (lambda x: -0.01 if x <0 else 0.01)(window); - - if abs(level) > 0.01: - dy = dy * level - else: - dy = dy * (lambda x: -0.01 if x <0 else 0.01)(level) - - - # Abs so that direction does not flip - - if window < 0.0: - dx = -1*dx - if level < 0.0: - dy = -1*dy - - # Compute new window level - - newWindow = dx + window - newLevel = level - dy - - # Stay away from zero and really - - if abs(newWindow) < 0.01: - newWindow = 0.01 * (lambda x: -1 if x <0 else 1)(newWindow) - - if abs(newLevel) < 0.01: - newLevel = 0.01 * (lambda x: -1 if x <0 else 1)(newLevel) - - self.GetWindowLevel().SetWindow(newWindow) - self.GetWindowLevel().SetLevel(newLevel) - - self.GetWindowLevel().Update() - self.UpdateSliceActor() - self.AdjustCamera() - - self.Render() - - def HandlePickEvent(self, interactor, event): - position = interactor.GetEventPosition() - #print ("PICK " + str(position)) - vox = self.viewport2imageCoordinate(position) - #print ("Pixel %d,%d,%d Value %f" % vox ) - self._viewer.cornerAnnotation.VisibilityOn() - self.UpdateCornerAnnotation("[%d,%d,%d] : %.2f" % vox , 0) - self.Render() - -############################################################################### - - - -class CILViewer2D(): - '''Simple Interactive Viewer based on VTK classes''' - - def __init__(self, dimx=600,dimy=600, ren=None, renWin=None,iren=None): - '''creates the rendering pipeline''' - # create a rendering window and renderer - if ren == None: - self.ren = vtk.vtkRenderer() - else: - self.ren = ren - if renWin == None: - self.renWin = vtk.vtkRenderWindow() - else: - self.renWin = renWin - if iren == None: - self.iren = vtk.vtkRenderWindowInteractor() - else: - self.iren = iren - - self.renWin.SetSize(dimx,dimy) - self.renWin.AddRenderer(self.ren) - - self.style = CILInteractorStyle(self) - - self.iren.SetInteractorStyle(self.style) - self.iren.SetRenderWindow(self.renWin) - self.iren.Initialize() - self.ren.SetBackground(.1, .2, .4) - - self.camera = vtk.vtkCamera() - self.camera.ParallelProjectionOn() - self.ren.SetActiveCamera(self.camera) - - # data - self.img3D = None - self.sliceno = 0 - self.sliceOrientation = SLICE_ORIENTATION_XY - - #Actors - self.sliceActor = vtk.vtkImageActor() - self.voi = vtk.vtkExtractVOI() - self.wl = vtk.vtkImageMapToWindowLevelColors() - self.ia = vtk.vtkImageAccumulate() - self.sliceActorNo = 0 - - #initial Window/Level - self.InitialLevel = 0 - self.InitialWindow = 0 - - #ViewerEvent - self.event = ViewerEvent.NO_EVENT - - # ROI Widget - self.ROIWidget = vtk.vtkBorderWidget() - self.ROIWidget.SetInteractor(self.iren) - self.ROIWidget.CreateDefaultRepresentation() - self.ROIWidget.GetBorderRepresentation().GetBorderProperty().SetColor(0,1,0) - self.ROIWidget.AddObserver(vtk.vtkWidgetEvent.Select, self.style.OnROIModifiedEvent, 1.0) - - # edge points of the ROI - self.ROI = () - - #picker - self.picker = vtk.vtkPropPicker() - self.picker.PickFromListOn() - self.picker.AddPickList(self.sliceActor) - - self.iren.SetPicker(self.picker) - - # corner annotation - self.cornerAnnotation = vtk.vtkCornerAnnotation() - self.cornerAnnotation.SetMaximumFontSize(12); - self.cornerAnnotation.PickableOff(); - self.cornerAnnotation.VisibilityOff(); - self.cornerAnnotation.GetTextProperty().ShadowOn(); - self.cornerAnnotation.SetLayerNumber(1); - - - - # cursor doesn't show up - self.cursor = vtk.vtkCursor2D() - self.cursorMapper = vtk.vtkPolyDataMapper2D() - self.cursorActor = vtk.vtkActor2D() - self.cursor.SetModelBounds(-10, 10, -10, 10, 0, 0) - self.cursor.SetFocalPoint(0, 0, 0) - self.cursor.AllOff() - self.cursor.AxesOn() - self.cursorActor.PickableOff() - self.cursorActor.VisibilityOn() - self.cursorActor.GetProperty().SetColor(1, 1, 1) - self.cursorActor.SetLayerNumber(1) - self.cursorMapper.SetInputData(self.cursor.GetOutput()) - self.cursorActor.SetMapper(self.cursorMapper) - - # Zoom - self.InitialCameraPosition = () - - # XY Plot actor for histogram - self.displayHistogram = False - self.firstHistogram = 0 - self.roiIA = vtk.vtkImageAccumulate() - self.roiVOI = vtk.vtkExtractVOI() - self.histogramPlotActor = vtk.vtkXYPlotActor() - self.histogramPlotActor.ExchangeAxesOff(); - self.histogramPlotActor.SetXLabelFormat( "%g" ) - self.histogramPlotActor.SetXLabelFormat( "%g" ) - self.histogramPlotActor.SetAdjustXLabels(3) - self.histogramPlotActor.SetXTitle( "Level" ) - self.histogramPlotActor.SetYTitle( "N" ) - self.histogramPlotActor.SetXValuesToValue() - self.histogramPlotActor.SetPlotColor(0, (0,1,1) ) - self.histogramPlotActor.SetPosition(0.6,0.6) - self.histogramPlotActor.SetPosition2(0.4,0.4) - - - - def GetInteractor(self): - return self.iren - - def GetRenderer(self): - return self.ren - - def setInput3DData(self, imageData): - self.img3D = imageData - self.installPipeline() - - def setInputAsNumpy(self, numpyarray, origin=(0,0,0), spacing=(1.,1.,1.), - rescale=True, dtype=vtk.VTK_UNSIGNED_SHORT): - importer = Converter.numpy2vtkImporter(numpyarray, spacing, origin) - importer.Update() - - if rescale: - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(importer.GetOutput()) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - if (iMax - iMin == 0): - scale = 1 - else: - if dtype == vtk.VTK_UNSIGNED_SHORT: - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - elif dtype == vtk.VTK_UNSIGNED_INT: - scale = vtk.VTK_UNSIGNED_INT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(importer.GetOutput()) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(-iMin) - shiftScaler.SetOutputScalarType(dtype) - shiftScaler.Update() - self.img3D = shiftScaler.GetOutput() - else: - self.img3D = importer.GetOutput() - - self.installPipeline() - - def displaySlice(self, sliceno = 0): - self.sliceno = sliceno - - self.updatePipeline() - - self.renWin.Render() - - return self.sliceActorNo - - def updatePipeline(self, resetcamera = False): - extent = [ i for i in self.img3D.GetExtent()] - extent[self.sliceOrientation * 2] = self.sliceno - extent[self.sliceOrientation * 2 + 1] = self.sliceno - self.voi.SetVOI(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - - self.voi.Update() - self.ia.Update() - self.wl.Update() - self.sliceActor.SetDisplayExtent(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - self.sliceActor.Update() - - self.updateCornerAnnotation("Slice %d/%d" % (self.sliceno + 1 , self.img3D.GetDimensions()[self.sliceOrientation])) - - if self.displayHistogram: - self.updateROIHistogram() - - self.AdjustCamera(resetcamera) - - self.renWin.Render() - - - def installPipeline(self): - '''Slices a 3D volume and then creates an actor to be rendered''' - - self.ren.AddViewProp(self.cornerAnnotation) - - self.voi.SetInputData(self.img3D) - #select one slice in Z - extent = [ i for i in self.img3D.GetExtent()] - extent[self.sliceOrientation * 2] = self.sliceno - extent[self.sliceOrientation * 2 + 1] = self.sliceno - self.voi.SetVOI(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - - self.voi.Update() - # set window/level for current slices - - - self.wl = vtk.vtkImageMapToWindowLevelColors() - self.ia.SetInputData(self.voi.GetOutput()) - self.ia.Update() - cmax = self.ia.GetMax()[0] - cmin = self.ia.GetMin()[0] - - self.InitialLevel = (cmax+cmin)/2 - self.InitialWindow = cmax-cmin - - - self.wl.SetLevel(self.InitialLevel) - self.wl.SetWindow(self.InitialWindow) - - self.wl.SetInputData(self.voi.GetOutput()) - self.wl.Update() - - self.sliceActor.SetInputData(self.wl.GetOutput()) - self.sliceActor.SetDisplayExtent(extent[0], extent[1], - extent[2], extent[3], - extent[4], extent[5]) - self.sliceActor.Update() - self.sliceActor.SetInterpolate(False) - self.ren.AddActor(self.sliceActor) - self.ren.ResetCamera() - self.ren.Render() - - self.AdjustCamera() - - self.ren.AddViewProp(self.cursorActor) - self.cursorActor.VisibilityOn() - - self.iren.Initialize() - self.renWin.Render() - #self.iren.Start() - - def AdjustCamera(self, resetcamera = False): - self.ren.ResetCameraClippingRange() - if resetcamera: - self.ren.ResetCamera() - - - def getROI(self): - return self.ROI - - def getROIExtent(self): - p0 = self.ROI[0] - p1 = self.ROI[1] - return (p0[0], p1[0],p0[1],p1[1],p0[2],p1[2]) - - ############### Handle events are moved to the interactor style - - - def viewport2imageCoordinate(self, viewerposition): - #Determine point index - - self.picker.Pick(viewerposition[0], viewerposition[1], 0.0, self.GetRenderer()) - pickPosition = list(self.picker.GetPickPosition()) - pickPosition[self.sliceOrientation] = \ - self.img3D.GetSpacing()[self.sliceOrientation] * self.sliceno + \ - self.img3D.GetOrigin()[self.sliceOrientation] - print ("Pick Position " + str (pickPosition)) - - if (pickPosition != [0,0,0]): - dims = self.img3D.GetDimensions() - print (dims) - spac = self.img3D.GetSpacing() - orig = self.img3D.GetOrigin() - imagePosition = [int(pickPosition[i] / spac[i] + orig[i]) for i in range(3) ] - - pixelValue = self.img3D.GetScalarComponentAsDouble(imagePosition[0], imagePosition[1], imagePosition[2], 0) - return (imagePosition[0], imagePosition[1], imagePosition[2] , pixelValue) - else: - return (0,0,0,0) - - - - def GetRenderWindow(self): - return self.renWin - - - def startRenderLoop(self): - self.iren.Start() - - def GetSliceOrientation(self): - return self.sliceOrientation - - def GetActiveSlice(self): - return self.sliceno - - def updateCornerAnnotation(self, text , idx=0, visibility=True): - if visibility: - self.cornerAnnotation.VisibilityOn() - else: - self.cornerAnnotation.VisibilityOff() - - self.cornerAnnotation.SetText(idx, text) - self.iren.Render() - - def saveRender(self, filename, renWin=None): - '''Save the render window to PNG file''' - # screenshot code: - w2if = vtk.vtkWindowToImageFilter() - if renWin == None: - renWin = self.renWin - w2if.SetInput(renWin) - w2if.Update() - - writer = vtk.vtkPNGWriter() - writer.SetFileName("%s.png" % (filename)) - writer.SetInputConnection(w2if.GetOutputPort()) - writer.Write() - - def updateROIHistogram(self): - - extent = [0 for i in range(6)] - if self.GetSliceOrientation() == SLICE_ORIENTATION_XY: - print ("slice orientation : XY") - extent[0] = self.ROI[0][0] - extent[1] = self.ROI[1][0] - extent[2] = self.ROI[0][1] - extent[3] = self.ROI[1][1] - extent[4] = self.GetActiveSlice() - extent[5] = self.GetActiveSlice()+1 - #y = abs(roi[1][1] - roi[0][1]) - elif self.GetSliceOrientation() == SLICE_ORIENTATION_XZ: - print ("slice orientation : XY") - extent[0] = self.ROI[0][0] - extent[1] = self.ROI[1][0] - #x = abs(roi[1][0] - roi[0][0]) - extent[4] = self.ROI[0][2] - extent[5] = self.ROI[1][2] - #y = abs(roi[1][2] - roi[0][2]) - extent[2] = self.GetActiveSlice() - extent[3] = self.GetActiveSlice()+1 - elif self.GetSliceOrientation() == SLICE_ORIENTATION_YZ: - print ("slice orientation : XY") - extent[2] = self.ROI[0][1] - extent[3] = self.ROI[1][1] - #x = abs(roi[1][1] - roi[0][1]) - extent[4] = self.ROI[0][2] - extent[5] = self.ROI[1][2] - #y = abs(roi[1][2] - roi[0][2]) - extent[0] = self.GetActiveSlice() - extent[1] = self.GetActiveSlice()+1 - - self.roiVOI.SetVOI(extent) - self.roiVOI.SetInputData(self.img3D) - self.roiVOI.Update() - irange = self.roiVOI.GetOutput().GetScalarRange() - - self.roiIA.SetInputData(self.roiVOI.GetOutput()) - self.roiIA.IgnoreZeroOff() - self.roiIA.SetComponentExtent(0,int(irange[1]-irange[0]-1),0,0,0,0 ) - self.roiIA.SetComponentOrigin( int(irange[0]),0,0 ); - self.roiIA.SetComponentSpacing( 1,0,0 ); - self.roiIA.Update() - - self.histogramPlotActor.AddDataSetInputConnection(self.roiIA.GetOutputPort()) - self.histogramPlotActor.SetXRange(irange[0],irange[1]) - - self.histogramPlotActor.SetYRange( self.roiIA.GetOutput().GetScalarRange() ) - - \ No newline at end of file diff --git a/src/Python/ccpi/viewer/QVTKWidget.py b/src/Python/ccpi/viewer/QVTKWidget.py deleted file mode 100644 index 906786b..0000000 --- a/src/Python/ccpi/viewer/QVTKWidget.py +++ /dev/null @@ -1,340 +0,0 @@ -################################################################################ -# File: QVTKWidget.py -# Author: Edoardo Pasca -# Description: PyVE Viewer Qt widget -# -# License: -# This file is part of PyVE. PyVE is an open-source image -# analysis and visualization environment focused on medical -# imaging. More info at http://pyve.sourceforge.net -# -# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or -# without modification, are permitted provided that the following -# conditions are met: -# -# Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following -# disclaimer in the documentation and/or other materials provided -# with the distribution. Neither name of Edoardo Pasca or Lukas -# Batteau nor the names of any contributors may be used to endorse -# or promote products derived from this software without specific -# prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND -# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, -# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, -# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT -# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY -# OF SUCH DAMAGE. -# -# CHANGE HISTORY -# -# 20120118 Edoardo Pasca Initial version -# -############################################################################### - -import os -from PyQt5 import QtCore, QtGui, QtWidgets -#import itk -import vtk -#from viewer import PyveViewer -from ccpi.viewer.CILViewer2D import CILViewer2D , Converter - -class QVTKWidget(QtWidgets.QWidget): - - """ A QVTKWidget for Python and Qt.""" - - # Map between VTK and Qt cursors. - _CURSOR_MAP = { - 0: QtCore.Qt.ArrowCursor, # VTK_CURSOR_DEFAULT - 1: QtCore.Qt.ArrowCursor, # VTK_CURSOR_ARROW - 2: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZENE - 3: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZENWSE - 4: QtCore.Qt.SizeBDiagCursor, # VTK_CURSOR_SIZESW - 5: QtCore.Qt.SizeFDiagCursor, # VTK_CURSOR_SIZESE - 6: QtCore.Qt.SizeVerCursor, # VTK_CURSOR_SIZENS - 7: QtCore.Qt.SizeHorCursor, # VTK_CURSOR_SIZEWE - 8: QtCore.Qt.SizeAllCursor, # VTK_CURSOR_SIZEALL - 9: QtCore.Qt.PointingHandCursor, # VTK_CURSOR_HAND - 10: QtCore.Qt.CrossCursor, # VTK_CURSOR_CROSSHAIR - } - - def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): - # the current button - self._ActiveButton = QtCore.Qt.NoButton - - # private attributes - self.__oldFocus = None - self.__saveX = 0 - self.__saveY = 0 - self.__saveModifiers = QtCore.Qt.NoModifier - self.__saveButtons = QtCore.Qt.NoButton - self.__timeframe = 0 - - # create qt-level widget - QtWidgets.QWidget.__init__(self, parent, wflags|QtCore.Qt.MSWindowsOwnDC) - - # Link to PyVE Viewer - self._PyveViewer = CILViewer2D() - #self._Viewer = self._PyveViewer._vtkPyveViewer - - self._Iren = self._PyveViewer.GetInteractor() - #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() - self._RenderWindow = self._PyveViewer.GetRenderWindow() - #self._RenderWindow = self._Viewer.GetRenderWindow() - - self._Iren.Register(self._RenderWindow) - self._Iren.SetRenderWindow(self._RenderWindow) - self._RenderWindow.SetWindowInfo(str(int(self.winId()))) - - # do all the necessary qt setup - self.setAttribute(QtCore.Qt.WA_OpaquePaintEvent) - self.setAttribute(QtCore.Qt.WA_PaintOnScreen) - self.setMouseTracking(True) # get all mouse events - self.setFocusPolicy(QtCore.Qt.WheelFocus) - self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)) - - self._Timer = QtCore.QTimer(self) - #self.connect(self._Timer, QtCore.pyqtSignal('timeout()'), self.TimerEvent) - - self._Iren.AddObserver('CreateTimerEvent', self.CreateTimer) - self._Iren.AddObserver('DestroyTimerEvent', self.DestroyTimer) - self._Iren.GetRenderWindow().AddObserver('CursorChangedEvent', - self.CursorChangedEvent) - - # Destructor - def __del__(self): - self._Iren.UnRegister(self._RenderWindow) - #QtWidgets.QWidget.__del__(self) - - # Display image data - def SetInput(self, imageData): - self._PyveViewer.setInput3DData(imageData) - - # GetInteractor - def GetInteractor(self): - return self._Iren - - # Display image data - def GetPyveViewer(self): - return self._PyveViewer - - def __getattr__(self, attr): - """Makes the object behave like a vtkGenericRenderWindowInteractor""" - print (attr) - if attr == '__vtk__': - return lambda t=self._Iren: t - elif hasattr(self._Iren, attr): - return getattr(self._Iren, attr) -# else: -# raise AttributeError( self.__class__.__name__ + \ -# " has no attribute named " + attr ) - - def CreateTimer(self, obj, evt): - self._Timer.start(10) - - def DestroyTimer(self, obj, evt): - self._Timer.stop() - return 1 - - def TimerEvent(self): - self._Iren.InvokeEvent("TimerEvent") - - def CursorChangedEvent(self, obj, evt): - """Called when the CursorChangedEvent fires on the render window.""" - # This indirection is needed since when the event fires, the current - # cursor is not yet set so we defer this by which time the current - # cursor should have been set. - QtCore.QTimer.singleShot(0, self.ShowCursor) - - def HideCursor(self): - """Hides the cursor.""" - self.setCursor(QtCore.Qt.BlankCursor) - - def ShowCursor(self): - """Shows the cursor.""" - vtk_cursor = self._Iren.GetRenderWindow().GetCurrentCursor() - qt_cursor = self._CURSOR_MAP.get(vtk_cursor, QtCore.Qt.ArrowCursor) - self.setCursor(qt_cursor) - - def sizeHint(self): - return QtCore.QSize(400, 400) - - def paintEngine(self): - return None - - def paintEvent(self, ev): - self._RenderWindow.Render() - - def resizeEvent(self, ev): - self._RenderWindow.Render() - w = self.width() - h = self.height() - - self._RenderWindow.SetSize(w, h) - self._Iren.SetSize(w, h) - - def _GetCtrlShiftAlt(self, ev): - ctrl = shift = alt = False - - if hasattr(ev, 'modifiers'): - if ev.modifiers() & QtCore.Qt.ShiftModifier: - shift = True - if ev.modifiers() & QtCore.Qt.ControlModifier: - ctrl = True - if ev.modifiers() & QtCore.Qt.AltModifier: - alt = True - else: - if self.__saveModifiers & QtCore.Qt.ShiftModifier: - shift = True - if self.__saveModifiers & QtCore.Qt.ControlModifier: - ctrl = True - if self.__saveModifiers & QtCore.Qt.AltModifier: - alt = True - - return ctrl, shift, alt - - def enterEvent(self, ev): - if not self.hasFocus(): - self.__oldFocus = self.focusWidget() - self.setFocus() - - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, - ctrl, shift, chr(0), 0, None) - self._Iren.SetAltKey(alt) - self._Iren.InvokeEvent("EnterEvent") - - def leaveEvent(self, ev): - if self.__saveButtons == QtCore.Qt.NoButton and self.__oldFocus: - self.__oldFocus.setFocus() - self.__oldFocus = None - - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, - ctrl, shift, chr(0), 0, None) - self._Iren.SetAltKey(alt) - self._Iren.InvokeEvent("LeaveEvent") - - def mousePressEvent(self, ev): - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - repeat = 0 - if ev.type() == QtCore.QEvent.MouseButtonDblClick: - repeat = 1 - self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), - ctrl, shift, chr(0), repeat, None) - - self._Iren.SetAltKey(alt) - self._ActiveButton = ev.button() - - if self._ActiveButton == QtCore.Qt.LeftButton: - self._Iren.InvokeEvent("LeftButtonPressEvent") - elif self._ActiveButton == QtCore.Qt.RightButton: - self._Iren.InvokeEvent("RightButtonPressEvent") - elif self._ActiveButton == QtCore.Qt.MidButton: - self._Iren.InvokeEvent("MiddleButtonPressEvent") - - def mouseReleaseEvent(self, ev): - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), - ctrl, shift, chr(0), 0, None) - self._Iren.SetAltKey(alt) - - if self._ActiveButton == QtCore.Qt.LeftButton: - self._Iren.InvokeEvent("LeftButtonReleaseEvent") - elif self._ActiveButton == QtCore.Qt.RightButton: - self._Iren.InvokeEvent("RightButtonReleaseEvent") - elif self._ActiveButton == QtCore.Qt.MidButton: - self._Iren.InvokeEvent("MiddleButtonReleaseEvent") - - def mouseMoveEvent(self, ev): - self.__saveModifiers = ev.modifiers() - self.__saveButtons = ev.buttons() - self.__saveX = ev.x() - self.__saveY = ev.y() - - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - self._Iren.SetEventInformationFlipY(ev.x(), ev.y(), - ctrl, shift, chr(0), 0, None) - self._Iren.SetAltKey(alt) - self._Iren.InvokeEvent("MouseMoveEvent") - - def keyPressEvent(self, ev): - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - if ev.key() < 256: - key = str(ev.text()) - else: - key = chr(0) - - self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, - ctrl, shift, key, 0, None) - self._Iren.SetAltKey(alt) - self._Iren.InvokeEvent("KeyPressEvent") - self._Iren.InvokeEvent("CharEvent") - - def keyReleaseEvent(self, ev): - ctrl, shift, alt = self._GetCtrlShiftAlt(ev) - if ev.key() < 256: - key = chr(ev.key()) - else: - key = chr(0) - - self._Iren.SetEventInformationFlipY(self.__saveX, self.__saveY, - ctrl, shift, key, 0, None) - self._Iren.SetAltKey(alt) - self._Iren.InvokeEvent("KeyReleaseEvent") - - def wheelEvent(self, ev): - print ("angleDeltaX %d" % ev.angleDelta().x()) - print ("angleDeltaY %d" % ev.angleDelta().y()) - if ev.angleDelta().y() >= 0: - self._Iren.InvokeEvent("MouseWheelForwardEvent") - else: - self._Iren.InvokeEvent("MouseWheelBackwardEvent") - - def GetRenderWindow(self): - return self._RenderWindow - - def Render(self): - self.update() - - -def QVTKExample(): - """A simple example that uses the QVTKWidget class.""" - - # every QT app needs an app - app = QtWidgets.QApplication(['PyVE QVTKWidget Example']) - page_VTK = QtWidgets.QWidget() - page_VTK.resize(500,500) - layout = QtWidgets.QVBoxLayout(page_VTK) - # create the widget - widget = QVTKWidget(parent=None) - layout.addWidget(widget) - - #reader = vtk.vtkPNGReader() - #reader.SetFileName("F:\Diagnostics\Images\PyVE\VTKData\Data\camscene.png") - reader = vtk.vtkMetaImageReader() - reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") - reader.Update() - - widget.SetInput(reader.GetOutput()) - - # show the widget - page_VTK.show() - # start event processing - app.exec_() - -if __name__ == "__main__": - QVTKExample() diff --git a/src/Python/ccpi/viewer/QVTKWidget2.py b/src/Python/ccpi/viewer/QVTKWidget2.py deleted file mode 100644 index e32e1c2..0000000 --- a/src/Python/ccpi/viewer/QVTKWidget2.py +++ /dev/null @@ -1,84 +0,0 @@ -################################################################################ -# File: QVTKWidget.py -# Author: Edoardo Pasca -# Description: PyVE Viewer Qt widget -# -# License: -# This file is part of PyVE. PyVE is an open-source image -# analysis and visualization environment focused on medical -# imaging. More info at http://pyve.sourceforge.net -# -# Copyright (c) 2011-2012 Edoardo Pasca, Lukas Batteau -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or -# without modification, are permitted provided that the following -# conditions are met: -# -# Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following -# disclaimer in the documentation and/or other materials provided -# with the distribution. Neither name of Edoardo Pasca or Lukas -# Batteau nor the names of any contributors may be used to endorse -# or promote products derived from this software without specific -# prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND -# CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, -# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, -# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT -# OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY -# OF SUCH DAMAGE. -# -# CHANGE HISTORY -# -# 20120118 Edoardo Pasca Initial version -# -############################################################################### - -import os -from PyQt5 import QtCore, QtGui, QtWidgets -#import itk -import vtk -#from viewer import PyveViewer -from ccpi.viewer.CILViewer2D import CILViewer2D , Converter -from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor - -class QVTKWidget(QVTKRenderWindowInteractor): - - - def __init__(self, parent=None, wflags=QtCore.Qt.WindowFlags(), **kw): - kw = dict() - super().__init__(parent, **kw) - - - # Link to PyVE Viewer - self._PyveViewer = CILViewer2D(400,400) - #self._Viewer = self._PyveViewer._vtkPyveViewer - - self._Iren = self._PyveViewer.GetInteractor() - kw['iren'] = self._Iren - #self._Iren = self._Viewer.GetRenderWindow().GetInteractor() - self._RenderWindow = self._PyveViewer.GetRenderWindow() - #self._RenderWindow = self._Viewer.GetRenderWindow() - kw['rw'] = self._RenderWindow - - - - - def GetInteractor(self): - return self._Iren - - # Display image data - def SetInput(self, imageData): - self._PyveViewer.setInput3DData(imageData) - \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__init__.py b/src/Python/ccpi/viewer/__init__.py deleted file mode 100644 index 946188b..0000000 --- a/src/Python/ccpi/viewer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ccpi.viewer.CILViewer import CILViewer \ No newline at end of file diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc deleted file mode 100644 index 711f77a..0000000 Binary files a/src/Python/ccpi/viewer/__pycache__/CILViewer.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc deleted file mode 100644 index 77c2ca8..0000000 Binary files a/src/Python/ccpi/viewer/__pycache__/CILViewer2D.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc deleted file mode 100644 index 3d11b87..0000000 Binary files a/src/Python/ccpi/viewer/__pycache__/QVTKWidget.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc deleted file mode 100644 index 2fa2eaf..0000000 Binary files a/src/Python/ccpi/viewer/__pycache__/QVTKWidget2.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc b/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc deleted file mode 100644 index fcea537..0000000 Binary files a/src/Python/ccpi/viewer/__pycache__/__init__.cpython-35.pyc and /dev/null differ diff --git a/src/Python/ccpi/viewer/embedvtk.py b/src/Python/ccpi/viewer/embedvtk.py deleted file mode 100644 index b5eb0a7..0000000 --- a/src/Python/ccpi/viewer/embedvtk.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Thu Jul 27 12:18:58 2017 - -@author: ofn77899 -""" - -#!/usr/bin/env python - -import sys -import vtk -from PyQt5 import QtCore, QtWidgets -from vtk.qt.QVTKRenderWindowInteractor import QVTKRenderWindowInteractor -import QVTKWidget2 - -class MainWindow(QtWidgets.QMainWindow): - - def __init__(self, parent = None): - QtWidgets.QMainWindow.__init__(self, parent) - - self.frame = QtWidgets.QFrame() - - self.vl = QtWidgets.QVBoxLayout() -# self.vtkWidget = QVTKRenderWindowInteractor(self.frame) - - self.vtkWidget = QVTKWidget2.QVTKWidget(self.frame) - self.iren = self.vtkWidget.GetInteractor() - self.vl.addWidget(self.vtkWidget) - - - - - self.ren = vtk.vtkRenderer() - self.vtkWidget.GetRenderWindow().AddRenderer(self.ren) -# self.iren = self.vtkWidget.GetRenderWindow().GetInteractor() -# -# # Create source -# source = vtk.vtkSphereSource() -# source.SetCenter(0, 0, 0) -# source.SetRadius(5.0) -# -# # Create a mapper -# mapper = vtk.vtkPolyDataMapper() -# mapper.SetInputConnection(source.GetOutputPort()) -# -# # Create an actor -# actor = vtk.vtkActor() -# actor.SetMapper(mapper) -# -# self.ren.AddActor(actor) -# -# self.ren.ResetCamera() -# - self.frame.setLayout(self.vl) - self.setCentralWidget(self.frame) - reader = vtk.vtkMetaImageReader() - reader.SetFileName("C:\\Users\\ofn77899\\Documents\\GitHub\\CCPi-Simpleflex\\data\\head.mha") - reader.Update() - - self.vtkWidget.SetInput(reader.GetOutput()) - - #self.vktWidget.Initialize() - #self.vktWidget.Start() - - self.show() - #self.iren.Initialize() - - -if __name__ == "__main__": - - app = QtWidgets.QApplication(sys.argv) - - window = MainWindow() - - sys.exit(app.exec_()) \ No newline at end of file -- cgit v1.2.3 From 99e8a3130d6ee161fc8e73faf526d7e0a7a9db44 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 18 Oct 2017 11:46:54 +0100 Subject: Modified Region of interest; removed studentt https://github.com/vais-ral/CCPi-FISTA_Reconstruction/commit/6fb8f5d188ed31d7a7077cba8ab7aea17b25b8bf --- src/Python/ccpi/fista/FISTAReconstructor.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index 87dd2c0..33e67a3 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -108,9 +108,7 @@ class FISTAReconstructor(): 'regularizer' , 'ring_lambda_R_L1', 'ring_alpha', - 'subsets', - 'use_studentt_fidelity', - 'studentt') + 'subsets') self.acceptedInputKeywords = list(kw) # handle keyworded parameters @@ -141,10 +139,12 @@ class FISTAReconstructor(): if not 'region_of_interest'in kwargs.keys() : if self.pars['ideal_image'] == None: - pass + self.pars['region_of_interest'] = None else: - self.pars['region_of_interest'] = numpy.nonzero( - self.pars['ideal_image']>0.0) + ## nonzero if the image is larger than m + fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) + + self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) # the regularizer must be a correctly instantiated object if not 'regularizer' in kwargs.keys() : @@ -165,14 +165,7 @@ class FISTAReconstructor(): if not 'initialize' in kwargs.keys(): self.pars['initialize'] = False - if not 'use_studentt_fidelity' in kwargs.keys(): - self.setParameter(studentt=False) - else: - print ("studentt {0}".format(kwargs['use_studentt_fidelity'])) - if kwargs['use_studentt_fidelity']: - raise Exception('Not implemented') - - self.setParameter(studentt=kwargs['use_studentt_fidelity']) + def setParameter(self, **kwargs): -- cgit v1.2.3 From 9a126e05d03a474850c122cc44e971383069fb8d Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 18 Oct 2017 11:48:16 +0100 Subject: minor reorganization of the code added RSME --- src/Python/test_reconstructor.py | 97 +++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 52 deletions(-) (limited to 'src') diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index f8f6b3c..2f188b4 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -11,10 +11,17 @@ import numpy from ccpi.fista.FISTAReconstructor import FISTAReconstructor import astra +import matplotlib.pyplot as plt -##def getEntry(nx, location): -## for item in nx[location].keys(): -## print (item) +def RMSE(signal1, signal2): + '''RMSE Root Mean Squared Error''' + if numpy.shape(signal1) == numpy.shape(signal2): + err = (signal1 - signal2) + err = numpy.sum( err * err )/numpy.size(signal1); # MSE + err = sqrt(err); # RMSE + return err + else: + raise Exception('Input signals must have the same shape') filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' nx = h5py.File(filename, "r") @@ -68,7 +75,6 @@ fistaRecon.setParameter(number_of_iterations = 12) fistaRecon.setParameter(Lipschitz_constant = 767893952.0) fistaRecon.setParameter(ring_alpha = 21) fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) -#fistaRecon.setParameter(use_studentt_fidelity= True) ## Ordered subset if False: @@ -95,18 +101,33 @@ if False: if True: - fistaRecon.prepareForIteration() print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) - + print ("prepare for iteration") + fistaRecon.prepareForIteration() + + print("initializing ...") + if False: + # if X doesn't exist + #N = params.vol_geom.GridColCount + N = vol_geom['GridColCount'] + print ("N " + str(N)) + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + #X = fistaRecon.initialize() + X = numpy.load("X.npy") + + print (numpy.shape(X)) + X_t = X.copy() + print ("initialized") proj_geom , vol_geom, sino , \ SlicesZ = fistaRecon.getParameter(['projector_geometry' , 'output_geometry', 'input_sinogram', 'SlicesZ']) - fistaRecon.setParameter(number_of_iterations = 3) + #fistaRecon.setParameter(number_of_iterations = 3) iterFISTA = fistaRecon.getParameter('number_of_iterations') # errors vector (if the ground truth is given) Resid_error = numpy.zeros((iterFISTA)); @@ -114,23 +135,10 @@ if True: objective = numpy.zeros((iterFISTA)); - print ("line") t = 1 - print ("line") - if False: - # if X doesn't exist - #N = params.vol_geom.GridColCount - N = vol_geom['GridColCount'] - print ("N " + str(N)) - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - #X = fistaRecon.initialize() - X = numpy.load("X.npy") - - print (numpy.shape(X)) - X_t = X.copy() - print ("X_t copy") + + print ("starting iterations") ## % Outer FISTA iterations loop for i in range(fistaRecon.getParameter('number_of_iterations')): X_old = X.copy() @@ -147,7 +155,6 @@ if True: vol_geomT['GridSliceCount'] = 1; sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) for kkk in range(SlicesZ): - print (kkk) sino_id, sino_updt[kkk] = \ astra.creators.create_sino3d_gpu( X_t[kkk:kkk+1], proj_geomT, vol_geomT) @@ -169,8 +176,9 @@ if True: SlicesZ, anglesNumb, Detectors = \ numpy.shape(fistaRecon.getParameter('input_sinogram')) if lambdaR_L1 > 0 : + print ("ring removal") for kkk in range(anglesNumb): - print ("angles {0}".format(kkk)) + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ ((sino_updt[:,kkk,:]).squeeze() - \ (sino[:,kkk,:]).squeeze() -\ @@ -194,39 +202,15 @@ if True: ## r = r_x - (1./L_const).*vec; ## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output - else: - if fistaRecon.getParameter('use_studentt_fidelity'): - residual = weights * (sino_updt - sino) - for kkk in range(SlicesZ): - # reshape(residual(:,:,kkk), Detectors*anglesNumb, 1) - # 1D - res_vec = numpy.reshape(residual[kkk], (Detectors * anglesNumb,1)) - -## else -## if (studentt == 1) -## % artifacts removal with Students t penalty -## residual = weights.*(sino_updt - sino); -## for kkk = 1:SlicesZ -## res_vec = reshape(residual(:,:,kkk), Detectors*anglesNumb, 1); % 1D vectorized sinogram -## %s = 100; -## %gr = (2)*res_vec./(s*2 + conj(res_vec).*res_vec); -## [ff, gr] = studentst(res_vec, 1); -## residual(:,:,kkk) = reshape(gr, Detectors, anglesNumb); -## end -## objective(i) = ff; % for the objective function output -## else -## % no ring removal (LS model) -## residual = weights.*(sino_updt - sino); -## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output -## end -## end + # Projection/Backprojection Routine if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d': x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + print ("Projection/Backprojection Routine") for kkk in range(SlicesZ): - print ("Projection/Backprojection Routine {0}".format( kkk )) + x_id, x_temp[kkk] = \ astra.creators.create_backprojection3d_gpu( residual[kkk:kkk+1], @@ -248,9 +232,11 @@ if True: # regularizer = fistaRecon.getParameter('regularizer') # for slices: # out = regularizer(input=X) + print ("skipping regularizer") ## FINAL + print ("final") lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') if lambdaR_L1 > 0: fistaRecon.r = numpy.max( @@ -263,9 +249,16 @@ if True: fistaRecon.r_x = fistaRecon.r + \ (((t_old-1)/t) * (fistaRecon.r - r_old)) - if fistaRecon.getParameter('ideal_image') is None: + if fistaRecon.getParameter('region_of_interest') is None: string = 'Iteration Number {0} | Objective {1} \n' print (string.format( i, objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], objective[i])) ## if (lambdaR_L1 > 0) ## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector -- cgit v1.2.3 From c097c34a59f80a6d4475a1f783b772fa42a44862 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 18 Oct 2017 16:54:06 +0100 Subject: implemented non ordered subset FISTA in reconstructor --- src/Python/ccpi/fista/FISTAReconstructor.py | 165 ++++++++++++++++++++++++---- src/Python/test_reconstructor.py | 25 ++++- 2 files changed, 165 insertions(+), 25 deletions(-) (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index 33e67a3..fda9cf0 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -85,6 +85,7 @@ class FISTAReconstructor(): self.pars['detectors'] = detectors self.pars['number_of_angles'] = nangles self.pars['SlicesZ'] = sliceZ + self.pars['output_volume'] = None print (self.pars) # handle optional input parameters (at instantiation) @@ -108,7 +109,11 @@ class FISTAReconstructor(): 'regularizer' , 'ring_lambda_R_L1', 'ring_alpha', - 'subsets') + 'subsets', + 'output_volume', + 'os_subsets', + 'os_indices', + 'os_bins') self.acceptedInputKeywords = list(kw) # handle keyworded parameters @@ -176,8 +181,6 @@ class FISTAReconstructor(): ''' for key , value in kwargs.items(): if key in self.acceptedInputKeywords: - if key == 'use_studentt_fidelity': - raise Exception('use_studentt_fidelity Not implemented') self.pars[key] = value else: raise Exception('Wrong parameter {0} for '.format(key) + @@ -382,11 +385,15 @@ class FISTAReconstructor(): counter = counter + binsDiscr[jj] - 1 - - return IndicesReorg + # store the OS in parameters + self.setParameter(os_subsets=subsets, + os_bins=binsDiscr, + os_indices=IndicesReorg) def prepareForIteration(self): + print ("FISTA Reconstructor: prepare for iteration") + self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) self.objective = numpy.zeros((self.pars['number_of_iterations'])) @@ -401,19 +408,17 @@ class FISTAReconstructor(): if self.getParameter('Lipschitz_constant') is None: self.pars['Lipschitz_constant'] = \ self.calculateLipschitzConstantWithPowerMethod() + # errors vector (if the ground truth is given) + self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); + # objective function values vector + self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); # prepareForIteration def iterate(self, Xin=None): - # convenience variable storage - proj_geom , vol_geom, sino , \ - SlicesZ = self.getParameter([ 'projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ']) - - t = 1 + print ("FISTA Reconstructor: iterate") + if Xin is None: if self.getParameter('initialize'): X = self.initialize() @@ -423,15 +428,25 @@ class FISTAReconstructor(): else: # copy by reference X = Xin - + # store the output volume in the parameters + self.setParameter(output_volume=X) X_t = X.copy() + # convenience variable storage + proj_geom , vol_geom, sino , \ + SlicesZ = self.getParameter([ 'projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ' ]) + + t = 1 for i in range(self.getParameter('number_of_iterations')): X_old = X.copy() t_old = t r_old = self.r.copy() if self.getParameter('projector_geometry')['type'] == 'parallel' or \ - self.getParameter('projector_geometry')['type'] == 'parallel3d': + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': # if the geometry is parallel use slice-by-slice # projection-backprojection routine #sino_updt = zeros(size(sino),'single'); @@ -439,10 +454,9 @@ class FISTAReconstructor(): proj_geomT['DetectorRowCount'] = 1 vol_geomT = vol_geom.copy() vol_geomT['GridSliceCount'] = 1; - sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) for kkk in range(SlicesZ): - print (kkk) - sino_id, sino_updt[kkk] = \ + sino_id, self.sino_updt[kkk] = \ astra.creators.create_sino3d_gpu( X_t[kkk:kkk+1], proj_geomT, vol_geomT) astra.matlab.data3d('delete', sino_id) @@ -450,11 +464,122 @@ class FISTAReconstructor(): # for divergent 3D geometry (watch the GPU memory overflow in # ASTRA versions < 1.8) #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - sino_id, sino_updt = astra.matlab.create_sino3d_gpu( + sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( X_t, proj_geom, vol_geom) ## RING REMOVAL - + self.ringRemoval(i) + ## Projection/Backprojection Routine + self.projectionBackprojection(X, X_t) + astra.matlab.data3d('delete', sino_id) ## REGULARIZATION + X = self.regularize(X) + ## Update Loop + X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) + self.setParameter(output_volume=X) + return X + ## iterate + + def ringRemoval(self, i): + print ("FISTA Reconstructor: ring removal") + residual = self.residual + lambdaR_L1 , alpha_ring , weights , L_const , sino= \ + self.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant', + 'input_sinogram']) + r_x = self.r_x + sino_updt = self.sino_updt + + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + for kkk in range(anglesNumb): + + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + self.r = (r_x - (1./L_const) * vec).copy() + self.objective[i] = (0.5 * (residual ** 2).sum()) + + def projectionBackprojection(self, X, X_t): + print ("FISTA Reconstructor: projection-backprojection routine") + + # a few useful variables + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + residual = self.residual + proj_geom , vol_geom , L_const = \ + self.getParameter(['projector_geometry' , + 'output_geometry', + 'Lipschitz_constant']) + + + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + #astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) + + def regularize(self, X): + print ("FISTA Reconstructor: regularize") + + regularizer = self.getParameter('regularizer') + if regularizer is not None: + return regularizer(input=X) + else: + return X + + def updateLoop(self, i, X, X_old, r_old, t, t_old): + print ("FISTA Reconstructor: update loop") + lambdaR_L1 = self.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + self.r = numpy.max( + numpy.abs(self.r) - lambdaR_L1 , 0) * \ + numpy.sign(self.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + self.r_x = self.r + \ + (((t_old-1)/t) * (self.r - r_old)) + + if self.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, self.objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], self.objective[i])) + return (X , X_t, t) diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py index 2f188b4..07668ba 100644 --- a/src/Python/test_reconstructor.py +++ b/src/Python/test_reconstructor.py @@ -100,7 +100,7 @@ if False: counter = counter + binsDiscr[jj] - 1 -if True: +if False: print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) print ("prepare for iteration") fistaRecon.prepareForIteration() @@ -145,7 +145,8 @@ if True: t_old = t r_old = fistaRecon.r.copy() if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ - fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d': + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' : # if the geometry is parallel use slice-by-slice # projection-backprojection routine #sino_updt = zeros(size(sino),'single'); @@ -157,13 +158,13 @@ if True: for kkk in range(SlicesZ): sino_id, sino_updt[kkk] = \ astra.creators.create_sino3d_gpu( - X_t[kkk:kkk+1], proj_geomT, vol_geomT) + X_t[kkk:kkk+1], proj_geom, vol_geom) astra.matlab.data3d('delete', sino_id) else: # for divergent 3D geometry (watch the GPU memory overflow in # ASTRA versions < 1.8) #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - sino_id, sino_updt = astra.matlab.create_sino3d_gpu( + sino_id, sino_updt = astra.creators.create_sino3d_gpu( X_t, proj_geom, vol_geom) ## RING REMOVAL @@ -206,7 +207,8 @@ if True: # Projection/Backprojection Routine if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ - fistaRecon.getParameter('projector_geometry')['type'] == 'parallel3d': + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec': x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) print ("Projection/Backprojection Routine") for kkk in range(SlicesZ): @@ -284,3 +286,16 @@ if True: ## else ## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); ## end +else: + fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) + + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + fistaRecon.setParameter(number_of_iterations = 12) + fistaRecon.setParameter(Lipschitz_constant = 767893952.0) + fistaRecon.setParameter(ring_alpha = 21) + fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + fistaRecon.prepareForIteration() + X = fistaRecon.iterate(numpy.load("X.npy")) -- cgit v1.2.3 From c7f0f2268f94b62d2e2deee736939ad75d3dc1b1 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 18 Oct 2017 16:54:39 +0100 Subject: added to repository --- src/Python/test_reconstructor-os.py | 379 ++++++++++++++++++++++++++++++++++++ 1 file changed, 379 insertions(+) create mode 100644 src/Python/test_reconstructor-os.py (limited to 'src') diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py new file mode 100644 index 0000000..6f3721f --- /dev/null +++ b/src/Python/test_reconstructor-os.py @@ -0,0 +1,379 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.fista.FISTAReconstructor import FISTAReconstructor +import astra +import matplotlib.pyplot as plt + +def RMSE(signal1, signal2): + '''RMSE Root Mean Squared Error''' + if numpy.shape(signal1) == numpy.shape(signal2): + err = (signal1 - signal2) + err = numpy.sum( err * err )/numpy.size(signal1); # MSE + err = sqrt(err); # RMSE + return err + else: + raise Exception('Input signals must have the same shape') + +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") +Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 20 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, + image_size_y, + image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) + +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +fistaRecon.setParameter(number_of_iterations = 12) +fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +fistaRecon.setParameter(ring_alpha = 21) +fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + +## Ordered subset +if True: + subsets = 16 + angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles))) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + +if True: + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + print ("prepare for iteration") + fistaRecon.prepareForIteration() + + + + print("initializing ...") + if False: + # if X doesn't exist + #N = params.vol_geom.GridColCount + N = vol_geom['GridColCount'] + print ("N " + str(N)) + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + #X = fistaRecon.initialize() + X = numpy.load("X.npy") + + print (numpy.shape(X)) + X_t = X.copy() + print ("initialized") + proj_geom , vol_geom, sino , \ + SlicesZ = fistaRecon.getParameter(['projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ']) + + #fistaRecon.setParameter(number_of_iterations = 3) + iterFISTA = fistaRecon.getParameter('number_of_iterations') + # errors vector (if the ground truth is given) + Resid_error = numpy.zeros((iterFISTA)); + # objective function values vector + objective = numpy.zeros((iterFISTA)); + + + t = 1 + + ## additional for + proj_geomSUB = proj_geom.copy() + fistaRecon.residual2 = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) + print ("starting iterations") +## % Outer FISTA iterations loop + for i in range(fistaRecon.getParameter('number_of_iterations')): +## % With OS approach it becomes trickier to correlate independent subsets, hence additional work is required +## % one solution is to work with a full sinogram at times +## if ((i >= 3) && (lambdaR_L1 > 0)) +## [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X, proj_geom, vol_geom); +## astra_mex_data3d('delete', sino_id2); +## end + # With OS approach it becomes trickier to correlate independent subsets, + # hence additional work is required one solution is to work with a full + # sinogram at times + + ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 + if (lambdaR_L1 > 0) : + sino_id2, sino_updt2 = astra.creators.create_sino3d_gpu( + X, proj_geom, vol_geom) + astra.matlab.data3d('delete', sino_id2) + + # subset loop + counterInd = 1 + for ss in range(fistaRecon.getParameter('subsets')): + print ("Subset {0}".format(ss)) + X_old = X.copy() + t_old = t + r_old = fistaRecon.r.copy() + + # the number of projections per subset + numProjSub = fistaRecon.getParameter('os_bins')[ss] + CurrSubIndices = fistaRecon.getParameter('os_indices')\ + [counterInd:counterInd+numProjSub-1] + proj_geomSUB['ProjectionAngles'] = angles[CurrSubIndeces] + +## if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ +## fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \ +## fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' : +## # if the geometry is parallel use slice-by-slice +## # projection-backprojection routine +## #sino_updt = zeros(size(sino),'single'); +## proj_geomT = proj_geom.copy() +## proj_geomT['DetectorRowCount'] = 1 +## vol_geomT = vol_geom.copy() +## vol_geomT['GridSliceCount'] = 1; +## sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) +## for kkk in range(SlicesZ): +## sino_id, sino_updt[kkk] = \ +## astra.creators.create_sino3d_gpu( +## X_t[kkk:kkk+1], proj_geom, vol_geom) +## astra.matlab.data3d('delete', sino_id) +## else: +## # for divergent 3D geometry (watch the GPU memory overflow in +## # ASTRA versions < 1.8) +## #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); +## sino_id, sino_updt = astra.creators.create_sino3d_gpu( +## X_t, proj_geom, vol_geom) + + ## RING REMOVAL + residual = fistaRecon.residual + residual2 = fistaRecon.residual2 + + lambdaR_L1 , alpha_ring , weights , L_const= \ + fistaRecon.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant']) + r_x = fistaRecon.r_x + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(fistaRecon.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + print ("ring removal") +## % the ring removal part (Group-Huber fidelity) +## % first 2 iterations do additional work reconstructing whole dataset to ensure +## % the stablility +## if (i < 3) +## [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); +## astra_mex_data3d('delete', sino_id2); +## else +## [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); +## end + +## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 + if i < 3: + pass + else: + sino_id, sino_updt = astra.creators.create_sino3d_gpu( + X_t, proj_geomSUB, vol_geom) +## sino_id, sino_updt = astra.creators.create_sino3d_gpu( +## X, proj_geom, vol_geom) +## astra.matlab.data3d('delete', sino_id) + + for kkk in range(anglesNumb): + + residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt2[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) + shape[1] = numProjSub + fistaRecon.residual = numpy.zeros(shape) + if fistaRecon.residual.__hash__() != residual.__hash__(): + residual = fistaRecon.residual +## for kkk = 1:numProjSub +## indC = CurrSubIndeces(kkk); +## if (i < 3) +## residual(:,kkk,:) = squeeze(residual2(:,indC,:)); +## else +## residual(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); +## end +## end + for kk in range(numProjSub): + indC = fistaRecon.getParameter('os_indices')[kkk] + if i < 3: + residual[:,kkk,:] = residual2[:,indC,:].squeeze() + else: + residual(:,kkk,:) = \ + weights[:,indC,:].squeeze() * sino_updt[:,kkk,:].squeeze() - \ + sino[:,indC,:].squeeze() - alpha_ring * fistaRecon.r_x + #squeeze(weights(:,indC,:)).* \ + # (squeeze(sino_updt(:,kkk,:)) - \ + #(squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); + + + + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + fistaRecon.r = (r_x - (1./L_const) * vec).copy() + objective[i] = (0.5 * (residual ** 2).sum()) +## % the ring removal part (Group-Huber fidelity) +## for kkk = 1:anglesNumb +## residual(:,kkk,:) = squeeze(weights(:,kkk,:)).* +## (squeeze(sino_updt(:,kkk,:)) - +## (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); +## end +## vec = sum(residual,2); +## if (SlicesZ > 1) +## vec = squeeze(vec(:,1,:)); +## end +## r = r_x - (1./L_const).*vec; +## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + + + + # Projection/Backprojection Routine + if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec': + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + print ("Projection/Backprojection Routine") + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) + + + ## REGULARIZATION + ## SKIPPING FOR NOW + ## Should be simpli + # regularizer = fistaRecon.getParameter('regularizer') + # for slices: + # out = regularizer(input=X) + print ("skipping regularizer") + + + ## FINAL + print ("final") + lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + fistaRecon.r = numpy.max( + numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ + numpy.sign(fistaRecon.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + fistaRecon.r_x = fistaRecon.r + \ + (((t_old-1)/t) * (fistaRecon.r - r_old)) + + if fistaRecon.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], objective[i])) + +## if (lambdaR_L1 > 0) +## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector +## end +## +## t = (1 + sqrt(1 + 4*t^2))/2; % updating t +## X_t = X + ((t_old-1)/t).*(X - X_old); % updating X +## +## if (lambdaR_L1 > 0) +## r_x = r + ((t_old-1)/t).*(r - r_old); % updating r +## end +## +## if (show == 1) +## figure(10); imshow(X(:,:,slice), [0 maxvalplot]); +## if (lambdaR_L1 > 0) +## figure(11); plot(r); title('Rings offset vector') +## end +## pause(0.01); +## end +## if (strcmp(X_ideal, 'none' ) == 0) +## Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); +## fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); +## else +## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); +## end +else: + fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) + + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + fistaRecon.setParameter(number_of_iterations = 12) + fistaRecon.setParameter(Lipschitz_constant = 767893952.0) + fistaRecon.setParameter(ring_alpha = 21) + fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + fistaRecon.prepareForIteration() + X = fistaRecon.iterate(numpy.load("X.npy")) -- cgit v1.2.3 From d0de394cc4d2be254fc6b1c7c89571b58f7bd30d Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 19 Oct 2017 13:00:32 +0100 Subject: progress in pythonization --- src/Python/test_reconstructor-os.py | 66 ++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 22 deletions(-) (limited to 'src') diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py index 6f3721f..3ee92fa 100644 --- a/src/Python/test_reconstructor-os.py +++ b/src/Python/test_reconstructor-os.py @@ -139,7 +139,10 @@ if True: ## additional for proj_geomSUB = proj_geom.copy() - fistaRecon.residual2 = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) + fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram'])) + residual2 = fistaRecon.residual2 + sino_updt_FULL = residual.copy() + print ("starting iterations") ## % Outer FISTA iterations loop for i in range(fistaRecon.getParameter('number_of_iterations')): @@ -153,11 +156,23 @@ if True: # hence additional work is required one solution is to work with a full # sinogram at times - ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 - if (lambdaR_L1 > 0) : - sino_id2, sino_updt2 = astra.creators.create_sino3d_gpu( - X, proj_geom, vol_geom) - astra.matlab.data3d('delete', sino_id2) + + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(fistaRecon.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 + if (i > 1 and lambdaR_L1 > 0) : + for kkk in range(anglesNumb): + + residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt_FULL[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + r_old = fistaRecon.r.copy() + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:] # 1 or 0? + r_x = fistaRecon.r_x + fistaRecon.r = (r_x - (1./L_const) * vec).copy() # subset loop counterInd = 1 @@ -165,8 +180,7 @@ if True: print ("Subset {0}".format(ss)) X_old = X.copy() t_old = t - r_old = fistaRecon.r.copy() - + # the number of projections per subset numProjSub = fistaRecon.getParameter('os_bins')[ss] CurrSubIndices = fistaRecon.getParameter('os_indices')\ @@ -198,27 +212,35 @@ if True: ## RING REMOVAL residual = fistaRecon.residual - residual2 = fistaRecon.residual2 lambdaR_L1 , alpha_ring , weights , L_const= \ fistaRecon.getParameter(['ring_lambda_R_L1', 'ring_alpha' , 'weights', 'Lipschitz_constant']) - r_x = fistaRecon.r_x - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(fistaRecon.getParameter('input_sinogram')) + sino_updt_Sub = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) if lambdaR_L1 > 0 : print ("ring removal") -## % the ring removal part (Group-Huber fidelity) -## % first 2 iterations do additional work reconstructing whole dataset to ensure -## % the stablility -## if (i < 3) -## [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); -## astra_mex_data3d('delete', sino_id2); -## else -## [sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geomSUB, vol_geom); -## end + geometry_type = fistaRecon.getParameter('projector_geometry')['type'] + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + # here + for kkk in range(SlicesZ): + sino_id, sinoT[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geomSUB, vol_geom) + sino_updt_Sub[kkk] = sinoT.T.copy() + astra.matlab.data3d('delete', sino_id) + +## % if geometry is 2D use slice-by-slice projection-backprojection routine +## for kkk = 1:SlicesZ +## [sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geomSUB, vol_geom); +## sino_updt_Sub(:,:,kkk) = sinoT'; +## astra_mex_data2d('delete', sino_id); +## end + +## ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 if i < 3: pass @@ -254,7 +276,7 @@ if True: if i < 3: residual[:,kkk,:] = residual2[:,indC,:].squeeze() else: - residual(:,kkk,:) = \ + residual[:,kkk,:] = \ weights[:,indC,:].squeeze() * sino_updt[:,kkk,:].squeeze() - \ sino[:,indC,:].squeeze() - alpha_ring * fistaRecon.r_x #squeeze(weights(:,indC,:)).* \ -- cgit v1.2.3 From 98443072f33a1f46eb8ea5ab27741a6400970afa Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 19 Oct 2017 14:56:20 +0100 Subject: further progress --- src/Python/test_reconstructor-os.py | 88 +++++++++++++++++++++++-------------- 1 file changed, 54 insertions(+), 34 deletions(-) (limited to 'src') diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py index 3ee92fa..8820db7 100644 --- a/src/Python/test_reconstructor-os.py +++ b/src/Python/test_reconstructor-os.py @@ -176,6 +176,23 @@ if True: # subset loop counterInd = 1 + geometry_type = fistaRecon.getParameter('projector_geometry')['type'] + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + + for kkk in range(SlicesZ): + sino_id, sinoT[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geomSUB, vol_geom) + sino_updt_Sub[kkk] = sinoT.T.copy() + + else: + sino_id, sino_updt_Sub = \ + astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom) + + astra.matlab.data3d('delete', sino_id) + for ss in range(fistaRecon.getParameter('subsets')): print ("Subset {0}".format(ss)) X_old = X.copy() @@ -186,29 +203,29 @@ if True: CurrSubIndices = fistaRecon.getParameter('os_indices')\ [counterInd:counterInd+numProjSub-1] proj_geomSUB['ProjectionAngles'] = angles[CurrSubIndeces] + + shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) + shape[1] = numProjSub + sino_updt_Sub = numpy.zeros(shape) + + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + + for kkk in range(SlicesZ): + sino_id, sinoT = astra.creators.create_sino3d_gpu ( + X_t[kkk:kkk+1] , proj_geomSUB, vol_geom) + sino_updt_Sub[kkk] = sinoT.T.copy() + + else: + # for 3D geometry (watch the GPU memory overflow in ASTRA < 1.8) + sino_id, sino_updt_Sub = \ + astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom) + + astra.matlab.data3d('delete', sino_id) + -## if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ -## fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \ -## fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' : -## # if the geometry is parallel use slice-by-slice -## # projection-backprojection routine -## #sino_updt = zeros(size(sino),'single'); -## proj_geomT = proj_geom.copy() -## proj_geomT['DetectorRowCount'] = 1 -## vol_geomT = vol_geom.copy() -## vol_geomT['GridSliceCount'] = 1; -## sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) -## for kkk in range(SlicesZ): -## sino_id, sino_updt[kkk] = \ -## astra.creators.create_sino3d_gpu( -## X_t[kkk:kkk+1], proj_geom, vol_geom) -## astra.matlab.data3d('delete', sino_id) -## else: -## # for divergent 3D geometry (watch the GPU memory overflow in -## # ASTRA versions < 1.8) -## #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); -## sino_id, sino_updt = astra.creators.create_sino3d_gpu( -## X_t, proj_geom, vol_geom) + ## RING REMOVAL residual = fistaRecon.residual @@ -217,20 +234,23 @@ if True: fistaRecon.getParameter(['ring_lambda_R_L1', 'ring_alpha' , 'weights', 'Lipschitz_constant']) - sino_updt_Sub = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) if lambdaR_L1 > 0 : print ("ring removal") - geometry_type = fistaRecon.getParameter('projector_geometry')['type'] - if geometry_type == 'parallel' or \ - geometry_type == 'fanflat' or \ - geometry_type == 'fanflat_vec' : - # here - for kkk in range(SlicesZ): - sino_id, sinoT[kkk] = \ - astra.creators.create_sino3d_gpu( - X_t[kkk:kkk+1], proj_geomSUB, vol_geom) - sino_updt_Sub[kkk] = sinoT.T.copy() - astra.matlab.data3d('delete', sino_id) + residualSub = numpy.zeros(shape) +## for a chosen subset +## for kkk = 1:numProjSub +## indC = CurrSubIndeces(kkk); +## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); +## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram +## end + for kkk in range(numProjSub): + indC = CurrSubIndices[kkk] + residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ + (sino_updt_Sub[:,kkk,:].squeeze() - \ + sino[:,indC,:].squeeze() - alpha_ring * r_x) + # filling the full sinogram + sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() + ## % if geometry is 2D use slice-by-slice projection-backprojection routine ## for kkk = 1:SlicesZ -- cgit v1.2.3 From d2ce1b74b4ecad5cdecb29207181e09ef0f6013a Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 19 Oct 2017 15:09:06 +0100 Subject: finished first draft --- src/Python/test_reconstructor-os.py | 136 ++++++------------------------------ 1 file changed, 21 insertions(+), 115 deletions(-) (limited to 'src') diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py index 8820db7..f6d7d4b 100644 --- a/src/Python/test_reconstructor-os.py +++ b/src/Python/test_reconstructor-os.py @@ -250,102 +250,34 @@ if True: sino[:,indC,:].squeeze() - alpha_ring * r_x) # filling the full sinogram sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() - - -## % if geometry is 2D use slice-by-slice projection-backprojection routine -## for kkk = 1:SlicesZ -## [sino_id, sinoT] = astra_create_sino_cuda(X_t(:,:,kkk), proj_geomSUB, vol_geom); -## sino_updt_Sub(:,:,kkk) = sinoT'; -## astra_mex_data2d('delete', sino_id); -## end - -## -## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 - if i < 3: - pass - else: - sino_id, sino_updt = astra.creators.create_sino3d_gpu( - X_t, proj_geomSUB, vol_geom) -## sino_id, sino_updt = astra.creators.create_sino3d_gpu( -## X, proj_geom, vol_geom) -## astra.matlab.data3d('delete', sino_id) - - for kkk in range(anglesNumb): - - residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ - ((sino_updt2[:,kkk,:]).squeeze() - \ - (sino[:,kkk,:]).squeeze() -\ - (alpha_ring * r_x) - ) - shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) - shape[1] = numProjSub - fistaRecon.residual = numpy.zeros(shape) - if fistaRecon.residual.__hash__() != residual.__hash__(): - residual = fistaRecon.residual -## for kkk = 1:numProjSub -## indC = CurrSubIndeces(kkk); -## if (i < 3) -## residual(:,kkk,:) = squeeze(residual2(:,indC,:)); -## else -## residual(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); -## end -## end - for kk in range(numProjSub): - indC = fistaRecon.getParameter('os_indices')[kkk] - if i < 3: - residual[:,kkk,:] = residual2[:,indC,:].squeeze() - else: - residual[:,kkk,:] = \ - weights[:,indC,:].squeeze() * sino_updt[:,kkk,:].squeeze() - \ - sino[:,indC,:].squeeze() - alpha_ring * fistaRecon.r_x - #squeeze(weights(:,indC,:)).* \ - # (squeeze(sino_updt(:,kkk,:)) - \ - #(squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); - - - - vec = residual.sum(axis = 1) - #if SlicesZ > 1: - # vec = vec[:,1,:].squeeze() - fistaRecon.r = (r_x - (1./L_const) * vec).copy() - objective[i] = (0.5 * (residual ** 2).sum()) -## % the ring removal part (Group-Huber fidelity) -## for kkk = 1:anglesNumb -## residual(:,kkk,:) = squeeze(weights(:,kkk,:)).* -## (squeeze(sino_updt(:,kkk,:)) - -## (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); -## end -## vec = sum(residual,2); -## if (SlicesZ > 1) -## vec = squeeze(vec(:,1,:)); -## end -## r = r_x - (1./L_const).*vec; -## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output - - + else: + #PWLS model + residualSub = weights[:,CurrSubIndices,:] * \ + ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) + objective[i] = 0.5 * numpy.linalg.norm(residualSub) - # Projection/Backprojection Routine - if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ - fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\ - fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec': - x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) - print ("Projection/Backprojection Routine") + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + # if geometry is 2D use slice-by-slice projection-backprojection + # routine + x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) for kkk in range(SlicesZ): x_id, x_temp[kkk] = \ astra.creators.create_backprojection3d_gpu( - residual[kkk:kkk+1], - proj_geomT, vol_geomT) - astra.matlab.data3d('delete', x_id) + residualSub[kkk:kkk+1], + proj_geomSUB, vol_geom) + else: x_id, x_temp = \ astra.creators.create_backprojection3d_gpu( - residual, proj_geom, vol_geom) + residualSub, proj_geomSUB, vol_geom) - X = X_t - (1/L_const) * x_temp - astra.matlab.data3d('delete', sino_id) astra.matlab.data3d('delete', x_id) + X = X_t - (1/L_const) * x_temp + ## REGULARIZATION @@ -364,12 +296,9 @@ if True: fistaRecon.r = numpy.max( numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ numpy.sign(fistaRecon.r) - t = (1 + numpy.sqrt(1 + 4 * t**2))/2 - X_t = X + (((t_old -1)/t) * (X - X_old)) - - if lambdaR_L1 > 0: - fistaRecon.r_x = fistaRecon.r + \ - (((t_old-1)/t) * (fistaRecon.r - r_old)) + # updating r + r_x = fistaRecon.r + ((t_old-1)/t) * (fistaRecon.r - r_old) + if fistaRecon.getParameter('region_of_interest') is None: string = 'Iteration Number {0} | Objective {1} \n' @@ -382,30 +311,7 @@ if True: string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' print (string.format(i,Resid_error[i], objective[i])) -## if (lambdaR_L1 > 0) -## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector -## end -## -## t = (1 + sqrt(1 + 4*t^2))/2; % updating t -## X_t = X + ((t_old-1)/t).*(X - X_old); % updating X -## -## if (lambdaR_L1 > 0) -## r_x = r + ((t_old-1)/t).*(r - r_old); % updating r -## end -## -## if (show == 1) -## figure(10); imshow(X(:,:,slice), [0 maxvalplot]); -## if (lambdaR_L1 > 0) -## figure(11); plot(r); title('Rings offset vector') -## end -## pause(0.01); -## end -## if (strcmp(X_ideal, 'none' ) == 0) -## Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); -## fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); -## else -## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); -## end + else: fistaRecon = FISTAReconstructor(proj_geom, vol_geom, -- cgit v1.2.3 From b6c314b371ef3081828fa007cd3fcaf1dc820477 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 19 Oct 2017 17:07:48 +0100 Subject: First commit of CMakeLists.txt attempting to locate conda python environment --- src/Python/CMakeLists.txt | 79 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 src/Python/CMakeLists.txt (limited to 'src') diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt new file mode 100644 index 0000000..a75c062 --- /dev/null +++ b/src/Python/CMakeLists.txt @@ -0,0 +1,79 @@ +cmake_minimum_required (VERSION 3.0) + +project(FISTA) + +# The version number. +set (FISTA_VERSION_MAJOR 1) +set (FISTA_VERSION_MINOR 0) + +set (CIL_VERSION_MAJOR 0) +set (CIL_VERSION_MINOR 9) +set (CIL_VERSION_PATCH 1) + +set (CIL_VERSION '${CIL_VERSION_MAJOR}.${CIL_VERSION_MINOR}.${CIL_VERSION_PATCH}') + +message("CIL VERSION " ${CIL_VERSION}) + +# variables we need to run the conda build +#PREFIX=C:\Apps\Miniconda2\envs\cil\Library +#LIBRARY_INC=C:\\Apps\\Miniconda2\\envs\\cil\\Library\\include + +set (NUMPY_VERSION 1.12) +#set (PYTHON_VERSION 3.5) + +#https://groups.google.com/a/continuum.io/forum/#!topic/anaconda/R9gWjl09UFs +set (CONDA_ENVIRONMENT "C:\\Apps\\Miniconda2\\envs\\cil27" CACHE PATH "env dir") + +function (findPythonForAnacondaEnvironment env) + + file(TO_CMAKE_PATH ${env}/python.exe PYTHON_EXECUTABLE) + + message("Found " ${PYTHON_EXECUTABLE}) + execute_process(COMMAND ${PYTHON_EXECUTABLE} pythonversion.py major + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE major + ERROR_QUIET) + execute_process(COMMAND ${PYTHON_EXECUTABLE} pythonversion.py minor + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE minor + ERROR_QUIET) + execute_process(COMMAND ${PYTHON_EXECUTABLE} pythonversion.py patch + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE patch + ERROR_QUIET) + execute_process(COMMAND ${PYTHON_EXECUTABLE} pythonversion.py + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE version + ERROR_QUIET) + + + set (PYTHON_EXECUTABLE ${PYTHON_EXECUTABLE} PARENT_SCOPE) + set (PYTHONINTERP_FOUND "ON" PARENT_SCOPE) + set (PYTHON_VERSION_STRING ${version}) + message("My version found " ${PYTHON_VERSION_STRING}) + +endfunction() + +findPythonForAnacondaEnvironment(${CONDA_ENVIRONMENT}) + +set(Python_ADDITIONAL_VERSIONS 3) + +find_package(PythonInterp) +if (PYTHONINTERP_FOUND) + + message("Found interpret " ${PYTHON_EXECUTABLE}) + + foreach(pv ${PYTHON_VERSION_STRING}) + message("Found interpret " ${pv}) + endforeach() +endif() + +find_package(PythonLibs) +if (PYTHONLIB_FOUND) + message("Found PythonLibs PYTHON_LIBRARIES " ${PYTHON_LIBRARIES}) + message("Found PythonLibs PYTHON_INCLUDE_PATH " ${PYTHON_INCLUDE_PATH}) + message("Found PythonLibs PYTHON_INCLUDE_DIRS " ${PYTHON_INCLUDE_DIRS}) + message("Found PythonLibs PYTHONLIBS_VERSION_STRING " ${PYTHONLIBS_VERSION_STRING} ) + +endif() + -- cgit v1.2.3 From 8b427d82acfaeb4671484bc459343c5e2e412736 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 19 Oct 2017 17:01:55 +0100 Subject: Initial revision of build environment made with CMake Initial revision of build environment made with CMake First commit of CMakeLists.txt attempt to use CMake to create the build environment First commit of CMakeLists.txt attempting to locate conda python environment Added a few files for CMake Many changes for the CMake compilation. Tested CMake build Bugfixes --- src/CMakeLists.txt | 14 +++ src/Python/CMakeLists.txt | 58 +++++++++++ src/Python/FindAnacondaEnvironment.cmake | 166 +++++++++++++++++++++++++++++++ src/Python/compile.bat.in | 4 + src/Python/compile.sh.in | 6 ++ src/Python/conda-recipe/bld.bat | 14 +++ src/Python/conda-recipe/build.sh | 14 +++ src/Python/conda-recipe/meta.yaml | 30 ++++++ 8 files changed, 306 insertions(+) create mode 100644 src/CMakeLists.txt create mode 100644 src/Python/CMakeLists.txt create mode 100644 src/Python/FindAnacondaEnvironment.cmake create mode 100644 src/Python/compile.bat.in create mode 100644 src/Python/compile.sh.in create mode 100644 src/Python/conda-recipe/bld.bat create mode 100644 src/Python/conda-recipe/build.sh create mode 100644 src/Python/conda-recipe/meta.yaml (limited to 'src') diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..cbe2fec --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_subdirectory(Python) \ No newline at end of file diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt new file mode 100644 index 0000000..fd377cc --- /dev/null +++ b/src/Python/CMakeLists.txt @@ -0,0 +1,58 @@ +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +message("CIL VERSION " ${CIL_VERSION}) + +# variables that are set by conda +#PREFIX=C:\Apps\Miniconda2\envs\cil\Library +#LIBRARY_INC=C:\\Apps\\Miniconda2\\envs\\cil\\Library\\include + +set (NUMPY_VERSION 1.12) +#set (PYTHON_VERSION 3.5) + +#https://groups.google.com/a/continuum.io/forum/#!topic/anaconda/R9gWjl09UFs +set (CONDA_ENVIRONMENT "cil") +set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "env dir") + +message("CIL VERSION " ${CIL_VERSION}) + +# set the Python variables for the Conda environment +include(FindAnacondaEnvironment.cmake) +findPythonForAnacondaEnvironment(${CONDA_ENVIRONMENT_PATH}) +message("Python found " ${PYTHON_VERSION_STRING}) +findPythonPackagesPath() +message("PYTHON_PACKAGES_FOUND " ${PYTHON_PACKAGES_PATH}) + +# copy the Pyhon files of the package +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging/) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) + + +# Copy and configure the relative conda build and recipes +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup.py) +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe) + +if (WIN32) + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat) +elseif(UNIX) + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) + # assumes we will use bash + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh) +endif() + + diff --git a/src/Python/FindAnacondaEnvironment.cmake b/src/Python/FindAnacondaEnvironment.cmake new file mode 100644 index 0000000..3abb5d1 --- /dev/null +++ b/src/Python/FindAnacondaEnvironment.cmake @@ -0,0 +1,166 @@ +# Copyright 2017 Edoardo Pasca +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# #.rst: +# FindAnacondaEnvironment +# -------------- +# +# Find Python executable and library for a specific Anaconda environment +# +# This module finds the Python interpreter for a specific Anaconda enviroment, +# if installed and determines where the include files and libraries are. +# This code sets the following variables: +# +# :: +# PYTHONINTERP_FOUND - if the Python interpret has been found +# PYTHON_EXECUTABLE - the Python interpret found +# PYTHON_LIBRARY - path to the python library +# PYTHON_INCLUDE_PATH - path to where Python.h is found (deprecated) +# PYTHON_INCLUDE_DIRS - path to where Python.h is found +# PYTHONLIBS_VERSION_STRING - version of the Python libs found (since CMake 2.8.8) +# PYTHON_VERSION_MAJOR - major Python version +# PYTHON_VERSION_MINOR - minor Python version +# PYTHON_VERSION_PATCH - patch Python version + + + +function (findPythonForAnacondaEnvironment env) + + file(TO_CMAKE_PATH ${env}/python.exe PYTHON_EXECUTABLE) + + message("Found " ${PYTHON_EXECUTABLE}) + ####### FROM FindPythonInterpr ######## + # determine python version string + if(PYTHON_EXECUTABLE) + execute_process(COMMAND "${PYTHON_EXECUTABLE}" -c + "import sys; sys.stdout.write(';'.join([str(x) for x in sys.version_info[:3]]))" + OUTPUT_VARIABLE _VERSION + RESULT_VARIABLE _PYTHON_VERSION_RESULT + ERROR_QUIET) + if(NOT _PYTHON_VERSION_RESULT) + string(REPLACE ";" "." _PYTHON_VERSION_STRING "${_VERSION}") + list(GET _VERSION 0 _PYTHON_VERSION_MAJOR) + list(GET _VERSION 1 _PYTHON_VERSION_MINOR) + list(GET _VERSION 2 _PYTHON_VERSION_PATCH) + if(PYTHON_VERSION_PATCH EQUAL 0) + # it's called "Python 2.7", not "2.7.0" + string(REGEX REPLACE "\\.0$" "" _PYTHON_VERSION_STRING "${PYTHON_VERSION_STRING}") + endif() + else() + # sys.version predates sys.version_info, so use that + execute_process(COMMAND "${PYTHON_EXECUTABLE}" -c "import sys; sys.stdout.write(sys.version)" + OUTPUT_VARIABLE _VERSION + RESULT_VARIABLE _PYTHON_VERSION_RESULT + ERROR_QUIET) + if(NOT _PYTHON_VERSION_RESULT) + string(REGEX REPLACE " .*" "" _PYTHON_VERSION_STRING "${_VERSION}") + string(REGEX REPLACE "^([0-9]+)\\.[0-9]+.*" "\\1" _PYTHON_VERSION_MAJOR "${PYTHON_VERSION_STRING}") + string(REGEX REPLACE "^[0-9]+\\.([0-9])+.*" "\\1" _PYTHON_VERSION_MINOR "${PYTHON_VERSION_STRING}") + if(PYTHON_VERSION_STRING MATCHES "^[0-9]+\\.[0-9]+\\.([0-9]+)") + set(PYTHON_VERSION_PATCH "${CMAKE_MATCH_1}") + else() + set(PYTHON_VERSION_PATCH "0") + endif() + else() + # sys.version was first documented for Python 1.5, so assume + # this is older. + set(PYTHON_VERSION_STRING "1.4" PARENT_SCOPE) + set(PYTHON_VERSION_MAJOR "1" PARENT_SCOPE) + set(PYTHON_VERSION_MINOR "4" PARENT_SCOPE) + set(PYTHON_VERSION_PATCH "0" PARENT_SCOPE) + endif() + endif() + unset(_PYTHON_VERSION_RESULT) + unset(_VERSION) + endif() + ############################################### + + set (PYTHON_EXECUTABLE ${PYTHON_EXECUTABLE} PARENT_SCOPE) + set (PYTHONINTERP_FOUND "ON" PARENT_SCOPE) + set (PYTHON_VERSION_STRING ${_PYTHON_VERSION_STRING} PARENT_SCOPE) + set (PYTHON_VERSION_MAJOR ${_PYTHON_VERSION_MAJOR} PARENT_SCOPE) + set (PYTHON_VERSION_MINOR ${_PYTHON_VERSION_MINOR} PARENT_SCOPE) + set (PYTHON_VERSION_PATCH ${_PYTHON_VERSION_PATCH} PARENT_SCOPE) + message("My version found " ${PYTHON_VERSION_STRING}) + +endfunction() + + + +set(Python_ADDITIONAL_VERSIONS 3.5) + +find_package(PythonInterp) +if (PYTHONINTERP_FOUND) + + message("Found interpret " ${PYTHON_EXECUTABLE}) + message("Python Library " ${PYTHON_LIBRARY}) + message("Python Include Dir " ${PYTHON_INCLUDE_DIR}) + message("Python Include Path " ${PYTHON_INCLUDE_PATH}) + + foreach(pv ${PYTHON_VERSION_STRING}) + message("Found interpret " ${pv}) + endforeach() +endif() + + + +find_package(PythonLibs) +if (PYTHONLIB_FOUND) + message("Found PythonLibs PYTHON_LIBRARIES " ${PYTHON_LIBRARIES}) + message("Found PythonLibs PYTHON_INCLUDE_PATH " ${PYTHON_INCLUDE_PATH}) + message("Found PythonLibs PYTHON_INCLUDE_DIRS " ${PYTHON_INCLUDE_DIRS}) + message("Found PythonLibs PYTHONLIBS_VERSION_STRING " ${PYTHONLIBS_VERSION_STRING} ) +else() + message("No PythonLibs Found") +endif() + + + + +function(findPythonPackagesPath) +### https://openlab.ncl.ac.uk/gitlab/john.shearer/clappertracker/raw/549885e5decd37f7b23e9c1fd39e86f207156795/src/3rdparty/opencv/cmake/OpenCVDetectPython.cmake +### +if(CMAKE_HOST_UNIX) + execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "from distutils.sysconfig import *; print get_python_lib()" + RESULT_VARIABLE PYTHON_CVPY_PROCESS + OUTPUT_VARIABLE PYTHON_STD_PACKAGES_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE) + if("${PYTHON_STD_PACKAGES_PATH}" MATCHES "site-packages") + set(_PYTHON_PACKAGES_PATH "python${PYTHON_VERSION_MAJOR_MINOR}/site-packages") + else() #debian based assumed, install to the dist-packages. + set(_PYTHON_PACKAGES_PATH "python${PYTHON_VERSION_MAJOR_MINOR}/dist-packages") + endif() + if(EXISTS "${CMAKE_INSTALL_PREFIX}/lib${LIB_SUFFIX}/${PYTHON_PACKAGES_PATH}") + set(_PYTHON_PACKAGES_PATH "lib${LIB_SUFFIX}/${_PYTHON_PACKAGES_PATH}") + else() + set(_PYTHON_PACKAGES_PATH "lib/${_PYTHON_PACKAGES_PATH}") + endif() + elseif(CMAKE_HOST_WIN32) + get_filename_component(PYTHON_PATH "${PYTHON_EXECUTABLE}" PATH) + file(TO_CMAKE_PATH "${PYTHON_PATH}" PYTHON_PATH) + if(NOT EXISTS "${PYTHON_PATH}/Lib/site-packages") + unset(PYTHON_PATH) + get_filename_component(PYTHON_PATH "[HKEY_LOCAL_MACHINE\\SOFTWARE\\Python\\PythonCore\\${PYTHON_VERSION_MAJOR_MINOR}\\InstallPath]" ABSOLUTE) + if(NOT PYTHON_PATH) + get_filename_component(PYTHON_PATH "[HKEY_CURRENT_USER\\SOFTWARE\\Python\\PythonCore\\${PYTHON_VERSION_MAJOR_MINOR}\\InstallPath]" ABSOLUTE) + endif() + file(TO_CMAKE_PATH "${PYTHON_PATH}" PYTHON_PATH) + endif() + set(_PYTHON_PACKAGES_PATH "${PYTHON_PATH}/Lib/site-packages") + endif() + SET(PYTHON_PACKAGES_PATH "${_PYTHON_PACKAGES_PATH}" PARENT_SCOPE) + +endfunction() + + diff --git a/src/Python/compile.bat.in b/src/Python/compile.bat.in new file mode 100644 index 0000000..d4ddc92 --- /dev/null +++ b/src/Python/compile.bat.in @@ -0,0 +1,4 @@ +set CIL_VERSION=@CIL_VERSION@ + +activate @CONDA_ENVIRONMENT@ +conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi \ No newline at end of file diff --git a/src/Python/compile.sh.in b/src/Python/compile.sh.in new file mode 100644 index 0000000..dd29973 --- /dev/null +++ b/src/Python/compile.sh.in @@ -0,0 +1,6 @@ +#!/bin/sh + +export CIL_VERSION=@CIL_VERSION@ +module load python/anaconda +source activate @CONDA_ENVIRONMENT@ +conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi \ No newline at end of file diff --git a/src/Python/conda-recipe/bld.bat b/src/Python/conda-recipe/bld.bat new file mode 100644 index 0000000..69491de --- /dev/null +++ b/src/Python/conda-recipe/bld.bat @@ -0,0 +1,14 @@ +IF NOT DEFINED CIL_VERSION ( +ECHO CIL_VERSION Not Defined. +exit 1 +) + +mkdir "%SRC_DIR%\ccpi" +xcopy /e "%RECIPE_DIR%\..\.." "%SRC_DIR%\ccpi" + +cd %SRC_DIR%\ccpi\Python + +%PYTHON% setup.py build_ext +if errorlevel 1 exit 1 +%PYTHON% setup.py install +if errorlevel 1 exit 1 diff --git a/src/Python/conda-recipe/build.sh b/src/Python/conda-recipe/build.sh new file mode 100644 index 0000000..855047f --- /dev/null +++ b/src/Python/conda-recipe/build.sh @@ -0,0 +1,14 @@ + +if [ -z "$CIL_VERSION" ]; then + echo "Need to set CIL_VERSION" + exit 1 +fi +mkdir "$SRC_DIR/ccpi" +cp -r "$RECIPE_DIR/../.." "$SRC_DIR/ccpi" + +cd $SRC_DIR/ccpi/Python + +$PYTHON setup.py build_ext +$PYTHON setup.py install + + diff --git a/src/Python/conda-recipe/meta.yaml b/src/Python/conda-recipe/meta.yaml new file mode 100644 index 0000000..c5b7a89 --- /dev/null +++ b/src/Python/conda-recipe/meta.yaml @@ -0,0 +1,30 @@ +package: + name: ccpi-fista + version: {{ environ['CIL_VERSION'] }} + + +build: + preserve_egg_dir: False + script_env: + - CIL_VERSION +# number: 0 + +requirements: + build: + - python + - numpy + - setuptools + - boost ==1.64 + - boost-cpp ==1.64 + - cython + + run: + - python + - numpy + - boost ==1.64 + + +about: + home: http://www.ccpi.ac.uk + license: BSD license + summary: 'CCPi Core Imaging Library Quantification Toolbox' -- cgit v1.2.3 From 903175ed67f7645fa35edf4623b27999d6cb990f Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 20 Oct 2017 17:04:26 +0100 Subject: Further development --- src/Python/ccpi/fista/FISTAReconstructor.py | 24 ++++++ src/Python/test_reconstructor-os.py | 112 ++++++++++++++-------------- 2 files changed, 81 insertions(+), 55 deletions(-) (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py index fda9cf0..85bfac5 100644 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ b/src/Python/ccpi/fista/FISTAReconstructor.py @@ -583,3 +583,27 @@ class FISTAReconstructor(): string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' print (string.format(i,Resid_error[i], self.objective[i])) return (X , X_t, t) + + def os_iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + + # some useful constants + proj_geom , vol_geom, sino , \ + SlicesZ, weights , alpha_ring , + lambdaR_L1 , L_const = self.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , + 'ring_lambda_R_L1', 'Lipschitz_constant']) diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py index f6d7d4b..aee70a4 100644 --- a/src/Python/test_reconstructor-os.py +++ b/src/Python/test_reconstructor-os.py @@ -122,10 +122,13 @@ if True: X_t = X.copy() print ("initialized") proj_geom , vol_geom, sino , \ - SlicesZ = fistaRecon.getParameter(['projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ']) + SlicesZ, weights , alpha_ring = fistaRecon.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha']) + lambdaR_L1 , alpha_ring , weights , L_const= \ + fistaRecon.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant']) #fistaRecon.setParameter(number_of_iterations = 3) iterFISTA = fistaRecon.getParameter('number_of_iterations') @@ -136,12 +139,13 @@ if True: t = 1 + ## additional for proj_geomSUB = proj_geom.copy() fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram'])) residual2 = fistaRecon.residual2 - sino_updt_FULL = residual.copy() + sino_updt_FULL = fistaRecon.residual.copy() print ("starting iterations") ## % Outer FISTA iterations loop @@ -156,7 +160,8 @@ if True: # hence additional work is required one solution is to work with a full # sinogram at times - + r_old = fistaRecon.r.copy() + t_old = t SlicesZ, anglesNumb, Detectors = \ numpy.shape(fistaRecon.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 if (i > 1 and lambdaR_L1 > 0) : @@ -167,8 +172,8 @@ if True: (sino[:,kkk,:]).squeeze() -\ (alpha_ring * r_x) ) - r_old = fistaRecon.r.copy() - vec = residual.sum(axis = 1) + + vec = fistaRecon.residual.sum(axis = 1) #if SlicesZ > 1: # vec = vec[:,1,:] # 1 or 0? r_x = fistaRecon.r_x @@ -227,56 +232,53 @@ if True: - ## RING REMOVAL - residual = fistaRecon.residual - - lambdaR_L1 , alpha_ring , weights , L_const= \ - fistaRecon.getParameter(['ring_lambda_R_L1', - 'ring_alpha' , 'weights', - 'Lipschitz_constant']) - if lambdaR_L1 > 0 : - print ("ring removal") - residualSub = numpy.zeros(shape) -## for a chosen subset -## for kkk = 1:numProjSub -## indC = CurrSubIndeces(kkk); -## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); -## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram -## end - for kkk in range(numProjSub): - indC = CurrSubIndices[kkk] - residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ - (sino_updt_Sub[:,kkk,:].squeeze() - \ - sino[:,indC,:].squeeze() - alpha_ring * r_x) - # filling the full sinogram - sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() + ## RING REMOVAL + residual = fistaRecon.residual + + + if lambdaR_L1 > 0 : + print ("ring removal") + residualSub = numpy.zeros(shape) + ## for a chosen subset + ## for kkk = 1:numProjSub + ## indC = CurrSubIndeces(kkk); + ## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); + ## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram + ## end + for kkk in range(numProjSub): + indC = CurrSubIndices[kkk] + residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ + (sino_updt_Sub[:,kkk,:].squeeze() - \ + sino[:,indC,:].squeeze() - alpha_ring * r_x) + # filling the full sinogram + sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() - else: - #PWLS model - residualSub = weights[:,CurrSubIndices,:] * \ - ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) - objective[i] = 0.5 * numpy.linalg.norm(residualSub) + else: + #PWLS model + residualSub = weights[:,CurrSubIndices,:] * \ + ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) + objective[i] = 0.5 * numpy.linalg.norm(residualSub) - if geometry_type == 'parallel' or \ - geometry_type == 'fanflat' or \ - geometry_type == 'fanflat_vec' : - # if geometry is 2D use slice-by-slice projection-backprojection - # routine - x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) - for kkk in range(SlicesZ): - - x_id, x_temp[kkk] = \ - astra.creators.create_backprojection3d_gpu( - residualSub[kkk:kkk+1], - proj_geomSUB, vol_geom) - - else: - x_id, x_temp = \ - astra.creators.create_backprojection3d_gpu( - residualSub, proj_geomSUB, vol_geom) + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + # if geometry is 2D use slice-by-slice projection-backprojection + # routine + x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residualSub[kkk:kkk+1], + proj_geomSUB, vol_geom) + + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residualSub, proj_geomSUB, vol_geom) - astra.matlab.data3d('delete', x_id) - X = X_t - (1/L_const) * x_temp + astra.matlab.data3d('delete', x_id) + X = X_t - (1/L_const) * x_temp -- cgit v1.2.3 From 52f7080153a00bd3f7276a4be0a79f7aa82c6196 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Fri, 20 Oct 2017 17:07:05 +0100 Subject: add conda-forge to the channel --- src/Python/compile.bat.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/Python/compile.bat.in b/src/Python/compile.bat.in index d4ddc92..ab23404 100644 --- a/src/Python/compile.bat.in +++ b/src/Python/compile.bat.in @@ -1,4 +1,4 @@ set CIL_VERSION=@CIL_VERSION@ activate @CONDA_ENVIRONMENT@ -conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi \ No newline at end of file +conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi -c conda-forge \ No newline at end of file -- cgit v1.2.3 From e3dee52c17c9da457cc4c4e98b7dbb8ce1a644f6 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 23 Oct 2017 10:50:41 +0100 Subject: Add needed environment variables Adds the environment variables that are needed to compile via conda. --- src/Python/CMakeLists.txt | 24 +++++++++++++++++++++--- src/Python/compile.bat.in | 5 ++++- src/Python/compile.sh.in | 7 +++++-- 3 files changed, 30 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt index 3eb4158..b84f5a3 100644 --- a/src/Python/CMakeLists.txt +++ b/src/Python/CMakeLists.txt @@ -14,17 +14,25 @@ message("CIL VERSION " ${CIL_VERSION}) -# variables that are set by conda + +# variables that must be set for conda compilation #PREFIX=C:\Apps\Miniconda2\envs\cil\Library #LIBRARY_INC=C:\\Apps\\Miniconda2\\envs\\cil\\Library\\include - set (NUMPY_VERSION 1.12) #set (PYTHON_VERSION 3.5) #https://groups.google.com/a/continuum.io/forum/#!topic/anaconda/R9gWjl09UFs set (CONDA_ENVIRONMENT "cil") -set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "env dir") +if (WIN32) + set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") + set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}\\Library" CACHE PATH "env dir") + set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}\\include" CACHE PATH "env dir") +elseif (UNIX) + set (CONDA_ENVIRONMENT_PATH "/apps/anaconda/2.4/envs/${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") + set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}" CACHE PATH "env dir") + set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}/include" CACHE PATH "env dir") +endif() message("CIL VERSION " ${CIL_VERSION}) @@ -56,4 +64,14 @@ elseif(UNIX) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh) endif() +### add tests +#add_executable(RegularizersTest ) +find_package(tiff) +if (TIFF_FOUND) + message("LibTIFF Found") + message("TIFF_INCLUDE_DIR "${TIFF_INCLUDE_DIR}) + message("TIFF_LIBRARIES"${TIFF_LIBRARIES}) +else() + message("LibTIFF not found") +endif() \ No newline at end of file diff --git a/src/Python/compile.bat.in b/src/Python/compile.bat.in index ab23404..e5342ed 100644 --- a/src/Python/compile.bat.in +++ b/src/Python/compile.bat.in @@ -1,4 +1,7 @@ set CIL_VERSION=@CIL_VERSION@ -activate @CONDA_ENVIRONMENT@ +set PREFIX=@CONDA_ENVIRONMENT_PREFIX@ +set LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ + +REM activate @CONDA_ENVIRONMENT@ conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi -c conda-forge \ No newline at end of file diff --git a/src/Python/compile.sh.in b/src/Python/compile.sh.in index dd29973..ca9f310 100644 --- a/src/Python/compile.sh.in +++ b/src/Python/compile.sh.in @@ -1,6 +1,9 @@ #!/bin/sh +# compile within the right conda environment +#module load python/anaconda +#source activate @CONDA_ENVIRONMENT@ export CIL_VERSION=@CIL_VERSION@ -module load python/anaconda -source activate @CONDA_ENVIRONMENT@ +export LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ + conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi \ No newline at end of file -- cgit v1.2.3 From 72e66dcdc5a9297846dfed89f7801653e2e45aa3 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 23 Oct 2017 10:59:51 +0100 Subject: Added setup.py.in --- src/Python/setup.py.in | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 src/Python/setup.py.in (limited to 'src') diff --git a/src/Python/setup.py.in b/src/Python/setup.py.in new file mode 100644 index 0000000..0a1f4ad --- /dev/null +++ b/src/Python/setup.py.in @@ -0,0 +1,67 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform + +cil_version=@CIL_VERSION@ + +library_include_path = "" +library_lib_path = "" +try: + library_include_path = os.environ['LIBRARY_INC'] + library_lib_path = os.environ['LIBRARY_LIB'] +except: + library_include_path = os.environ['PREFIX']+'/include' + pass + +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [os.path.join(library_include_path, "..", "lib")] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +extra_include_dirs += [os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU"), + os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_GPU") , + "@CMAKE_CURRENT_SOURCE_DIR@"] + +if platform.system() == 'Windows': + extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB' , '/openmp' ] + + if sys.version_info.major == 3 : + extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] + else: + extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: + if sys.version_info.major == 3: + extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] + else: + extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( + name='ccpi', + description='CCPi Core Imaging Library - FISTA Reconstruction Module', + version=cil_version, + cmdclass = {'build_ext': build_ext}, + ext_modules = [Extension("ccpi.imaging.cpu_regularizers", + sources=[os.path.join("@CMAKE_CURRENT_SOURCE_DIR@" , "fista_module.cpp" ), + os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "FGP_TV_core.c"), + os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "SplitBregman_TV_core.c"), + os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "LLT_model_core.c"), + os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "PatchBased_Regul_core.c"), + os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "TGV_PD_core.c"), + os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" , "regularizers_CPU", "utils.c") + ], + include_dirs=extra_include_dirs, + library_dirs=extra_library_dirs, + extra_compile_args=extra_compile_args, + libraries=extra_libraries ), + + ], + zip_safe = False, + packages = {'ccpi','ccpi.imaging'}, +) -- cgit v1.2.3 From dcceb4e6b0aa515d47fb54f7325983f378aa743f Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 23 Oct 2017 13:55:35 +0100 Subject: Finds the active conda environment sets up for the current active conda environmnet --- src/Python/CMakeLists.txt | 60 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt index b84f5a3..707f006 100644 --- a/src/Python/CMakeLists.txt +++ b/src/Python/CMakeLists.txt @@ -23,7 +23,59 @@ set (NUMPY_VERSION 1.12) #set (PYTHON_VERSION 3.5) #https://groups.google.com/a/continuum.io/forum/#!topic/anaconda/R9gWjl09UFs -set (CONDA_ENVIRONMENT "cil") +#set (CONDA_ENVIRONMENT "cil") + +execute_process(COMMAND "conda" "env" "list" + OUTPUT_VARIABLE _CONDA_ENVS + RESULT_VARIABLE _CONDA_RESULT + ERROR_VARIABLE _CONDA_ERR) + if(NOT _CONDA_RESULT) + #message("conda envs list " ${_CONDA_ENVS}) + string(REPLACE "\n" ";" ENV_LIST ${_CONDA_ENVS}) + #string(REGEX MATCHALL "^.*[\t\n]" matches ${_CONDA_ENVS}) + foreach(line ${ENV_LIST}) + message("line='${line}'") + string(REGEX MATCHALL "(.+)[*](.+)" match ${line}) + #list(LENGTH ${match} N) + #string(LENGTH ${match} Ns) + #string(REPLACE "*" ";" env_dir ) + #message("list length " ${N} " string length " ${Ns}) + #list(GET ${env_dir} 0 CONDA_ENVIRONMENT) + #list(GET ${env_dir} 1 CONDA_PATH) + if (NOT ${match} EQUAL "") + message("match='${match}'") + string(REPLACE "*" ";" ENV_DIR ${match}) + #string(MATCHALL "(.*)[*](.*)" ENV_DIR ${match}) + list (APPEND cc "") + foreach(conda ${ENV_DIR}) + message("THERE YOU ARE " ${conda}) + list(APPEND cc ${conda}) + endforeach() + list(LENGTH cc Ns) + message("cc " ${cc} " " ${Ns}) + if (${Ns} EQUAL 2) + list(GET cc 0 CONDA_ENVIRONMENT) + list(GET cc 1 CONDA_ENVIRONMENT_PATH) + message("Current conda environmnet " ${CONDA_ENVIRONMENT}) + message("Current conda environmnet path" ${CONDA_ENVIRONMENT_PATH}) + + endif() + + #list(GET ${env_dir} 0 CONDA_ENVIRONMENT) + #list(GET ${env_dir} 1 CONDA_PATH) + #message("******" ${env_dir}) + #message("******" ${CONDA_ENVIRONMENT} " " ${CONDA_PATH} ) + endif() + endforeach() + #string(REGEX REPLACE "^.*[*].*" "" CONDA_ENVIRONMENT ${_CONDA_ENVS}) + else() + message("conda result false" ${_CONDA_ERR}) + endif() + +message("**********************************************************") +message("Current conda environmnet " ${CONDA_ENVIRONMENT}) +message("Current conda environmnet path" ${CONDA_ENVIRONMENT_PATH}) + if (WIN32) set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}\\Library" CACHE PATH "env dir") @@ -70,8 +122,8 @@ endif() find_package(tiff) if (TIFF_FOUND) message("LibTIFF Found") - message("TIFF_INCLUDE_DIR "${TIFF_INCLUDE_DIR}) - message("TIFF_LIBRARIES"${TIFF_LIBRARIES}) + message("TIFF_INCLUDE_DIR " ${TIFF_INCLUDE_DIR}) + message("TIFF_LIBRARIES" ${TIFF_LIBRARIES}) else() message("LibTIFF not found") -endif() \ No newline at end of file +endif() -- cgit v1.2.3 From 808167b9a333e5f351d39c8a791104c1b7a08aab Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 23 Oct 2017 13:56:36 +0100 Subject: executable extension is for WIN32 --- src/Python/FindAnacondaEnvironment.cmake | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/Python/FindAnacondaEnvironment.cmake b/src/Python/FindAnacondaEnvironment.cmake index 3abb5d1..6d91dba 100644 --- a/src/Python/FindAnacondaEnvironment.cmake +++ b/src/Python/FindAnacondaEnvironment.cmake @@ -36,8 +36,12 @@ function (findPythonForAnacondaEnvironment env) - - file(TO_CMAKE_PATH ${env}/python.exe PYTHON_EXECUTABLE) + set (EXE "") + if (WIN32) + set (EXE ".exe") + endif() + + file(TO_CMAKE_PATH ${env}/python${EXE} PYTHON_EXECUTABLE) message("Found " ${PYTHON_EXECUTABLE}) ####### FROM FindPythonInterpr ######## -- cgit v1.2.3 From 99b106a2ced10136a2a71b42440d8c1ffa4f8d2c Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 23 Oct 2017 14:47:43 +0100 Subject: Bugfixes for UNIX compilation --- src/Python/CMakeLists.txt | 93 ++++++++++++++------------------ src/Python/FindAnacondaEnvironment.cmake | 41 ++++---------- src/Python/compile.sh.in | 2 +- 3 files changed, 52 insertions(+), 84 deletions(-) (limited to 'src') diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt index 707f006..e539eda 100644 --- a/src/Python/CMakeLists.txt +++ b/src/Python/CMakeLists.txt @@ -25,73 +25,52 @@ set (NUMPY_VERSION 1.12) #https://groups.google.com/a/continuum.io/forum/#!topic/anaconda/R9gWjl09UFs #set (CONDA_ENVIRONMENT "cil") +## Tries to parse the output of conda env list to determine the current +## active conda environment execute_process(COMMAND "conda" "env" "list" OUTPUT_VARIABLE _CONDA_ENVS RESULT_VARIABLE _CONDA_RESULT ERROR_VARIABLE _CONDA_ERR) if(NOT _CONDA_RESULT) - #message("conda envs list " ${_CONDA_ENVS}) string(REPLACE "\n" ";" ENV_LIST ${_CONDA_ENVS}) - #string(REGEX MATCHALL "^.*[\t\n]" matches ${_CONDA_ENVS}) foreach(line ${ENV_LIST}) - message("line='${line}'") string(REGEX MATCHALL "(.+)[*](.+)" match ${line}) - #list(LENGTH ${match} N) - #string(LENGTH ${match} Ns) - #string(REPLACE "*" ";" env_dir ) - #message("list length " ${N} " string length " ${Ns}) - #list(GET ${env_dir} 0 CONDA_ENVIRONMENT) - #list(GET ${env_dir} 1 CONDA_PATH) if (NOT ${match} EQUAL "") - message("match='${match}'") - string(REPLACE "*" ";" ENV_DIR ${match}) - #string(MATCHALL "(.*)[*](.*)" ENV_DIR ${match}) - list (APPEND cc "") - foreach(conda ${ENV_DIR}) - message("THERE YOU ARE " ${conda}) - list(APPEND cc ${conda}) - endforeach() - list(LENGTH cc Ns) - message("cc " ${cc} " " ${Ns}) - if (${Ns} EQUAL 2) - list(GET cc 0 CONDA_ENVIRONMENT) - list(GET cc 1 CONDA_ENVIRONMENT_PATH) - message("Current conda environmnet " ${CONDA_ENVIRONMENT}) - message("Current conda environmnet path" ${CONDA_ENVIRONMENT_PATH}) - - endif() - - #list(GET ${env_dir} 0 CONDA_ENVIRONMENT) - #list(GET ${env_dir} 1 CONDA_PATH) - #message("******" ${env_dir}) - #message("******" ${CONDA_ENVIRONMENT} " " ${CONDA_PATH} ) + string(REPLACE "*" ";" ENV_DIR ${match}) + list (APPEND cc "") + foreach(conda ${ENV_DIR}) + string(STRIP ${conda} stripped) + list(APPEND cc ${stripped}) + endforeach() + list(LENGTH cc Ns) + if (${Ns} EQUAL 2) + list(GET cc 0 CONDA_ENVIRONMENT) + list(GET cc 1 CONDA_ENVIRONMENT_PATH) + endif() endif() endforeach() - #string(REGEX REPLACE "^.*[*].*" "" CONDA_ENVIRONMENT ${_CONDA_ENVS}) else() message("conda result false" ${_CONDA_ERR}) endif() -message("**********************************************************") -message("Current conda environmnet " ${CONDA_ENVIRONMENT}) -message("Current conda environmnet path" ${CONDA_ENVIRONMENT_PATH}) - -if (WIN32) - set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") - set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}\\Library" CACHE PATH "env dir") - set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}\\include" CACHE PATH "env dir") -elseif (UNIX) - set (CONDA_ENVIRONMENT_PATH "/apps/anaconda/2.4/envs/${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") - set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}" CACHE PATH "env dir") - set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}/include" CACHE PATH "env dir") +if (${CONDA_ENVIRONMENT} AND ${CONDA_ENVIRONMENT_PATH}) + message (FATAL_ERROR "CONDA NOT FOUND") +else() + message("**********************************************************") + message("Using current conda environmnet " ${CONDA_ENVIRONMENT}) + message("Using current conda environmnet path" ${CONDA_ENVIRONMENT_PATH}) endif() + + message("CIL VERSION " ${CIL_VERSION}) # set the Python variables for the Conda environment include(FindAnacondaEnvironment.cmake) findPythonForAnacondaEnvironment(${CONDA_ENVIRONMENT_PATH}) message("Python found " ${PYTHON_VERSION_STRING}) +message("Python found Major " ${PYTHON_VERSION_MAJOR}) +message("Python found Minor " ${PYTHON_VERSION_MINOR}) findPythonPackagesPath() message("PYTHON_PACKAGES_FOUND " ${PYTHON_PACKAGES_PATH}) @@ -101,6 +80,15 @@ file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRE file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) +if (WIN32) + #set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") + set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}\\Library" CACHE PATH "env dir") + set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}\\include" CACHE PATH "env dir") +elseif (UNIX) + #set (CONDA_ENVIRONMENT_PATH "/apps/anaconda/2.4/envs/${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") + set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}/lib/python${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}" CACHE PATH "env dir") + set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}/include" CACHE PATH "env dir") +endif() # Copy and configure the relative conda build and recipes configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup.py) @@ -111,6 +99,7 @@ if (WIN32) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat) elseif(UNIX) + message ("We are on UNIX") file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) # assumes we will use bash configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh) @@ -119,11 +108,11 @@ endif() ### add tests #add_executable(RegularizersTest ) -find_package(tiff) -if (TIFF_FOUND) - message("LibTIFF Found") - message("TIFF_INCLUDE_DIR " ${TIFF_INCLUDE_DIR}) - message("TIFF_LIBRARIES" ${TIFF_LIBRARIES}) -else() - message("LibTIFF not found") -endif() +#find_package(tiff) +#if (TIFF_FOUND) +# message("LibTIFF Found") +# message("TIFF_INCLUDE_DIR " ${TIFF_INCLUDE_DIR}) +# message("TIFF_LIBRARIES" ${TIFF_LIBRARIES}) +#else() +# message("LibTIFF not found") +#endif() diff --git a/src/Python/FindAnacondaEnvironment.cmake b/src/Python/FindAnacondaEnvironment.cmake index 6d91dba..fa4637a 100644 --- a/src/Python/FindAnacondaEnvironment.cmake +++ b/src/Python/FindAnacondaEnvironment.cmake @@ -36,14 +36,14 @@ function (findPythonForAnacondaEnvironment env) - set (EXE "") if (WIN32) - set (EXE ".exe") + file(TO_CMAKE_PATH ${env}/python.exe PYTHON_EXECUTABLE) + elseif (UNIX) + file(TO_CMAKE_PATH ${env}/bin/python PYTHON_EXECUTABLE) endif() - file(TO_CMAKE_PATH ${env}/python${EXE} PYTHON_EXECUTABLE) - message("Found " ${PYTHON_EXECUTABLE}) + message("findPythonForAnacondaEnvironment Found Python Executable" ${PYTHON_EXECUTABLE}) ####### FROM FindPythonInterpr ######## # determine python version string if(PYTHON_EXECUTABLE) @@ -133,37 +133,16 @@ endif() function(findPythonPackagesPath) -### https://openlab.ncl.ac.uk/gitlab/john.shearer/clappertracker/raw/549885e5decd37f7b23e9c1fd39e86f207156795/src/3rdparty/opencv/cmake/OpenCVDetectPython.cmake -### -if(CMAKE_HOST_UNIX) - execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "from distutils.sysconfig import *; print get_python_lib()" + execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "from distutils.sysconfig import *; print (get_python_lib())" RESULT_VARIABLE PYTHON_CVPY_PROCESS OUTPUT_VARIABLE PYTHON_STD_PACKAGES_PATH OUTPUT_STRIP_TRAILING_WHITESPACE) - if("${PYTHON_STD_PACKAGES_PATH}" MATCHES "site-packages") + #message("STD_PACKAGES " ${PYTHON_STD_PACKAGES_PATH}) + if("${PYTHON_STD_PACKAGES_PATH}" MATCHES "site-packages") set(_PYTHON_PACKAGES_PATH "python${PYTHON_VERSION_MAJOR_MINOR}/site-packages") - else() #debian based assumed, install to the dist-packages. - set(_PYTHON_PACKAGES_PATH "python${PYTHON_VERSION_MAJOR_MINOR}/dist-packages") - endif() - if(EXISTS "${CMAKE_INSTALL_PREFIX}/lib${LIB_SUFFIX}/${PYTHON_PACKAGES_PATH}") - set(_PYTHON_PACKAGES_PATH "lib${LIB_SUFFIX}/${_PYTHON_PACKAGES_PATH}") - else() - set(_PYTHON_PACKAGES_PATH "lib/${_PYTHON_PACKAGES_PATH}") - endif() - elseif(CMAKE_HOST_WIN32) - get_filename_component(PYTHON_PATH "${PYTHON_EXECUTABLE}" PATH) - file(TO_CMAKE_PATH "${PYTHON_PATH}" PYTHON_PATH) - if(NOT EXISTS "${PYTHON_PATH}/Lib/site-packages") - unset(PYTHON_PATH) - get_filename_component(PYTHON_PATH "[HKEY_LOCAL_MACHINE\\SOFTWARE\\Python\\PythonCore\\${PYTHON_VERSION_MAJOR_MINOR}\\InstallPath]" ABSOLUTE) - if(NOT PYTHON_PATH) - get_filename_component(PYTHON_PATH "[HKEY_CURRENT_USER\\SOFTWARE\\Python\\PythonCore\\${PYTHON_VERSION_MAJOR_MINOR}\\InstallPath]" ABSOLUTE) - endif() - file(TO_CMAKE_PATH "${PYTHON_PATH}" PYTHON_PATH) - endif() - set(_PYTHON_PACKAGES_PATH "${PYTHON_PATH}/Lib/site-packages") - endif() - SET(PYTHON_PACKAGES_PATH "${_PYTHON_PACKAGES_PATH}" PARENT_SCOPE) + endif() + + SET(PYTHON_PACKAGES_PATH "${PYTHON_STD_PACKAGES_PATH}" PARENT_SCOPE) endfunction() diff --git a/src/Python/compile.sh.in b/src/Python/compile.sh.in index ca9f310..93fdba2 100644 --- a/src/Python/compile.sh.in +++ b/src/Python/compile.sh.in @@ -6,4 +6,4 @@ export CIL_VERSION=@CIL_VERSION@ export LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ -conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi \ No newline at end of file +conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi -- cgit v1.2.3 From 44ec01fa2e8d8da2dce4950ea3d822fe7c8cd8d5 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 23 Oct 2017 15:00:57 +0100 Subject: minor cleanup --- src/Python/CMakeLists.txt | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) (limited to 'src') diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt index e539eda..fd74ee7 100644 --- a/src/Python/CMakeLists.txt +++ b/src/Python/CMakeLists.txt @@ -12,18 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -message("CIL VERSION " ${CIL_VERSION}) - - # variables that must be set for conda compilation #PREFIX=C:\Apps\Miniconda2\envs\cil\Library #LIBRARY_INC=C:\\Apps\\Miniconda2\\envs\\cil\\Library\\include set (NUMPY_VERSION 1.12) -#set (PYTHON_VERSION 3.5) - -#https://groups.google.com/a/continuum.io/forum/#!topic/anaconda/R9gWjl09UFs -#set (CONDA_ENVIRONMENT "cil") ## Tries to parse the output of conda env list to determine the current ## active conda environment @@ -50,27 +43,23 @@ execute_process(COMMAND "conda" "env" "list" endif() endforeach() else() - message("conda result false" ${_CONDA_ERR}) + message(FATAL_ERROR "conda error " ${_CONDA_ERR}) endif() -if (${CONDA_ENVIRONMENT} AND ${CONDA_ENVIRONMENT_PATH}) - message (FATAL_ERROR "CONDA NOT FOUND") -else() - message("**********************************************************") - message("Using current conda environmnet " ${CONDA_ENVIRONMENT}) - message("Using current conda environmnet path" ${CONDA_ENVIRONMENT_PATH}) -endif() - - +message("**********************************************************") +message("Using current conda environmnet " ${CONDA_ENVIRONMENT}) +message("Using current conda environmnet path" ${CONDA_ENVIRONMENT_PATH}) message("CIL VERSION " ${CIL_VERSION}) # set the Python variables for the Conda environment include(FindAnacondaEnvironment.cmake) findPythonForAnacondaEnvironment(${CONDA_ENVIRONMENT_PATH}) + message("Python found " ${PYTHON_VERSION_STRING}) message("Python found Major " ${PYTHON_VERSION_MAJOR}) message("Python found Minor " ${PYTHON_VERSION_MINOR}) + findPythonPackagesPath() message("PYTHON_PACKAGES_FOUND " ${PYTHON_PACKAGES_PATH}) @@ -105,6 +94,7 @@ elseif(UNIX) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh) endif() + ### add tests #add_executable(RegularizersTest ) -- cgit v1.2.3 From ece0bfc45cf2e339fc517a4f2c078f0b8fe274ad Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Mon, 23 Oct 2017 17:04:41 +0100 Subject: add fista-recipe and stuff --- src/Python/CMakeLists.txt | 16 +++++++++++++++- src/Python/fista-recipe/build.sh | 10 ++++++++++ src/Python/fista-recipe/meta.yaml | 28 ++++++++++++++++++++++++++++ src/Python/setup-fista.py.in | 27 +++++++++++++++++++++++++++ src/Python/setup.py.in | 4 +++- 5 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 src/Python/fista-recipe/build.sh create mode 100644 src/Python/fista-recipe/meta.yaml create mode 100644 src/Python/setup-fista.py.in (limited to 'src') diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt index fd74ee7..33ebc08 100644 --- a/src/Python/CMakeLists.txt +++ b/src/Python/CMakeLists.txt @@ -63,12 +63,19 @@ message("Python found Minor " ${PYTHON_VERSION_MINOR}) findPythonPackagesPath() message("PYTHON_PACKAGES_FOUND " ${PYTHON_PACKAGES_PATH}) -# copy the Pyhon files of the package +######### CONFIGURE REGULARIZER ############# + +# copy the Pyhon files of the package regularizer file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging/) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi) +# regularizers file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) +# fista reconstructor +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/FISTAReconstructor.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction) +#file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction) + if (WIN32) #set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}\\Library" CACHE PATH "env dir") @@ -84,9 +91,16 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DI file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe) +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup-fista.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup-fista.py) +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe) + if (WIN32) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat) + + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/) + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat) elseif(UNIX) message ("We are on UNIX") file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) diff --git a/src/Python/fista-recipe/build.sh b/src/Python/fista-recipe/build.sh new file mode 100644 index 0000000..e3f3552 --- /dev/null +++ b/src/Python/fista-recipe/build.sh @@ -0,0 +1,10 @@ +if [ -z "$CIL_VERSION" ]; then + echo "Need to set CIL_VERSION" + exit 1 +fi +mkdir "$SRC_DIR/ccpifista" +cp -r "$RECIPE_DIR/.." "$SRC_DIR/ccpifista" + +cd $SRC_DIR/ccpifista + +$PYTHON setup-fista.py install diff --git a/src/Python/fista-recipe/meta.yaml b/src/Python/fista-recipe/meta.yaml new file mode 100644 index 0000000..64c9b5d --- /dev/null +++ b/src/Python/fista-recipe/meta.yaml @@ -0,0 +1,28 @@ +package: + name: ccpi-fista + version: {{ environ['CIL_VERSION'] }} + + +build: + preserve_egg_dir: False + script_env: + - CIL_VERSION +# number: 0 + +requirements: + build: + - python + - numpy + - setuptools + + run: + - python + - numpy + - astra + + + +about: + home: http://www.ccpi.ac.uk + license: Apache v.2.0 license + summary: 'CCPi Core Imaging Library (Viewer)' diff --git a/src/Python/setup-fista.py.in b/src/Python/setup-fista.py.in new file mode 100644 index 0000000..c5c9f4d --- /dev/null +++ b/src/Python/setup-fista.py.in @@ -0,0 +1,27 @@ +from distutils.core import setup +#from setuptools import setup, find_packages +import os + +cil_version=os.environ['CIL_VERSION'] +if cil_version == '': + print("Please set the environmental variable CIL_VERSION") + sys.exit(1) + +setup( + name="ccpi-fista", + version=cil_version, + packages=['ccpi','ccpi.reconstruction'], + install_requires=['numpy'], + + zip_safe = False, + + # metadata for upload to PyPI + author="Edoardo Pasca", + author_email="edo.paskino@gmail.com", + description='CCPi Core Imaging Library - FISTA Reconstructor module', + license="Apache v2.0", + keywords="tomography interative reconstruction", + url="http://www.ccpi.ac.uk", # project home page, if any + + # could also include long_description, download_url, classifiers, etc. +) diff --git a/src/Python/setup.py.in b/src/Python/setup.py.in index 0a1f4ad..12e8af1 100644 --- a/src/Python/setup.py.in +++ b/src/Python/setup.py.in @@ -44,7 +44,7 @@ else: setup( name='ccpi', - description='CCPi Core Imaging Library - FISTA Reconstruction Module', + description='CCPi Core Imaging Library - Image Regularizers', version=cil_version, cmdclass = {'build_ext': build_ext}, ext_modules = [Extension("ccpi.imaging.cpu_regularizers", @@ -65,3 +65,5 @@ setup( zip_safe = False, packages = {'ccpi','ccpi.imaging'}, ) + + -- cgit v1.2.3 From 4fd4f187a70c0e4f56d5194b09ab4a528d20ee51 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 10:38:37 +0100 Subject: Builds 2 packages fista and regularizers --- src/Python/CMakeLists.txt | 63 ++++--- src/Python/compile-fista.bat.in | 7 + src/Python/compile-fista.sh.in | 9 + src/Python/conda-recipe/meta.yaml | 2 +- src/Python/fista-recipe/meta.yaml | 3 +- src/Python/test/test_reconstructor.py | 301 ++++++++++++++++++++++++++++++++++ 6 files changed, 359 insertions(+), 26 deletions(-) create mode 100644 src/Python/compile-fista.bat.in create mode 100644 src/Python/compile-fista.sh.in create mode 100644 src/Python/test/test_reconstructor.py (limited to 'src') diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt index 33ebc08..c5e14ea 100644 --- a/src/Python/CMakeLists.txt +++ b/src/Python/CMakeLists.txt @@ -20,6 +20,7 @@ set (NUMPY_VERSION 1.12) ## Tries to parse the output of conda env list to determine the current ## active conda environment +message ("Trying to determine your active conda environment...") execute_process(COMMAND "conda" "env" "list" OUTPUT_VARIABLE _CONDA_ENVS RESULT_VARIABLE _CONDA_RESULT @@ -43,12 +44,12 @@ execute_process(COMMAND "conda" "env" "list" endif() endforeach() else() - message(FATAL_ERROR "conda error " ${_CONDA_ERR}) + message(FATAL_ERROR "ERROR with conda command " ${_CONDA_ERR}) endif() message("**********************************************************") -message("Using current conda environmnet " ${CONDA_ENVIRONMENT}) -message("Using current conda environmnet path" ${CONDA_ENVIRONMENT_PATH}) +message("Active conda environmnet: " ${CONDA_ENVIRONMENT}) +message("Active conda environmnet path: " ${CONDA_ENVIRONMENT_PATH}) message("CIL VERSION " ${CIL_VERSION}) @@ -63,19 +64,6 @@ message("Python found Minor " ${PYTHON_VERSION_MINOR}) findPythonPackagesPath() message("PYTHON_PACKAGES_FOUND " ${PYTHON_PACKAGES_PATH}) -######### CONFIGURE REGULARIZER ############# - -# copy the Pyhon files of the package regularizer -file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging/) -file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi) -# regularizers -file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) -file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) - -# fista reconstructor -file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/FISTAReconstructor.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction) -#file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction) - if (WIN32) #set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}\\Library" CACHE PATH "env dir") @@ -86,26 +74,53 @@ elseif (UNIX) set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}/include" CACHE PATH "env dir") endif() +######### CONFIGURE REGULARIZER PACKAGE ############# + +# copy the Pyhon files of the package regularizer +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging/) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi) +# regularizers +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) + # Copy and configure the relative conda build and recipes configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup.py) file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe) +if (WIN32) + + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat) + +elseif(UNIX) + + message ("We are on UNIX") + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) + # assumes we will use bash + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh) + +endif() + +########## CONFIGURE FISTA RECONSTRUCTOR PACKAGE +# fista reconstructor +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/FISTAReconstructor.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction) + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup-fista.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup-fista.py) file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe) if (WIN32) - file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) - configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat) - file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/) - configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat) + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/) + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.bat) + elseif(UNIX) - message ("We are on UNIX") - file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) - # assumes we will use bash - configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh) + message ("We are on UNIX") + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/) + # assumes we will use bash + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.sh) endif() diff --git a/src/Python/compile-fista.bat.in b/src/Python/compile-fista.bat.in new file mode 100644 index 0000000..b1db686 --- /dev/null +++ b/src/Python/compile-fista.bat.in @@ -0,0 +1,7 @@ +set CIL_VERSION=@CIL_VERSION@ + +set PREFIX=@CONDA_ENVIRONMENT_PREFIX@ +set LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ + +REM activate @CONDA_ENVIRONMENT@ +conda build fista-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi -c conda-forge diff --git a/src/Python/compile-fista.sh.in b/src/Python/compile-fista.sh.in new file mode 100644 index 0000000..267f014 --- /dev/null +++ b/src/Python/compile-fista.sh.in @@ -0,0 +1,9 @@ +#!/bin/sh +# compile within the right conda environment +#module load python/anaconda +#source activate @CONDA_ENVIRONMENT@ + +export CIL_VERSION=@CIL_VERSION@ +export LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ + +conda build fista-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi diff --git a/src/Python/conda-recipe/meta.yaml b/src/Python/conda-recipe/meta.yaml index c5b7a89..7068e9d 100644 --- a/src/Python/conda-recipe/meta.yaml +++ b/src/Python/conda-recipe/meta.yaml @@ -1,5 +1,5 @@ package: - name: ccpi-fista + name: ccpi-regularizers version: {{ environ['CIL_VERSION'] }} diff --git a/src/Python/fista-recipe/meta.yaml b/src/Python/fista-recipe/meta.yaml index 64c9b5d..89bf597 100644 --- a/src/Python/fista-recipe/meta.yaml +++ b/src/Python/fista-recipe/meta.yaml @@ -18,7 +18,8 @@ requirements: run: - python - numpy - - astra + - astra-toolbox + - ccpi-regularizers diff --git a/src/Python/test/test_reconstructor.py b/src/Python/test/test_reconstructor.py new file mode 100644 index 0000000..07668ba --- /dev/null +++ b/src/Python/test/test_reconstructor.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.fista.FISTAReconstructor import FISTAReconstructor +import astra +import matplotlib.pyplot as plt + +def RMSE(signal1, signal2): + '''RMSE Root Mean Squared Error''' + if numpy.shape(signal1) == numpy.shape(signal2): + err = (signal1 - signal2) + err = numpy.sum( err * err )/numpy.size(signal1); # MSE + err = sqrt(err); # RMSE + return err + else: + raise Exception('Input signals must have the same shape') + +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") +Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 20 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, + image_size_y, + image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) + +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +fistaRecon.setParameter(number_of_iterations = 12) +fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +fistaRecon.setParameter(ring_alpha = 21) +fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + +## Ordered subset +if False: + subsets = 16 + angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles))) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + +if False: + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + print ("prepare for iteration") + fistaRecon.prepareForIteration() + + + + print("initializing ...") + if False: + # if X doesn't exist + #N = params.vol_geom.GridColCount + N = vol_geom['GridColCount'] + print ("N " + str(N)) + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + #X = fistaRecon.initialize() + X = numpy.load("X.npy") + + print (numpy.shape(X)) + X_t = X.copy() + print ("initialized") + proj_geom , vol_geom, sino , \ + SlicesZ = fistaRecon.getParameter(['projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ']) + + #fistaRecon.setParameter(number_of_iterations = 3) + iterFISTA = fistaRecon.getParameter('number_of_iterations') + # errors vector (if the ground truth is given) + Resid_error = numpy.zeros((iterFISTA)); + # objective function values vector + objective = numpy.zeros((iterFISTA)); + + + t = 1 + + + print ("starting iterations") +## % Outer FISTA iterations loop + for i in range(fistaRecon.getParameter('number_of_iterations')): + X_old = X.copy() + t_old = t + r_old = fistaRecon.r.copy() + if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' : + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + for kkk in range(SlicesZ): + sino_id, sino_updt[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geom, vol_geom) + astra.matlab.data3d('delete', sino_id) + else: + # for divergent 3D geometry (watch the GPU memory overflow in + # ASTRA versions < 1.8) + #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + sino_id, sino_updt = astra.creators.create_sino3d_gpu( + X_t, proj_geom, vol_geom) + + ## RING REMOVAL + residual = fistaRecon.residual + lambdaR_L1 , alpha_ring , weights , L_const= \ + fistaRecon.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant']) + r_x = fistaRecon.r_x + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(fistaRecon.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + print ("ring removal") + for kkk in range(anglesNumb): + + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + fistaRecon.r = (r_x - (1./L_const) * vec).copy() + objective[i] = (0.5 * (residual ** 2).sum()) +## % the ring removal part (Group-Huber fidelity) +## for kkk = 1:anglesNumb +## residual(:,kkk,:) = squeeze(weights(:,kkk,:)).* +## (squeeze(sino_updt(:,kkk,:)) - +## (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); +## end +## vec = sum(residual,2); +## if (SlicesZ > 1) +## vec = squeeze(vec(:,1,:)); +## end +## r = r_x - (1./L_const).*vec; +## objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + + + + # Projection/Backprojection Routine + if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\ + fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec': + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + print ("Projection/Backprojection Routine") + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) + + + ## REGULARIZATION + ## SKIPPING FOR NOW + ## Should be simpli + # regularizer = fistaRecon.getParameter('regularizer') + # for slices: + # out = regularizer(input=X) + print ("skipping regularizer") + + + ## FINAL + print ("final") + lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + fistaRecon.r = numpy.max( + numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ + numpy.sign(fistaRecon.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + fistaRecon.r_x = fistaRecon.r + \ + (((t_old-1)/t) * (fistaRecon.r - r_old)) + + if fistaRecon.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], objective[i])) + +## if (lambdaR_L1 > 0) +## r = max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector +## end +## +## t = (1 + sqrt(1 + 4*t^2))/2; % updating t +## X_t = X + ((t_old-1)/t).*(X - X_old); % updating X +## +## if (lambdaR_L1 > 0) +## r_x = r + ((t_old-1)/t).*(r - r_old); % updating r +## end +## +## if (show == 1) +## figure(10); imshow(X(:,:,slice), [0 maxvalplot]); +## if (lambdaR_L1 > 0) +## figure(11); plot(r); title('Rings offset vector') +## end +## pause(0.01); +## end +## if (strcmp(X_ideal, 'none' ) == 0) +## Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); +## fprintf('%s %i %s %s %.4f %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); +## else +## fprintf('%s %i %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); +## end +else: + fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) + + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + fistaRecon.setParameter(number_of_iterations = 12) + fistaRecon.setParameter(Lipschitz_constant = 767893952.0) + fistaRecon.setParameter(ring_alpha = 21) + fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + fistaRecon.prepareForIteration() + X = fistaRecon.iterate(numpy.load("X.npy")) -- cgit v1.2.3 From a11c59651ec125e24371a2049606df0f80f458d0 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 11:26:46 +0100 Subject: latest dev --- .../ccpi/reconstruction/FISTAReconstructor.py | 599 +++++++++++++++------ 1 file changed, 427 insertions(+), 172 deletions(-) (limited to 'src') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index ea96b53..85bfac5 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -21,10 +21,9 @@ import numpy -import h5py #from ccpi.reconstruction.parallelbeam import alg -from ccpi.imaging.Regularizer import Regularizer +#from ccpi.imaging.Regularizer import Regularizer from enum import Enum import astra @@ -74,18 +73,34 @@ class FISTAReconstructor(): # 3. "A novel tomographic reconstruction method based on the robust # Student's t function for suppressing data outliers" D. Kazantsev et.al. # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ + def __init__(self, projector_geometry, output_geometry, input_sinogram, + **kwargs): + # handle parmeters: + # obligatory parameters + self.pars = dict() + self.pars['projector_geometry'] = projector_geometry # proj_geom + self.pars['output_geometry'] = output_geometry # vol_geom + self.pars['input_sinogram'] = input_sinogram # sino + sliceZ, nangles, detectors = numpy.shape(input_sinogram) + self.pars['detectors'] = detectors + self.pars['number_of_angles'] = nangles + self.pars['SlicesZ'] = sliceZ + self.pars['output_volume'] = None + + print (self.pars) + # handle optional input parameters (at instantiation) # Accepted input keywords - kw = ('number_of_iterations', + kw = ( + # mandatory fields + 'projector_geometry', + 'output_geometry', + 'input_sinogram', + 'detectors', + 'number_of_angles', + 'SlicesZ', + # optional fields + 'number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , 'weights' , @@ -93,7 +108,13 @@ class FISTAReconstructor(): 'initialize' , 'regularizer' , 'ring_lambda_R_L1', - 'ring_alpha') + 'ring_alpha', + 'subsets', + 'output_volume', + 'os_subsets', + 'os_indices', + 'os_bins') + self.acceptedInputKeywords = list(kw) # handle keyworded parameters if kwargs is not None: @@ -110,85 +131,160 @@ class FISTAReconstructor(): if 'weights' in kwargs.keys(): self.pars['weights'] = kwargs['weights'] else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) + self.pars['weights'] = \ + numpy.ones(numpy.shape( + self.pars['input_sinogram'])) if 'Lipschitz_constant' in kwargs.keys(): self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() + self.pars['Lipschitz_constant'] = None - if not self.pars['ideal_image'] in kwargs.keys(): + if not 'ideal_image' in kwargs.keys(): self.pars['ideal_image'] = None - if not self.pars['region_of_interest'] : + if not 'region_of_interest'in kwargs.keys() : if self.pars['ideal_image'] == None: - pass + self.pars['region_of_interest'] = None else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : + ## nonzero if the image is larger than m + fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) + + self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) + + # the regularizer must be a correctly instantiated object + if not 'regularizer' in kwargs.keys() : self.pars['regularizer'] = None + + #RING REMOVAL + if not 'ring_lambda_R_L1' in kwargs.keys(): + self.pars['ring_lambda_R_L1'] = 0 + if not 'ring_alpha' in kwargs.keys(): + self.pars['ring_alpha'] = 1 + + # ORDERED SUBSET + if not 'subsets' in kwargs.keys(): + self.pars['subsets'] = 0 else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 + self.createOrderedSubsets() + + if not 'initialize' in kwargs.keys(): + self.pars['initialize'] = False + + def setParameter(self, **kwargs): + '''set named parameter for the reconstructor engine + + raises Exception if the named parameter is not recognized + ''' + for key , value in kwargs.items(): + if key in self.acceptedInputKeywords: + self.pars[key] = value + else: + raise Exception('Wrong parameter {0} for '.format(key) + + 'reconstructor') + # setParameter + + def getParameter(self, key): + if type(key) is str: + if key in self.acceptedInputKeywords: + return self.pars[key] + else: + raise Exception('Unrecongnised parameter: {0} '.format(key) ) + elif type(key) is list: + outpars = [] + for k in key: + outpars.append(self.getParameter(k)) + return outpars + else: + raise Exception('Unhandled input {0}' .format(str(type(key)))) + + def calculateLipschitzConstantWithPowerMethod(self): ''' using Power method (PM) to establish L constant''' - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] + N = self.pars['output_geometry']['GridColCount'] + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] weights = self.pars['weights'] SlicesZ = self.pars['SlicesZ'] - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): + + + if (proj_geom['type'] == 'parallel') or \ + (proj_geom['type'] == 'parallel3d'): #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM + #print('Calculating Lipshitz constant for parallel beam geometry...') + niter = 5;# % number of iteration for the PM #N = params.vol_geom.GridColCount; #x1 = rand(N,N,1); x1 = numpy.random.rand(1,N,N) #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) + sqweight = numpy.sqrt(weights[0]) proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; + proj_geomT['DetectorRowCount'] = 1; vol_geomT = vol_geom.copy(); vol_geomT['GridSliceCount'] = 1; + #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) + # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); + # s = norm(x1(:)); + # x1 = x1/s; + # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + # y = sqweight.*y; + # astra_mex_data3d('delete', sino_id); + # astra_mex_data3d('delete', id); + #print ("iteration {0}".format(i)) + + sino_id, y = astra.creators.create_sino3d_gpu(x1, + proj_geomT, + vol_geomT) + + y = (sqweight * y).copy() # element wise multiplication + + #b=fig.add_subplot(2,1,2) + #imgplot = plt.imshow(x1[0]) + #plt.show() + + #astra_mex_data3d('delete', sino_id); + astra.matlab.data3d('delete', sino_id) + del x1 - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); + idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), + proj_geomT, + vol_geomT) + del y + + s = numpy.linalg.norm(x1) ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight*y; + x1 = (x1/s).copy(); + + # ### this line? + # sino_id, y = astra.creators.create_sino3d_gpu(x1, + # proj_geomT, + # vol_geomT); + # y = sqweight * y; astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); + astra.matlab.data3d('delete', idx) + print ("iteration {0} s= {1}".format(i,s)) + #end del proj_geomT del vol_geomT + #plt.show() else: #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); + print('Calculating Lipshitz constant for divergen beam geometry...') niter = 8; #% number of iteration for PM x1 = numpy.random.rand(SlicesZ , N , N); #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) + sqweight = numpy.sqrt(weights[0]) sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); y = sqweight*y; @@ -217,6 +313,7 @@ class FISTAReconstructor(): #end #clear x1 del x1 + return s @@ -225,130 +322,288 @@ class FISTAReconstructor(): if regularizer is not None: self.pars['regularizer'] = regularizer + + def initialize(self): + # convenience variable storage + proj_geom = self.pars['projector_geometry'] + vol_geom = self.pars['output_geometry'] + sino = self.pars['input_sinogram'] + + # a 'warm start' with SIRT method + # Create a data object for the reconstruction + rec_id = astra.matlab.data3d('create', '-vol', + vol_geom); + + #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); + sinogram_id = astra.matlab.data3d('create', '-proj3d', + proj_geom, + sino) + + sirt_config = astra.astra_dict('SIRT3D_CUDA') + sirt_config['ReconstructionDataId' ] = rec_id + sirt_config['ProjectionDataId'] = sinogram_id + + sirt = astra.algorithm.create(sirt_config) + astra.algorithm.run(sirt, iterations=35) + X = astra.matlab.data3d('get', rec_id) + + # clean up memory + astra.matlab.data3d('delete', rec_id) + astra.matlab.data3d('delete', sinogram_id) + astra.algorithm.delete(sirt) + + + + return X + + def createOrderedSubsets(self, subsets=None): + if subsets is None: + try: + subsets = self.getParameter('subsets') + except Exception(): + subsets = 0 + #return subsets + + angles = self.getParameter('projector_geometry')['ProjectionAngles'] + + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles))) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + # store the OS in parameters + self.setParameter(os_subsets=subsets, + os_bins=binsDiscr, + os_indices=IndicesReorg) + + + def prepareForIteration(self): + print ("FISTA Reconstructor: prepare for iteration") + + self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) + self.objective = numpy.zeros((self.pars['number_of_iterations'])) + + #2D array (for 3D data) of sparse "ring" + detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) + self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) + # another ring variable + self.r_x = self.r.copy() + + self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) + + if self.getParameter('Lipschitz_constant') is None: + self.pars['Lipschitz_constant'] = \ + self.calculateLipschitzConstantWithPowerMethod() + # errors vector (if the ground truth is given) + self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); + # objective function values vector + self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); + + + # prepareForIteration + + def iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + # convenience variable storage + proj_geom , vol_geom, sino , \ + SlicesZ = self.getParameter([ 'projector_geometry' , + 'output_geometry', + 'input_sinogram', + 'SlicesZ' ]) + + t = 1 + + for i in range(self.getParameter('number_of_iterations')): + X_old = X.copy() + t_old = t + r_old = self.r.copy() + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) + for kkk in range(SlicesZ): + sino_id, self.sino_updt[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geomT, vol_geomT) + astra.matlab.data3d('delete', sino_id) + else: + # for divergent 3D geometry (watch the GPU memory overflow in + # ASTRA versions < 1.8) + #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); + sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( + X_t, proj_geom, vol_geom) + + + ## RING REMOVAL + self.ringRemoval(i) + ## Projection/Backprojection Routine + self.projectionBackprojection(X, X_t) + astra.matlab.data3d('delete', sino_id) + ## REGULARIZATION + X = self.regularize(X) + ## Update Loop + X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) + self.setParameter(output_volume=X) + return X + ## iterate - + def ringRemoval(self, i): + print ("FISTA Reconstructor: ring removal") + residual = self.residual + lambdaR_L1 , alpha_ring , weights , L_const , sino= \ + self.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant', + 'input_sinogram']) + r_x = self.r_x + sino_updt = self.sino_updt + + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + if lambdaR_L1 > 0 : + for kkk in range(anglesNumb): + + residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + vec = residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:].squeeze() + self.r = (r_x - (1./L_const) * vec).copy() + self.objective[i] = (0.5 * (residual ** 2).sum()) + def projectionBackprojection(self, X, X_t): + print ("FISTA Reconstructor: projection-backprojection routine") + + # a few useful variables + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(self.getParameter('input_sinogram')) + residual = self.residual + proj_geom , vol_geom , L_const = \ + self.getParameter(['projector_geometry' , + 'output_geometry', + 'Lipschitz_constant']) + + + if self.getParameter('projector_geometry')['type'] == 'parallel' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat' or \ + self.getParameter('projector_geometry')['type'] == 'fanflat_vec': + # if the geometry is parallel use slice-by-slice + # projection-backprojection routine + #sino_updt = zeros(size(sino),'single'); + proj_geomT = proj_geom.copy() + proj_geomT['DetectorRowCount'] = 1 + vol_geomT = vol_geom.copy() + vol_geomT['GridSliceCount'] = 1; + x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) + + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residual[kkk:kkk+1], + proj_geomT, vol_geomT) + astra.matlab.data3d('delete', x_id) + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residual, proj_geom, vol_geom) + + X = X_t - (1/L_const) * x_temp + #astra.matlab.data3d('delete', sino_id) + astra.matlab.data3d('delete', x_id) + + def regularize(self, X): + print ("FISTA Reconstructor: regularize") + + regularizer = self.getParameter('regularizer') + if regularizer is not None: + return regularizer(input=X) + else: + return X + + def updateLoop(self, i, X, X_old, r_old, t, t_old): + print ("FISTA Reconstructor: update loop") + lambdaR_L1 = self.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + self.r = numpy.max( + numpy.abs(self.r) - lambdaR_L1 , 0) * \ + numpy.sign(self.r) + t = (1 + numpy.sqrt(1 + 4 * t**2))/2 + X_t = X + (((t_old -1)/t) * (X - X_old)) + + if lambdaR_L1 > 0: + self.r_x = self.r + \ + (((t_old-1)/t) * (self.r - r_old)) + + if self.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, self.objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], self.objective[i])) + return (X , X_t, t) -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -#fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -##fname = "/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/dendr.h5" -##nx = h5py.File(fname, "r") -## -### the data are stored in a particular location in the hdf5 -##for item in nx['entry1/tomo_entry/data'].keys(): -## print (item) -## -##data = nx.get('entry1/tomo_entry/data/rotation_angle') -##angles = numpy.zeros(data.shape) -##data.read_direct(angles) -##print (angles) -### angles should be in degrees -## -##data = nx.get('entry1/tomo_entry/data/data') -##stack = numpy.zeros(data.shape) -##data.read_direct(stack) -##print (data.shape) -## -##print ("Data Loaded") -## -## -### Normalize -##data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -##itype = numpy.zeros(data.shape) -##data.read_direct(itype) -### 2 is dark field -##darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -##dark = darks[0] -##for i in range(1, len(darks)): -## dark += darks[i] -##dark = dark / len(darks) -###dark[0][0] = dark[0][1] -## -### 1 is flat field -##flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -##flat = flats[0] -##for i in range(1, len(flats)): -## flat += flats[i] -##flat = flat / len(flats) -###flat[0][0] = dark[0][1] -## -## -### 0 is projection data -##proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -##angle_proj = numpy.asarray (angle_proj) -##angle_proj = angle_proj.astype(numpy.float32) -## -### normalized data are -### norm = (projection - dark)/(flat-dark) -## -##def normalize(projection, dark, flat, def_val=0.1): -## a = (projection - dark) -## b = (flat-dark) -## with numpy.errstate(divide='ignore', invalid='ignore'): -## c = numpy.true_divide( a, b ) -## c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 -## return c -## -## -##norm = [normalize(projection, dark, flat) for projection in proj] -##norm = numpy.asarray (norm) -##norm = norm.astype(numpy.float32) - - -##niterations = 15 -##threads = 3 -## -##img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -##img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## iteration_values, False) -##print ("iteration values %s" % str(iteration_values)) -## -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -##iteration_values = numpy.zeros((niterations,)) -##img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, -## numpy.double(1e-5), iteration_values , False) -##print ("iteration values %s" % str(iteration_values)) -## -## -####numpy.save("cgls_recon.npy", img_data) -##import matplotlib.pyplot as plt -##fig, ax = plt.subplots(1,6,sharey=True) -##ax[0].imshow(img_cgls[80]) -##ax[0].axis('off') # clear x- and y-axes -##ax[1].imshow(img_sirt[80]) -##ax[1].axis('off') # clear x- and y-axes -##ax[2].imshow(img_mlem[80]) -##ax[2].axis('off') # clear x- and y-axesplt.show() -##ax[3].imshow(img_cgls_conv[80]) -##ax[3].axis('off') # clear x- and y-axesplt.show() -##ax[4].imshow(img_cgls_tikhonov[80]) -##ax[4].axis('off') # clear x- and y-axesplt.show() -##ax[5].imshow(img_cgls_TVreg[80]) -##ax[5].axis('off') # clear x- and y-axesplt.show() -## -## -##plt.show() -## + def os_iterate(self, Xin=None): + print ("FISTA Reconstructor: iterate") + + if Xin is None: + if self.getParameter('initialize'): + X = self.initialize() + else: + N = vol_geom['GridColCount'] + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + # copy by reference + X = Xin + # store the output volume in the parameters + self.setParameter(output_volume=X) + X_t = X.copy() + # some useful constants + proj_geom , vol_geom, sino , \ + SlicesZ, weights , alpha_ring , + lambdaR_L1 , L_const = self.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , + 'ring_lambda_R_L1', 'Lipschitz_constant']) -- cgit v1.2.3 From 909a7bb4d71bdb14d4e68f42c2297f6154a77ed0 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 11:27:17 +0100 Subject: use system package --- src/Python/test/test_reconstructor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/Python/test/test_reconstructor.py b/src/Python/test/test_reconstructor.py index 07668ba..377f3cf 100644 --- a/src/Python/test/test_reconstructor.py +++ b/src/Python/test/test_reconstructor.py @@ -9,7 +9,7 @@ Based on DemoRD2.m import h5py import numpy -from ccpi.fista.FISTAReconstructor import FISTAReconstructor +from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor import astra import matplotlib.pyplot as plt -- cgit v1.2.3 From 546104f8dfea5691801137c1be99d09e1e999d82 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 11:31:36 +0100 Subject: removed fista directory use the standard package reconstruction directory for the fista code --- src/Python/ccpi/fista/FISTAReconstructor.py | 609 ---------------------------- src/Python/ccpi/fista/Reconstructor.py | 425 ------------------- src/Python/ccpi/fista/__init__.py | 0 3 files changed, 1034 deletions(-) delete mode 100644 src/Python/ccpi/fista/FISTAReconstructor.py delete mode 100644 src/Python/ccpi/fista/Reconstructor.py delete mode 100644 src/Python/ccpi/fista/__init__.py (limited to 'src') diff --git a/src/Python/ccpi/fista/FISTAReconstructor.py b/src/Python/ccpi/fista/FISTAReconstructor.py deleted file mode 100644 index 85bfac5..0000000 --- a/src/Python/ccpi/fista/FISTAReconstructor.py +++ /dev/null @@ -1,609 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -#from ccpi.reconstruction.parallelbeam import alg - -#from ccpi.imaging.Regularizer import Regularizer -from enum import Enum - -import astra - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, - **kwargs): - # handle parmeters: - # obligatory parameters - self.pars = dict() - self.pars['projector_geometry'] = projector_geometry # proj_geom - self.pars['output_geometry'] = output_geometry # vol_geom - self.pars['input_sinogram'] = input_sinogram # sino - sliceZ, nangles, detectors = numpy.shape(input_sinogram) - self.pars['detectors'] = detectors - self.pars['number_of_angles'] = nangles - self.pars['SlicesZ'] = sliceZ - self.pars['output_volume'] = None - - print (self.pars) - # handle optional input parameters (at instantiation) - - # Accepted input keywords - kw = ( - # mandatory fields - 'projector_geometry', - 'output_geometry', - 'input_sinogram', - 'detectors', - 'number_of_angles', - 'SlicesZ', - # optional fields - 'number_of_iterations', - 'Lipschitz_constant' , - 'ideal_image' , - 'weights' , - 'region_of_interest' , - 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha', - 'subsets', - 'output_volume', - 'os_subsets', - 'os_indices', - 'os_bins') - self.acceptedInputKeywords = list(kw) - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = \ - numpy.ones(numpy.shape( - self.pars['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = None - - if not 'ideal_image' in kwargs.keys(): - self.pars['ideal_image'] = None - - if not 'region_of_interest'in kwargs.keys() : - if self.pars['ideal_image'] == None: - self.pars['region_of_interest'] = None - else: - ## nonzero if the image is larger than m - fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) - - self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) - - # the regularizer must be a correctly instantiated object - if not 'regularizer' in kwargs.keys() : - self.pars['regularizer'] = None - - #RING REMOVAL - if not 'ring_lambda_R_L1' in kwargs.keys(): - self.pars['ring_lambda_R_L1'] = 0 - if not 'ring_alpha' in kwargs.keys(): - self.pars['ring_alpha'] = 1 - - # ORDERED SUBSET - if not 'subsets' in kwargs.keys(): - self.pars['subsets'] = 0 - else: - self.createOrderedSubsets() - - if not 'initialize' in kwargs.keys(): - self.pars['initialize'] = False - - - - - def setParameter(self, **kwargs): - '''set named parameter for the reconstructor engine - - raises Exception if the named parameter is not recognized - - ''' - for key , value in kwargs.items(): - if key in self.acceptedInputKeywords: - self.pars[key] = value - else: - raise Exception('Wrong parameter {0} for '.format(key) + - 'reconstructor') - # setParameter - - def getParameter(self, key): - if type(key) is str: - if key in self.acceptedInputKeywords: - return self.pars[key] - else: - raise Exception('Unrecongnised parameter: {0} '.format(key) ) - elif type(key) is list: - outpars = [] - for k in key: - outpars.append(self.getParameter(k)) - return outpars - else: - raise Exception('Unhandled input {0}' .format(str(type(key)))) - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - N = self.pars['output_geometry']['GridColCount'] - proj_geom = self.pars['projector_geometry'] - vol_geom = self.pars['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - - - if (proj_geom['type'] == 'parallel') or \ - (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #print('Calculating Lipshitz constant for parallel beam geometry...') - niter = 5;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights[0]) - proj_geomT = proj_geom.copy(); - proj_geomT['DetectorRowCount'] = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - - - for i in range(niter): - # [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); - # s = norm(x1(:)); - # x1 = x1/s; - # [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - # y = sqweight.*y; - # astra_mex_data3d('delete', sino_id); - # astra_mex_data3d('delete', id); - #print ("iteration {0}".format(i)) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geomT, - vol_geomT) - - y = (sqweight * y).copy() # element wise multiplication - - #b=fig.add_subplot(2,1,2) - #imgplot = plt.imshow(x1[0]) - #plt.show() - - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - del x1 - - idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(), - proj_geomT, - vol_geomT) - del y - - - s = numpy.linalg.norm(x1) - ### this line? - x1 = (x1/s).copy(); - - # ### this line? - # sino_id, y = astra.creators.create_sino3d_gpu(x1, - # proj_geomT, - # vol_geomT); - # y = sqweight * y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx) - print ("iteration {0} s= {1}".format(i,s)) - - #end - del proj_geomT - del vol_geomT - #plt.show() - else: - #% divergen beam geometry - print('Calculating Lipshitz constant for divergen beam geometry...') - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - - return s - - - def setRegularizer(self, regularizer): - if regularizer is not None: - self.pars['regularizer'] = regularizer - - - def initialize(self): - # convenience variable storage - proj_geom = self.pars['projector_geometry'] - vol_geom = self.pars['output_geometry'] - sino = self.pars['input_sinogram'] - - # a 'warm start' with SIRT method - # Create a data object for the reconstruction - rec_id = astra.matlab.data3d('create', '-vol', - vol_geom); - - #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); - sinogram_id = astra.matlab.data3d('create', '-proj3d', - proj_geom, - sino) - - sirt_config = astra.astra_dict('SIRT3D_CUDA') - sirt_config['ReconstructionDataId' ] = rec_id - sirt_config['ProjectionDataId'] = sinogram_id - - sirt = astra.algorithm.create(sirt_config) - astra.algorithm.run(sirt, iterations=35) - X = astra.matlab.data3d('get', rec_id) - - # clean up memory - astra.matlab.data3d('delete', rec_id) - astra.matlab.data3d('delete', sinogram_id) - astra.algorithm.delete(sirt) - - - - return X - - def createOrderedSubsets(self, subsets=None): - if subsets is None: - try: - subsets = self.getParameter('subsets') - except Exception(): - subsets = 0 - #return subsets - - angles = self.getParameter('projector_geometry')['ProjectionAngles'] - - #binEdges = numpy.linspace(angles.min(), - # angles.max(), - # subsets + 1) - binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) - # get rearranged subset indices - IndicesReorg = numpy.zeros((numpy.shape(angles))) - counterM = 0 - for ii in range(binsDiscr.max()): - counter = 0 - for jj in range(subsets): - curr_index = ii + jj + counter - #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) - if binsDiscr[jj] > ii: - if (counterM < numpy.size(IndicesReorg)): - IndicesReorg[counterM] = curr_index - counterM = counterM + 1 - - counter = counter + binsDiscr[jj] - 1 - - # store the OS in parameters - self.setParameter(os_subsets=subsets, - os_bins=binsDiscr, - os_indices=IndicesReorg) - - - def prepareForIteration(self): - print ("FISTA Reconstructor: prepare for iteration") - - self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) - self.objective = numpy.zeros((self.pars['number_of_iterations'])) - - #2D array (for 3D data) of sparse "ring" - detectors, nangles, sliceZ = numpy.shape(self.pars['input_sinogram']) - self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) - # another ring variable - self.r_x = self.r.copy() - - self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) - - if self.getParameter('Lipschitz_constant') is None: - self.pars['Lipschitz_constant'] = \ - self.calculateLipschitzConstantWithPowerMethod() - # errors vector (if the ground truth is given) - self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); - # objective function values vector - self.objective = numpy.zeros((self.getParameter('number_of_iterations'))); - - - # prepareForIteration - - def iterate(self, Xin=None): - print ("FISTA Reconstructor: iterate") - - if Xin is None: - if self.getParameter('initialize'): - X = self.initialize() - else: - N = vol_geom['GridColCount'] - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - # copy by reference - X = Xin - # store the output volume in the parameters - self.setParameter(output_volume=X) - X_t = X.copy() - # convenience variable storage - proj_geom , vol_geom, sino , \ - SlicesZ = self.getParameter([ 'projector_geometry' , - 'output_geometry', - 'input_sinogram', - 'SlicesZ' ]) - - t = 1 - - for i in range(self.getParameter('number_of_iterations')): - X_old = X.copy() - t_old = t - r_old = self.r.copy() - if self.getParameter('projector_geometry')['type'] == 'parallel' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat_vec': - # if the geometry is parallel use slice-by-slice - # projection-backprojection routine - #sino_updt = zeros(size(sino),'single'); - proj_geomT = proj_geom.copy() - proj_geomT['DetectorRowCount'] = 1 - vol_geomT = vol_geom.copy() - vol_geomT['GridSliceCount'] = 1; - self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) - for kkk in range(SlicesZ): - sino_id, self.sino_updt[kkk] = \ - astra.creators.create_sino3d_gpu( - X_t[kkk:kkk+1], proj_geomT, vol_geomT) - astra.matlab.data3d('delete', sino_id) - else: - # for divergent 3D geometry (watch the GPU memory overflow in - # ASTRA versions < 1.8) - #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); - sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( - X_t, proj_geom, vol_geom) - - - ## RING REMOVAL - self.ringRemoval(i) - ## Projection/Backprojection Routine - self.projectionBackprojection(X, X_t) - astra.matlab.data3d('delete', sino_id) - ## REGULARIZATION - X = self.regularize(X) - ## Update Loop - X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) - self.setParameter(output_volume=X) - return X - ## iterate - - def ringRemoval(self, i): - print ("FISTA Reconstructor: ring removal") - residual = self.residual - lambdaR_L1 , alpha_ring , weights , L_const , sino= \ - self.getParameter(['ring_lambda_R_L1', - 'ring_alpha' , 'weights', - 'Lipschitz_constant', - 'input_sinogram']) - r_x = self.r_x - sino_updt = self.sino_updt - - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(self.getParameter('input_sinogram')) - if lambdaR_L1 > 0 : - for kkk in range(anglesNumb): - - residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ - ((sino_updt[:,kkk,:]).squeeze() - \ - (sino[:,kkk,:]).squeeze() -\ - (alpha_ring * r_x) - ) - vec = residual.sum(axis = 1) - #if SlicesZ > 1: - # vec = vec[:,1,:].squeeze() - self.r = (r_x - (1./L_const) * vec).copy() - self.objective[i] = (0.5 * (residual ** 2).sum()) - - def projectionBackprojection(self, X, X_t): - print ("FISTA Reconstructor: projection-backprojection routine") - - # a few useful variables - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(self.getParameter('input_sinogram')) - residual = self.residual - proj_geom , vol_geom , L_const = \ - self.getParameter(['projector_geometry' , - 'output_geometry', - 'Lipschitz_constant']) - - - if self.getParameter('projector_geometry')['type'] == 'parallel' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat' or \ - self.getParameter('projector_geometry')['type'] == 'fanflat_vec': - # if the geometry is parallel use slice-by-slice - # projection-backprojection routine - #sino_updt = zeros(size(sino),'single'); - proj_geomT = proj_geom.copy() - proj_geomT['DetectorRowCount'] = 1 - vol_geomT = vol_geom.copy() - vol_geomT['GridSliceCount'] = 1; - x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) - - for kkk in range(SlicesZ): - - x_id, x_temp[kkk] = \ - astra.creators.create_backprojection3d_gpu( - residual[kkk:kkk+1], - proj_geomT, vol_geomT) - astra.matlab.data3d('delete', x_id) - else: - x_id, x_temp = \ - astra.creators.create_backprojection3d_gpu( - residual, proj_geom, vol_geom) - - X = X_t - (1/L_const) * x_temp - #astra.matlab.data3d('delete', sino_id) - astra.matlab.data3d('delete', x_id) - - def regularize(self, X): - print ("FISTA Reconstructor: regularize") - - regularizer = self.getParameter('regularizer') - if regularizer is not None: - return regularizer(input=X) - else: - return X - - def updateLoop(self, i, X, X_old, r_old, t, t_old): - print ("FISTA Reconstructor: update loop") - lambdaR_L1 = self.getParameter('ring_lambda_R_L1') - if lambdaR_L1 > 0: - self.r = numpy.max( - numpy.abs(self.r) - lambdaR_L1 , 0) * \ - numpy.sign(self.r) - t = (1 + numpy.sqrt(1 + 4 * t**2))/2 - X_t = X + (((t_old -1)/t) * (X - X_old)) - - if lambdaR_L1 > 0: - self.r_x = self.r + \ - (((t_old-1)/t) * (self.r - r_old)) - - if self.getParameter('region_of_interest') is None: - string = 'Iteration Number {0} | Objective {1} \n' - print (string.format( i, self.objective[i])) - else: - ROI , X_ideal = fistaRecon.getParameter('region_of_interest', - 'ideal_image') - - Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) - string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' - print (string.format(i,Resid_error[i], self.objective[i])) - return (X , X_t, t) - - def os_iterate(self, Xin=None): - print ("FISTA Reconstructor: iterate") - - if Xin is None: - if self.getParameter('initialize'): - X = self.initialize() - else: - N = vol_geom['GridColCount'] - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - # copy by reference - X = Xin - # store the output volume in the parameters - self.setParameter(output_volume=X) - X_t = X.copy() - - # some useful constants - proj_geom , vol_geom, sino , \ - SlicesZ, weights , alpha_ring , - lambdaR_L1 , L_const = self.getParameter( - ['projector_geometry' , 'output_geometry', - 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , - 'ring_lambda_R_L1', 'Lipschitz_constant']) diff --git a/src/Python/ccpi/fista/Reconstructor.py b/src/Python/ccpi/fista/Reconstructor.py deleted file mode 100644 index d29ac0d..0000000 --- a/src/Python/ccpi/fista/Reconstructor.py +++ /dev/null @@ -1,425 +0,0 @@ -# -*- coding: utf-8 -*- -############################################################################### -#This work is part of the Core Imaging Library developed by -#Visual Analytics and Imaging System Group of the Science Technology -#Facilities Council, STFC -# -#Copyright 2017 Edoardo Pasca, Srikanth Nagella -#Copyright 2017 Daniil Kazantsev -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -#http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. -############################################################################### - - - -import numpy -import h5py -from ccpi.reconstruction.parallelbeam import alg - -from Regularizer import Regularizer -from enum import Enum - -import astra - - - -class FISTAReconstructor(): - '''FISTA-based reconstruction algorithm using ASTRA-toolbox - - ''' - # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> - # ___Input___: - # params.[] file: - # - .proj_geom (geometry of the projector) [required] - # - .vol_geom (geometry of the reconstructed object) [required] - # - .sino (vectorized in 2D or 3D sinogram) [required] - # - .iterFISTA (iterations for the main loop, default 40) - # - .L_const (Lipschitz constant, default Power method) ) - # - .X_ideal (ideal image, if given) - # - .weights (statisitcal weights, size of the sinogram) - # - .ROI (Region-of-interest, only if X_ideal is given) - # - .initialize (a 'warm start' using SIRT method from ASTRA) - #----------------Regularization choices------------------------ - # - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) - # - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) - # - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) - # - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) - # - .Regul_Iterations (iterations for the selected penalty, default 25) - # - .Regul_tauLLT (time step parameter for LLT term) - # - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) - # - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) - #----------------Visualization parameters------------------------ - # - .show (visualize reconstruction 1/0, (0 default)) - # - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) - # - .slice (for 3D volumes - slice number to imshow) - # ___Output___: - # 1. X - reconstructed image/volume - # 2. output - a structure with - # - .Resid_error - residual error (if X_ideal is given) - # - .objective: value of the objective function - # - .L_const: Lipshitz constant to avoid recalculations - - # References: - # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse - # Problems" by A. Beck and M Teboulle - # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo - # 3. "A novel tomographic reconstruction method based on the robust - # Student's t function for suppressing data outliers" D. Kazantsev et.al. - # D. Kazantsev, 2016-17 - def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): - self.params = dict() - self.params['projector_geometry'] = projector_geometry - self.params['output_geometry'] = output_geometry - self.params['input_sinogram'] = input_sinogram - detectors, nangles, sliceZ = numpy.shape(input_sinogram) - self.params['detectors'] = detectors - self.params['number_og_angles'] = nangles - self.params['SlicesZ'] = sliceZ - - # Accepted input keywords - kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , - 'weights' , 'region_of_interest' , 'initialize' , - 'regularizer' , - 'ring_lambda_R_L1', - 'ring_alpha') - - # handle keyworded parameters - if kwargs is not None: - for key, value in kwargs.items(): - if key in kw: - #print("{0} = {1}".format(key, value)) - self.pars[key] = value - - # set the default values for the parameters if not set - if 'number_of_iterations' in kwargs.keys(): - self.pars['number_of_iterations'] = kwargs['number_of_iterations'] - else: - self.pars['number_of_iterations'] = 40 - if 'weights' in kwargs.keys(): - self.pars['weights'] = kwargs['weights'] - else: - self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) - if 'Lipschitz_constant' in kwargs.keys(): - self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] - else: - self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() - - if not self.pars['ideal_image'] in kwargs.keys(): - self.pars['ideal_image'] = None - - if not self.pars['region_of_interest'] : - if self.pars['ideal_image'] == None: - pass - else: - self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) - - if not self.pars['regularizer'] : - self.pars['regularizer'] = None - else: - # the regularizer must be a correctly instantiated object - if not self.pars['ring_lambda_R_L1']: - self.pars['ring_lambda_R_L1'] = 0 - if not self.pars['ring_alpha']: - self.pars['ring_alpha'] = 1 - - - - - def calculateLipschitzConstantWithPowerMethod(self): - ''' using Power method (PM) to establish L constant''' - - #N = params.vol_geom.GridColCount - N = self.pars['output_geometry'].GridColCount - proj_geom = self.params['projector_geometry'] - vol_geom = self.params['output_geometry'] - weights = self.pars['weights'] - SlicesZ = self.pars['SlicesZ'] - - if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): - #% for parallel geometry we can do just one slice - #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); - niter = 15;# % number of iteration for the PM - #N = params.vol_geom.GridColCount; - #x1 = rand(N,N,1); - x1 = numpy.random.rand(1,N,N) - #sqweight = sqrt(weights(:,:,1)); - sqweight = numpy.sqrt(weights.T[0]) - proj_geomT = proj_geom.copy(); - proj_geomT.DetectorRowCount = 1; - vol_geomT = vol_geom.copy(); - vol_geomT['GridSliceCount'] = 1; - - - for i in range(niter): - if i == 0: - #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); - y = sqweight * y # element wise multiplication - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id) - - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); - y = sqweight*y; - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - del proj_geomT - del vol_geomT - else - #% divergen beam geometry - #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); - niter = 8; #% number of iteration for PM - x1 = numpy.random.rand(SlicesZ , N , N); - #sqweight = sqrt(weights); - sqweight = numpy.sqrt(weights.T[0]) - - sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - astra.matlab.data3d('delete', sino_id); - - for i in range(niter): - #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); - idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, - proj_geom, - vol_geom) - s = numpy.linalg.norm(x1) - ### this line? - x1 = x1/s; - ### this line? - #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); - sino_id, y = astra.creators.create_sino3d_gpu(x1, - proj_geom, - vol_geom); - - y = sqweight*y; - #astra_mex_data3d('delete', sino_id); - #astra_mex_data3d('delete', id); - astra.matlab.data3d('delete', sino_id); - astra.matlab.data3d('delete', idx); - #end - #clear x1 - del x1 - - return s - - - def setRegularizer(self, regularizer): - if regularizer - self.pars['regularizer'] = regularizer - - - - - -def getEntry(location): - for item in nx[location].keys(): - print (item) - - -print ("Loading Data") - -##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" -####ind = [i * 1049 for i in range(360)] -#### use only 360 images -##images = 200 -##ind = [int(i * 1049 / images) for i in range(images)] -##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) - -#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" -fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" -nx = h5py.File(fname, "r") - -# the data are stored in a particular location in the hdf5 -for item in nx['entry1/tomo_entry/data'].keys(): - print (item) - -data = nx.get('entry1/tomo_entry/data/rotation_angle') -angles = numpy.zeros(data.shape) -data.read_direct(angles) -print (angles) -# angles should be in degrees - -data = nx.get('entry1/tomo_entry/data/data') -stack = numpy.zeros(data.shape) -data.read_direct(stack) -print (data.shape) - -print ("Data Loaded") - - -# Normalize -data = nx.get('entry1/tomo_entry/instrument/detector/image_key') -itype = numpy.zeros(data.shape) -data.read_direct(itype) -# 2 is dark field -darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] -dark = darks[0] -for i in range(1, len(darks)): - dark += darks[i] -dark = dark / len(darks) -#dark[0][0] = dark[0][1] - -# 1 is flat field -flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] -flat = flats[0] -for i in range(1, len(flats)): - flat += flats[i] -flat = flat / len(flats) -#flat[0][0] = dark[0][1] - - -# 0 is projection data -proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] -angle_proj = numpy.asarray (angle_proj) -angle_proj = angle_proj.astype(numpy.float32) - -# normalized data are -# norm = (projection - dark)/(flat-dark) - -def normalize(projection, dark, flat, def_val=0.1): - a = (projection - dark) - b = (flat-dark) - with numpy.errstate(divide='ignore', invalid='ignore'): - c = numpy.true_divide( a, b ) - c[ ~ numpy.isfinite( c )] = def_val # set to not zero if 0/0 - return c - - -norm = [normalize(projection, dark, flat) for projection in proj] -norm = numpy.asarray (norm) -norm = norm.astype(numpy.float32) - -#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) - -#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, -# angles = angle_proj, center_of_rotation = 86.2 , -# flat_field = flat, dark_field = dark, -# iterations = 15, resolution = 1, isLogScale = False, threads = 3) -#img_cgls = recon.reconstruct() -# -#pars = dict() -#pars['algorithm'] = Reconstructor.Algorithm.SIRT -#pars['projection_data'] = proj -#pars['angles'] = angle_proj -#pars['center_of_rotation'] = numpy.double(86.2) -#pars['flat_field'] = flat -#pars['iterations'] = 15 -#pars['dark_field'] = dark -#pars['resolution'] = 1 -#pars['isLogScale'] = False -#pars['threads'] = 3 -# -#img_sirt = recon.reconstruct(pars) -# -#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM -#img_mlem = recon.reconstruct() - -############################################################ -############################################################ -#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV -#recon.pars['regularize'] = numpy.double(0.1) -#img_cgls_conv = recon.reconstruct() - -niterations = 15 -threads = 3 - -img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) -img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - iteration_values, False) -print ("iteration values %s" % str(iteration_values)) - -iteration_values = numpy.zeros((niterations,)) -img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) -iteration_values = numpy.zeros((niterations,)) -img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, - numpy.double(1e-5), iteration_values , False) -print ("iteration values %s" % str(iteration_values)) - - -##numpy.save("cgls_recon.npy", img_data) -import matplotlib.pyplot as plt -fig, ax = plt.subplots(1,6,sharey=True) -ax[0].imshow(img_cgls[80]) -ax[0].axis('off') # clear x- and y-axes -ax[1].imshow(img_sirt[80]) -ax[1].axis('off') # clear x- and y-axes -ax[2].imshow(img_mlem[80]) -ax[2].axis('off') # clear x- and y-axesplt.show() -ax[3].imshow(img_cgls_conv[80]) -ax[3].axis('off') # clear x- and y-axesplt.show() -ax[4].imshow(img_cgls_tikhonov[80]) -ax[4].axis('off') # clear x- and y-axesplt.show() -ax[5].imshow(img_cgls_TVreg[80]) -ax[5].axis('off') # clear x- and y-axesplt.show() - - -plt.show() - -#viewer = edo.CILViewer() -#viewer.setInputAsNumpy(img_cgls2) -#viewer.displaySliceActor(0) -#viewer.startRenderLoop() - -import vtk - -def NumpyToVTKImageData(numpyarray): - if (len(numpy.shape(numpyarray)) == 3): - doubleImg = vtk.vtkImageData() - shape = numpy.shape(numpyarray) - doubleImg.SetDimensions(shape[0], shape[1], shape[2]) - doubleImg.SetOrigin(0,0,0) - doubleImg.SetSpacing(1,1,1) - doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) - #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) - doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) - - for i in range(shape[0]): - for j in range(shape[1]): - for k in range(shape[2]): - doubleImg.SetScalarComponentFromDouble( - i,j,k,0, numpyarray[i][j][k]) - #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) - # rescale to appropriate VTK_UNSIGNED_SHORT - stats = vtk.vtkImageAccumulate() - stats.SetInputData(doubleImg) - stats.Update() - iMin = stats.GetMin()[0] - iMax = stats.GetMax()[0] - scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) - - shiftScaler = vtk.vtkImageShiftScale () - shiftScaler.SetInputData(doubleImg) - shiftScaler.SetScale(scale) - shiftScaler.SetShift(iMin) - shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) - shiftScaler.Update() - return shiftScaler.GetOutput() - -#writer = vtk.vtkMetaImageWriter() -#writer.SetFileName(alg + "_recon.mha") -#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) -#writer.Write() diff --git a/src/Python/ccpi/fista/__init__.py b/src/Python/ccpi/fista/__init__.py deleted file mode 100644 index e69de29..0000000 -- cgit v1.2.3 From 57ccd56cacb3c437f706324f330af40a3a715f18 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 12:59:12 +0100 Subject: added targets and cache variables --- src/Python/CMakeLists.txt | 48 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt index c5e14ea..1399c71 100644 --- a/src/Python/CMakeLists.txt +++ b/src/Python/CMakeLists.txt @@ -64,6 +64,16 @@ message("Python found Minor " ${PYTHON_VERSION_MINOR}) findPythonPackagesPath() message("PYTHON_PACKAGES_FOUND " ${PYTHON_PACKAGES_PATH}) +## CACHE SOME VARIABLES ## +set (CONDA_ENVIRONMENT ${CONDA_ENVIRONMENT} CACHE INTERNAL "active conda environment" FORCE) +set (CONDA_ENVIRONMENT_PATH ${CONDA_ENVIRONMENT_PATH} CACHE INTERNAL "active conda environment" FORCE) + +set (PYTHON_VERSION_STRING ${PYTHON_VERSION_STRING} CACHE INTERNAL "conda environment Python version string" FORCE) +set (PYTHON_VERSION_MAJOR ${PYTHON_VERSION_MAJOR} CACHE INTERNAL "conda environment Python version major" FORCE) +set (PYTHON_VERSION_MINOR ${PYTHON_VERSION_MINOR} CACHE INTERNAL "conda environment Python version minor" FORCE) +set (PYTHON_VERSION_PATCH ${PYTHON_VERSION_PATCH} CACHE INTERNAL "conda environment Python version patch" FORCE) +set (PYTHON_PACKAGES_PATH ${PYTHON_PACKAGES_PATH} CACHE INTERNAL "conda environment Python packages path" FORCE) + if (WIN32) #set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}\\Library" CACHE PATH "env dir") @@ -123,7 +133,43 @@ elseif(UNIX) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.sh) endif() - +############################# TARGETS + +########################## REGULARIZER PACKAGE ############################### + +# runs cmake on the build tree to update the code from source +add_custom_target(update_code + COMMAND ${CMAKE_COMMAND} + ARGS ${CMAKE_SOURCE_DIR} + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + ) + + +add_custom_target(fista + COMMAND bash + compile-fista.sh + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${update_code} + ) + +add_custom_target(regularizers + COMMAND bash + compile.sh + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${update_code} + ) + +add_custom_target(install-fista + COMMAND conda + install --force --use-local ccpi-fista=${CIL_VERSION} -c ccpi -c conda-forge + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${fista}) + +add_custom_target(install-regularizers + COMMAND conda + install --force --use-local ccpi-regularizers=${CIL_VERSION} -c ccpi -c conda-forge + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${fista}) ### add tests #add_executable(RegularizersTest ) -- cgit v1.2.3 From cf741b21f5a66d4b6157bef401a8ca240d8702b8 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 12:59:53 +0100 Subject: fix wrong indentation --- src/Python/ccpi/reconstruction/FISTAReconstructor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index 85bfac5..f43966c 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -602,7 +602,7 @@ class FISTAReconstructor(): # some useful constants proj_geom , vol_geom, sino , \ - SlicesZ, weights , alpha_ring , + SlicesZ, weights , alpha_ring ,\ lambdaR_L1 , L_const = self.getParameter( ['projector_geometry' , 'output_geometry', 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha' , -- cgit v1.2.3 From ac4408e8984be8ca23a46b2b75bb243a0a4720aa Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 13:00:23 +0100 Subject: Remove dependency on astra-toolbox in meta.yaml this just because otherwise conda wants to install a version from astra-toolbox which I'm not sure works for us. --- src/Python/fista-recipe/meta.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/Python/fista-recipe/meta.yaml b/src/Python/fista-recipe/meta.yaml index 89bf597..265541f 100644 --- a/src/Python/fista-recipe/meta.yaml +++ b/src/Python/fista-recipe/meta.yaml @@ -18,7 +18,7 @@ requirements: run: - python - numpy - - astra-toolbox + #- astra-toolbox - ccpi-regularizers -- cgit v1.2.3 From bb4f7dc7e3a3bf4b4da18e36a2fc69e2195c5a96 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 15:47:40 +0100 Subject: moved to test directory --- src/Python/test/test_reconstructor-os.py | 338 +++++++++++++++++++++++++++++++ src/Python/test_reconstructor-os.py | 329 ------------------------------ 2 files changed, 338 insertions(+), 329 deletions(-) create mode 100644 src/Python/test/test_reconstructor-os.py delete mode 100644 src/Python/test_reconstructor-os.py (limited to 'src') diff --git a/src/Python/test/test_reconstructor-os.py b/src/Python/test/test_reconstructor-os.py new file mode 100644 index 0000000..a36feda --- /dev/null +++ b/src/Python/test/test_reconstructor-os.py @@ -0,0 +1,338 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor +import astra +import matplotlib.pyplot as plt + +def RMSE(signal1, signal2): + '''RMSE Root Mean Squared Error''' + if numpy.shape(signal1) == numpy.shape(signal2): + err = (signal1 - signal2) + err = numpy.sum( err * err )/numpy.size(signal1); # MSE + err = sqrt(err); # RMSE + return err + else: + raise Exception('Input signals must have the same shape') + +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") +Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 20 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', + detectorSpacingX, + detectorSpacingY, + det_row_count, + det_col_count, + angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, + image_size_y, + image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) + +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +fistaRecon.setParameter(number_of_iterations = 12) +fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +fistaRecon.setParameter(ring_alpha = 21) +fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + + +## Ordered subset +if True: + subsets = 16 + fistaRecon.setParameter(subsets=subsets) + fistaRecon.createOrderedSubsets() + angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] + #binEdges = numpy.linspace(angles.min(), + # angles.max(), + # subsets + 1) + binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) + # get rearranged subset indices + IndicesReorg = numpy.zeros((numpy.shape(angles))) + counterM = 0 + for ii in range(binsDiscr.max()): + counter = 0 + for jj in range(subsets): + curr_index = ii + jj + counter + #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) + if binsDiscr[jj] > ii: + if (counterM < numpy.size(IndicesReorg)): + IndicesReorg[counterM] = curr_index + counterM = counterM + 1 + + counter = counter + binsDiscr[jj] - 1 + + +if True: + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + print ("prepare for iteration") + fistaRecon.prepareForIteration() + + + + print("initializing ...") + if False: + # if X doesn't exist + #N = params.vol_geom.GridColCount + N = vol_geom['GridColCount'] + print ("N " + str(N)) + X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) + else: + #X = fistaRecon.initialize() + X = numpy.load("X.npy") + + print (numpy.shape(X)) + X_t = X.copy() + print ("initialized") + proj_geom , vol_geom, sino , \ + SlicesZ, weights , alpha_ring = fistaRecon.getParameter( + ['projector_geometry' , 'output_geometry', + 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha']) + lambdaR_L1 , alpha_ring , weights , L_const= \ + fistaRecon.getParameter(['ring_lambda_R_L1', + 'ring_alpha' , 'weights', + 'Lipschitz_constant']) + + #fistaRecon.setParameter(number_of_iterations = 3) + iterFISTA = fistaRecon.getParameter('number_of_iterations') + # errors vector (if the ground truth is given) + Resid_error = numpy.zeros((iterFISTA)); + # objective function values vector + objective = numpy.zeros((iterFISTA)); + + + t = 1 + + + ## additional for + proj_geomSUB = proj_geom.copy() + fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram'])) + residual2 = fistaRecon.residual2 + sino_updt_FULL = fistaRecon.residual.copy() + r_x = fistaRecon.r.copy() + + print ("starting iterations") +## % Outer FISTA iterations loop + for i in range(fistaRecon.getParameter('number_of_iterations')): +## % With OS approach it becomes trickier to correlate independent subsets, hence additional work is required +## % one solution is to work with a full sinogram at times +## if ((i >= 3) && (lambdaR_L1 > 0)) +## [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X, proj_geom, vol_geom); +## astra_mex_data3d('delete', sino_id2); +## end + # With OS approach it becomes trickier to correlate independent subsets, + # hence additional work is required one solution is to work with a full + # sinogram at times + + r_old = fistaRecon.r.copy() + t_old = t + SlicesZ, anglesNumb, Detectors = \ + numpy.shape(fistaRecon.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 + if (i > 1 and lambdaR_L1 > 0) : + for kkk in range(anglesNumb): + + residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ + ((sino_updt_FULL[:,kkk,:]).squeeze() - \ + (sino[:,kkk,:]).squeeze() -\ + (alpha_ring * r_x) + ) + + vec = fistaRecon.residual.sum(axis = 1) + #if SlicesZ > 1: + # vec = vec[:,1,:] # 1 or 0? + r_x = fistaRecon.r_x + fistaRecon.r = (r_x - (1./L_const) * vec).copy() + + # subset loop + counterInd = 1 + geometry_type = fistaRecon.getParameter('projector_geometry')['type'] + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + + for kkk in range(SlicesZ): + sino_id, sinoT[kkk] = \ + astra.creators.create_sino3d_gpu( + X_t[kkk:kkk+1], proj_geomSUB, vol_geom) + sino_updt_Sub[kkk] = sinoT.T.copy() + + else: + sino_id, sino_updt_Sub = \ + astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom) + + astra.matlab.data3d('delete', sino_id) + + for ss in range(fistaRecon.getParameter('subsets')): + print ("Subset {0}".format(ss)) + X_old = X.copy() + t_old = t + + # the number of projections per subset + numProjSub = fistaRecon.getParameter('os_bins')[ss] + CurrSubIndices = fistaRecon.getParameter('os_indices')\ + [counterInd:counterInd+numProjSub-1] + mask = numpy.zeros(numpy.shape(angles), dtype=bool) + cc = 0 + for i in range(len(CurrSubIndices)): + mask[int(CurrSubIndices[i])] = True + proj_geomSUB['ProjectionAngles'] = angles[mask] + + shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) + shape[1] = numProjSub + sino_updt_Sub = numpy.zeros(shape) + + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + + for kkk in range(SlicesZ): + sino_id, sinoT = astra.creators.create_sino3d_gpu ( + X_t[kkk:kkk+1] , proj_geomSUB, vol_geom) + sino_updt_Sub[kkk] = sinoT.T.copy() + + else: + # for 3D geometry (watch the GPU memory overflow in ASTRA < 1.8) + sino_id, sino_updt_Sub = \ + astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom) + + astra.matlab.data3d('delete', sino_id) + + + + + ## RING REMOVAL + residual = fistaRecon.residual + + + if lambdaR_L1 > 0 : + print ("ring removal") + residualSub = numpy.zeros(shape) + ## for a chosen subset + ## for kkk = 1:numProjSub + ## indC = CurrSubIndeces(kkk); + ## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); + ## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram + ## end + for kkk in range(numProjSub): + print ("ring removal indC ... {0}".format(kkk)) + indC = int(CurrSubIndices[kkk]) + residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ + (sino_updt_Sub[:,kkk,:].squeeze() - \ + sino[:,indC,:].squeeze() - alpha_ring * r_x) + # filling the full sinogram + sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() + + else: + #PWLS model + residualSub = weights[:,CurrSubIndices,:] * \ + ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) + objective[i] = 0.5 * numpy.linalg.norm(residualSub) + + if geometry_type == 'parallel' or \ + geometry_type == 'fanflat' or \ + geometry_type == 'fanflat_vec' : + # if geometry is 2D use slice-by-slice projection-backprojection + # routine + x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) + for kkk in range(SlicesZ): + + x_id, x_temp[kkk] = \ + astra.creators.create_backprojection3d_gpu( + residualSub[kkk:kkk+1], + proj_geomSUB, vol_geom) + + else: + x_id, x_temp = \ + astra.creators.create_backprojection3d_gpu( + residualSub, proj_geomSUB, vol_geom) + + astra.matlab.data3d('delete', x_id) + X = X_t - (1/L_const) * x_temp + + + + ## REGULARIZATION + ## SKIPPING FOR NOW + ## Should be simpli + # regularizer = fistaRecon.getParameter('regularizer') + # for slices: + # out = regularizer(input=X) + print ("skipping regularizer") + + + ## FINAL + print ("final") + lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') + if lambdaR_L1 > 0: + fistaRecon.r = numpy.max( + numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ + numpy.sign(fistaRecon.r) + # updating r + r_x = fistaRecon.r + ((t_old-1)/t) * (fistaRecon.r - r_old) + + + if fistaRecon.getParameter('region_of_interest') is None: + string = 'Iteration Number {0} | Objective {1} \n' + print (string.format( i, objective[i])) + else: + ROI , X_ideal = fistaRecon.getParameter('region_of_interest', + 'ideal_image') + + Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) + string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' + print (string.format(i,Resid_error[i], objective[i])) + + +else: + fistaRecon = FISTAReconstructor(proj_geom, + vol_geom, + Sino3D , + weights=Weights3D) + + print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) + fistaRecon.setParameter(number_of_iterations = 12) + fistaRecon.setParameter(Lipschitz_constant = 767893952.0) + fistaRecon.setParameter(ring_alpha = 21) + fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + fistaRecon.prepareForIteration() + X = fistaRecon.iterate(numpy.load("X.npy")) diff --git a/src/Python/test_reconstructor-os.py b/src/Python/test_reconstructor-os.py deleted file mode 100644 index aee70a4..0000000 --- a/src/Python/test_reconstructor-os.py +++ /dev/null @@ -1,329 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Wed Aug 23 16:34:49 2017 - -@author: ofn77899 -Based on DemoRD2.m -""" - -import h5py -import numpy - -from ccpi.fista.FISTAReconstructor import FISTAReconstructor -import astra -import matplotlib.pyplot as plt - -def RMSE(signal1, signal2): - '''RMSE Root Mean Squared Error''' - if numpy.shape(signal1) == numpy.shape(signal2): - err = (signal1 - signal2) - err = numpy.sum( err * err )/numpy.size(signal1); # MSE - err = sqrt(err); # RMSE - return err - else: - raise Exception('Input signals must have the same shape') - -filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' -nx = h5py.File(filename, "r") -#getEntry(nx, '/') -# I have exported the entries as children of / -entries = [entry for entry in nx['/'].keys()] -print (entries) - -Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") -Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") -angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] -angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") -recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] -size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] -slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] - -Z_slices = 20 -det_row_count = Z_slices -# next definition is just for consistency of naming -det_col_count = size_det - -detectorSpacingX = 1.0 -detectorSpacingY = detectorSpacingX - - -proj_geom = astra.creators.create_proj_geom('parallel3d', - detectorSpacingX, - detectorSpacingY, - det_row_count, - det_col_count, - angles_rad) - -#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); -image_size_x = recon_size -image_size_y = recon_size -image_size_z = Z_slices -vol_geom = astra.creators.create_vol_geom( image_size_x, - image_size_y, - image_size_z) - -## First pass the arguments to the FISTAReconstructor and test the -## Lipschitz constant - -fistaRecon = FISTAReconstructor(proj_geom, - vol_geom, - Sino3D , - weights=Weights3D) - -print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) -fistaRecon.setParameter(number_of_iterations = 12) -fistaRecon.setParameter(Lipschitz_constant = 767893952.0) -fistaRecon.setParameter(ring_alpha = 21) -fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) - -## Ordered subset -if True: - subsets = 16 - angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] - #binEdges = numpy.linspace(angles.min(), - # angles.max(), - # subsets + 1) - binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) - # get rearranged subset indices - IndicesReorg = numpy.zeros((numpy.shape(angles))) - counterM = 0 - for ii in range(binsDiscr.max()): - counter = 0 - for jj in range(subsets): - curr_index = ii + jj + counter - #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) - if binsDiscr[jj] > ii: - if (counterM < numpy.size(IndicesReorg)): - IndicesReorg[counterM] = curr_index - counterM = counterM + 1 - - counter = counter + binsDiscr[jj] - 1 - - -if True: - print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) - print ("prepare for iteration") - fistaRecon.prepareForIteration() - - - - print("initializing ...") - if False: - # if X doesn't exist - #N = params.vol_geom.GridColCount - N = vol_geom['GridColCount'] - print ("N " + str(N)) - X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) - else: - #X = fistaRecon.initialize() - X = numpy.load("X.npy") - - print (numpy.shape(X)) - X_t = X.copy() - print ("initialized") - proj_geom , vol_geom, sino , \ - SlicesZ, weights , alpha_ring = fistaRecon.getParameter( - ['projector_geometry' , 'output_geometry', - 'input_sinogram', 'SlicesZ' , 'weights', 'ring_alpha']) - lambdaR_L1 , alpha_ring , weights , L_const= \ - fistaRecon.getParameter(['ring_lambda_R_L1', - 'ring_alpha' , 'weights', - 'Lipschitz_constant']) - - #fistaRecon.setParameter(number_of_iterations = 3) - iterFISTA = fistaRecon.getParameter('number_of_iterations') - # errors vector (if the ground truth is given) - Resid_error = numpy.zeros((iterFISTA)); - # objective function values vector - objective = numpy.zeros((iterFISTA)); - - - t = 1 - - - ## additional for - proj_geomSUB = proj_geom.copy() - fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram'])) - residual2 = fistaRecon.residual2 - sino_updt_FULL = fistaRecon.residual.copy() - - print ("starting iterations") -## % Outer FISTA iterations loop - for i in range(fistaRecon.getParameter('number_of_iterations')): -## % With OS approach it becomes trickier to correlate independent subsets, hence additional work is required -## % one solution is to work with a full sinogram at times -## if ((i >= 3) && (lambdaR_L1 > 0)) -## [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X, proj_geom, vol_geom); -## astra_mex_data3d('delete', sino_id2); -## end - # With OS approach it becomes trickier to correlate independent subsets, - # hence additional work is required one solution is to work with a full - # sinogram at times - - r_old = fistaRecon.r.copy() - t_old = t - SlicesZ, anglesNumb, Detectors = \ - numpy.shape(fistaRecon.getParameter('input_sinogram')) ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 - if (i > 1 and lambdaR_L1 > 0) : - for kkk in range(anglesNumb): - - residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ - ((sino_updt_FULL[:,kkk,:]).squeeze() - \ - (sino[:,kkk,:]).squeeze() -\ - (alpha_ring * r_x) - ) - - vec = fistaRecon.residual.sum(axis = 1) - #if SlicesZ > 1: - # vec = vec[:,1,:] # 1 or 0? - r_x = fistaRecon.r_x - fistaRecon.r = (r_x - (1./L_const) * vec).copy() - - # subset loop - counterInd = 1 - geometry_type = fistaRecon.getParameter('projector_geometry')['type'] - if geometry_type == 'parallel' or \ - geometry_type == 'fanflat' or \ - geometry_type == 'fanflat_vec' : - - for kkk in range(SlicesZ): - sino_id, sinoT[kkk] = \ - astra.creators.create_sino3d_gpu( - X_t[kkk:kkk+1], proj_geomSUB, vol_geom) - sino_updt_Sub[kkk] = sinoT.T.copy() - - else: - sino_id, sino_updt_Sub = \ - astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom) - - astra.matlab.data3d('delete', sino_id) - - for ss in range(fistaRecon.getParameter('subsets')): - print ("Subset {0}".format(ss)) - X_old = X.copy() - t_old = t - - # the number of projections per subset - numProjSub = fistaRecon.getParameter('os_bins')[ss] - CurrSubIndices = fistaRecon.getParameter('os_indices')\ - [counterInd:counterInd+numProjSub-1] - proj_geomSUB['ProjectionAngles'] = angles[CurrSubIndeces] - - shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) - shape[1] = numProjSub - sino_updt_Sub = numpy.zeros(shape) - - if geometry_type == 'parallel' or \ - geometry_type == 'fanflat' or \ - geometry_type == 'fanflat_vec' : - - for kkk in range(SlicesZ): - sino_id, sinoT = astra.creators.create_sino3d_gpu ( - X_t[kkk:kkk+1] , proj_geomSUB, vol_geom) - sino_updt_Sub[kkk] = sinoT.T.copy() - - else: - # for 3D geometry (watch the GPU memory overflow in ASTRA < 1.8) - sino_id, sino_updt_Sub = \ - astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom) - - astra.matlab.data3d('delete', sino_id) - - - - - ## RING REMOVAL - residual = fistaRecon.residual - - - if lambdaR_L1 > 0 : - print ("ring removal") - residualSub = numpy.zeros(shape) - ## for a chosen subset - ## for kkk = 1:numProjSub - ## indC = CurrSubIndeces(kkk); - ## residualSub(:,kkk,:) = squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); - ## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram - ## end - for kkk in range(numProjSub): - indC = CurrSubIndices[kkk] - residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ - (sino_updt_Sub[:,kkk,:].squeeze() - \ - sino[:,indC,:].squeeze() - alpha_ring * r_x) - # filling the full sinogram - sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() - - else: - #PWLS model - residualSub = weights[:,CurrSubIndices,:] * \ - ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) - objective[i] = 0.5 * numpy.linalg.norm(residualSub) - - if geometry_type == 'parallel' or \ - geometry_type == 'fanflat' or \ - geometry_type == 'fanflat_vec' : - # if geometry is 2D use slice-by-slice projection-backprojection - # routine - x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) - for kkk in range(SlicesZ): - - x_id, x_temp[kkk] = \ - astra.creators.create_backprojection3d_gpu( - residualSub[kkk:kkk+1], - proj_geomSUB, vol_geom) - - else: - x_id, x_temp = \ - astra.creators.create_backprojection3d_gpu( - residualSub, proj_geomSUB, vol_geom) - - astra.matlab.data3d('delete', x_id) - X = X_t - (1/L_const) * x_temp - - - - ## REGULARIZATION - ## SKIPPING FOR NOW - ## Should be simpli - # regularizer = fistaRecon.getParameter('regularizer') - # for slices: - # out = regularizer(input=X) - print ("skipping regularizer") - - - ## FINAL - print ("final") - lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') - if lambdaR_L1 > 0: - fistaRecon.r = numpy.max( - numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ - numpy.sign(fistaRecon.r) - # updating r - r_x = fistaRecon.r + ((t_old-1)/t) * (fistaRecon.r - r_old) - - - if fistaRecon.getParameter('region_of_interest') is None: - string = 'Iteration Number {0} | Objective {1} \n' - print (string.format( i, objective[i])) - else: - ROI , X_ideal = fistaRecon.getParameter('region_of_interest', - 'ideal_image') - - Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) - string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' - print (string.format(i,Resid_error[i], objective[i])) - - -else: - fistaRecon = FISTAReconstructor(proj_geom, - vol_geom, - Sino3D , - weights=Weights3D) - - print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) - fistaRecon.setParameter(number_of_iterations = 12) - fistaRecon.setParameter(Lipschitz_constant = 767893952.0) - fistaRecon.setParameter(ring_alpha = 21) - fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) - fistaRecon.prepareForIteration() - X = fistaRecon.iterate(numpy.load("X.npy")) -- cgit v1.2.3 From 7f6e90ed9569e6f935813d8ceb6b3c00feed3bc0 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 15:48:06 +0100 Subject: saves to file --- src/Python/test/test_reconstructor.py | 1 + 1 file changed, 1 insertion(+) (limited to 'src') diff --git a/src/Python/test/test_reconstructor.py b/src/Python/test/test_reconstructor.py index 377f3cf..be28624 100644 --- a/src/Python/test/test_reconstructor.py +++ b/src/Python/test/test_reconstructor.py @@ -299,3 +299,4 @@ else: fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) fistaRecon.prepareForIteration() X = fistaRecon.iterate(numpy.load("X.npy")) + numpy.save("X_out.npy", X) -- cgit v1.2.3 From 455ca86825c157512f61441d3d27b8148ca795a7 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Tue, 24 Oct 2017 16:37:21 +0100 Subject: Add regularization step Add regularization step OS seems to work --- .../ccpi/reconstruction/FISTAReconstructor.py | 5 ++++- src/Python/test/test_reconstructor-os.py | 22 ++++++++++++++++------ src/Python/test/test_reconstructor.py | 7 +++++++ 3 files changed, 27 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py index f43966c..c903712 100644 --- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -363,6 +363,9 @@ class FISTAReconstructor(): except Exception(): subsets = 0 #return subsets + else: + self.setParameter(subsets=subsets) + angles = self.getParameter('projector_geometry')['ProjectionAngles'] @@ -371,7 +374,7 @@ class FISTAReconstructor(): # subsets + 1) binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) # get rearranged subset indices - IndicesReorg = numpy.zeros((numpy.shape(angles))) + IndicesReorg = numpy.zeros((numpy.shape(angles)), dtype=numpy.int32) counterM = 0 for ii in range(binsDiscr.max()): counter = 0 diff --git a/src/Python/test/test_reconstructor-os.py b/src/Python/test/test_reconstructor-os.py index a36feda..6c82ae0 100644 --- a/src/Python/test/test_reconstructor-os.py +++ b/src/Python/test/test_reconstructor-os.py @@ -12,6 +12,7 @@ import numpy from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor import astra import matplotlib.pyplot as plt +from ccpi.imaging.Regularizer import Regularizer def RMSE(signal1, signal2): '''RMSE Root Mean Squared Error''' @@ -77,6 +78,12 @@ fistaRecon.setParameter(ring_alpha = 21) fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) +reg = Regularizer(Regularizer.Algorithm.LLT_model) +reg.setParameter(regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) + ## Ordered subset if True: subsets = 16 @@ -210,11 +217,12 @@ if True: # the number of projections per subset numProjSub = fistaRecon.getParameter('os_bins')[ss] CurrSubIndices = fistaRecon.getParameter('os_indices')\ - [counterInd:counterInd+numProjSub-1] + [counterInd:counterInd+numProjSub] + #print ("Len CurrSubIndices {0}".format(numProjSub)) mask = numpy.zeros(numpy.shape(angles), dtype=bool) cc = 0 - for i in range(len(CurrSubIndices)): - mask[int(CurrSubIndices[i])] = True + for j in range(len(CurrSubIndices)): + mask[int(CurrSubIndices[j])] = True proj_geomSUB['ProjectionAngles'] = angles[mask] shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) @@ -254,7 +262,7 @@ if True: ## sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram ## end for kkk in range(numProjSub): - print ("ring removal indC ... {0}".format(kkk)) + #print ("ring removal indC ... {0}".format(kkk)) indC = int(CurrSubIndices[kkk]) residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ (sino_updt_Sub[:,kkk,:].squeeze() - \ @@ -297,7 +305,8 @@ if True: # regularizer = fistaRecon.getParameter('regularizer') # for slices: # out = regularizer(input=X) - print ("skipping regularizer") + print ("regularizer") + #X = reg(input=X) ## FINAL @@ -321,7 +330,8 @@ if True: Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' print (string.format(i,Resid_error[i], objective[i])) - + + numpy.save("X_out_os.npy", X) else: fistaRecon = FISTAReconstructor(proj_geom, diff --git a/src/Python/test/test_reconstructor.py b/src/Python/test/test_reconstructor.py index be28624..3342301 100644 --- a/src/Python/test/test_reconstructor.py +++ b/src/Python/test/test_reconstructor.py @@ -76,6 +76,13 @@ fistaRecon.setParameter(Lipschitz_constant = 767893952.0) fistaRecon.setParameter(ring_alpha = 21) fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) +reg = Regularizer(Regularizer.Algorithm.LLT_model) +reg.setParameter(regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) +fistaRecon.setParameter(regularizer = reg) + ## Ordered subset if False: subsets = 16 -- cgit v1.2.3 From 31097954f87d0f30f667b29a12f7098710c284ab Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 25 Oct 2017 10:48:48 +0100 Subject: Use MATCH and CMAKE_MATCH_ --- src/Python/CMakeLists.txt | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) (limited to 'src') diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt index e539eda..b464059 100644 --- a/src/Python/CMakeLists.txt +++ b/src/Python/CMakeLists.txt @@ -36,17 +36,11 @@ execute_process(COMMAND "conda" "env" "list" foreach(line ${ENV_LIST}) string(REGEX MATCHALL "(.+)[*](.+)" match ${line}) if (NOT ${match} EQUAL "") - string(REPLACE "*" ";" ENV_DIR ${match}) - list (APPEND cc "") - foreach(conda ${ENV_DIR}) - string(STRIP ${conda} stripped) - list(APPEND cc ${stripped}) - endforeach() - list(LENGTH cc Ns) - if (${Ns} EQUAL 2) - list(GET cc 0 CONDA_ENVIRONMENT) - list(GET cc 1 CONDA_ENVIRONMENT_PATH) - endif() + #message("MATCHED " ${CMAKE_MATCH_0}) + #message("MATCHED " ${CMAKE_MATCH_1}) + #message("MATCHED " ${CMAKE_MATCH_2}) + string(STRIP ${CMAKE_MATCH_1} CONDA_ENVIRONMENT) + string(STRIP ${CMAKE_MATCH_2} CONDA_ENVIRONMENT_PATH) endif() endforeach() else() @@ -58,7 +52,7 @@ if (${CONDA_ENVIRONMENT} AND ${CONDA_ENVIRONMENT_PATH}) else() message("**********************************************************") message("Using current conda environmnet " ${CONDA_ENVIRONMENT}) - message("Using current conda environmnet path" ${CONDA_ENVIRONMENT_PATH}) + message("Using current conda environmnet path " ${CONDA_ENVIRONMENT_PATH}) endif() -- cgit v1.2.3 From 7033c692bd836ca300fc6d91ce7cd733a2342cde Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 25 Oct 2017 12:25:59 +0100 Subject: Added print to screen to debug faulty FGP_TV Bug #2 came from not initializing lambda but mu. --- src/Python/fista_module.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index c36329e..f8fd812 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -338,7 +338,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me A = reinterpret_cast(input.get_data()); //mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ - mu = (float)d_mu; + lambda = (float)d_mu; //iter = 35; /* default iterations number */ @@ -408,15 +408,17 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me /* begin iterations */ for (ll = 0; ll 3) { Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + std::cout << "Obj_func2D D[0] " << D[0]<< " A[0]" << A[0] << " R1[0] " << R1[0] << " R2[0] " << R2[0] << " lambda " << lambda << " dimX " << dimX << " dimY " << dimY << std::endl; funcval = 0.0f; for (j = 0; j 2) { if (re > re_old) { Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); + std::cout << "Obj_func2D D[0] " << D[0]<< " A[0]" << A[0] << " R1[0] " << R1[0] << " R2[0] " << R2[0] << " lambda " << lambda << " dimX " << dimX << " dimY " << dimY << std::endl; funcval = 0.0f; for (j = 0; j Date: Wed, 25 Oct 2017 12:28:01 +0100 Subject: removed print to screen Fixed #2 --- src/Python/fista_module.cpp | 6 ------ 1 file changed, 6 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index f8fd812..94d156c 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -408,14 +408,11 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me /* begin iterations */ for (ll = 0; ll 3) { Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); - std::cout << "Obj_func2D D[0] " << D[0]<< " A[0]" << A[0] << " R1[0] " << R1[0] << " R2[0] " << R2[0] << " lambda " << lambda << " dimX " << dimX << " dimY " << dimY << std::endl; funcval = 0.0f; for (j = 0; j 2) { if (re > re_old) { Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); - std::cout << "Obj_func2D D[0] " << D[0]<< " A[0]" << A[0] << " R1[0] " << R1[0] << " R2[0] " << R2[0] << " lambda " << lambda << " dimX " << dimX << " dimY " << dimY << std::endl; funcval = 0.0f; for (j = 0; j Date: Wed, 25 Oct 2017 16:25:50 +0100 Subject: added to repository --- src/Python/test/test_regularizers_3d.py | 380 ++++++++++++++++++++++++++++++++ 1 file changed, 380 insertions(+) create mode 100644 src/Python/test/test_regularizers_3d.py (limited to 'src') diff --git a/src/Python/test/test_regularizers_3d.py b/src/Python/test/test_regularizers_3d.py new file mode 100644 index 0000000..a2e3027 --- /dev/null +++ b/src/Python/test/test_regularizers_3d.py @@ -0,0 +1,380 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Aug 4 11:10:05 2017 + +@author: ofn77899 +""" + +from ccpi.viewer.CILViewer2D import Converter +import vtk + +import regularizers +import matplotlib.pyplot as plt +import numpy as np +import os +from enum import Enum +import timeit + +from Regularizer import Regularizer + +############################################################################### +#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 +#NRMSE a normalization of the root of the mean squared error +#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum +# intensity from the two images being compared, and respectively the same for +# minval. RMSE is given by the square root of MSE: +# sqrt[(sum(A - B) ** 2) / |A|], +# where |A| means the number of elements in A. By doing this, the maximum value +# given by RMSE is maxval. + +def nrmse(im1, im2): + a, b = im1.shape + rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b)) + max_val = max(np.max(im1), np.max(im2)) + min_val = min(np.min(im1), np.min(im2)) + return 1 - (rmse / (max_val - min_val)) +############################################################################### + +############################################################################### +# +# 2D Regularizers +# +############################################################################### +#Example: +# figure; +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +#reader = vtk.vtkTIFFReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +##vtk returns 3D images, let's take just the one slice there is as 2D +#Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 +# +##imgplot = plt.imshow(Im) +#perc = 0.05 +#u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +## map the u0 u0->u0>0 +#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +#u0 = f(u0).astype('float32') +# +### plot +#fig = plt.figure() +##a=fig.add_subplot(3,3,1) +##a.set_title('Original') +##imgplot = plt.imshow(Im) +# +#a=fig.add_subplot(2,3,1) +#a.set_title('noise') +#imgplot = plt.imshow(u0) +# +#reg_output = [] +############################################################################### +## Call regularizer +# +######################## SplitBregman_TV ##################################### +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +## +##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +## #tolerance_constant=1e-4, +## TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +##out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +## tolerance_constant=1e-4, +## TV_Penalty=Regularizer.TotalVariationPenalty.l1) +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) +#pars = out2[2] +#reg_output.append(out2) +# +#a=fig.add_subplot(2,3,2) +# +#textstr = out2[-1] +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output[-1][0]) +# +####################### FGP_TV ######################################### +## u = FGP_TV(single(u0), 0.05, 100, 1e-04); +#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +# number_of_iterations=200) +#pars = out2[-2] +# +#reg_output.append(out2) +# +#a=fig.add_subplot(2,3,3) +# +#textstr = out2[-1] +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output[-1][0]) +# +####################### LLT_model ######################################### +## * u0 = Im + .03*randn(size(Im)); % adding noise +## [Den] = LLT_model(single(u0), 10, 0.1, 1); +##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +##input, regularization_parameter , time_step, number_of_iterations, +## tolerance_constant, restrictive_Z_smoothing=0 +#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +# time_step=0.0003, +# tolerance_constant=0.0001, +# number_of_iterations=300) +#pars = out2[-2] +# +#reg_output.append(out2) +# +#a=fig.add_subplot(2,3,4) +#textstr = out2[-1] +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output[-1][0]) +# +####################### PatchBased_Regul ######################################### +## Quick 2D denoising example in Matlab: +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +## ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); +# +#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +# searching_window_ratio=3, +# similarity_window_ratio=1, +# PB_filtering_parameter=0.08) +#pars = out2[-2] +#reg_output.append(out2) +# +#a=fig.add_subplot(2,3,5) +# +# +#textstr = out2[-1] +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output[-1][0]) +# +# +####################### TGV_PD ######################################### +## Quick 2D denoising example in Matlab: +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +## u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); +# +# +#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +# first_order_term=1.3, +# second_order_term=1, +# number_of_iterations=550) +#pars = out2[-2] +#reg_output.append(out2) +# +#a=fig.add_subplot(2,3,6) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output[-1][0]) +# + +############################################################################### +# +# 3D Regularizers +# +############################################################################### +#Example: +# figure; +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" +filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" + +reader = vtk.vtkMetaImageReader() +reader.SetFileName(os.path.normpath(filename)) +reader.Update() +#vtk returns 3D images, let's take just the one slice there is as 2D +Im = Converter.vtk2numpy(reader.GetOutput()) +Im = Im.astype('float32') +#imgplot = plt.imshow(Im) +perc = 0.05 +u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +# map the u0 u0->u0>0 +f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +u0 = f(u0).astype('float32') +converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), + reader.GetOutput().GetOrigin()) +converter.Update() +writer = vtk.vtkMetaImageWriter() +writer.SetInputData(converter.GetOutput()) +writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") +#writer.Write() + + +## plot +fig3D = plt.figure(figsize=(20,16)) + +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) +sliceNo = 32 + +a=fig3D.add_subplot(2,3,1) +a.set_title('noise') +imgplot = plt.imshow(u0.T[sliceNo]) + +reg_output3d = [] + +############################################################################## +# Call regularizer + +####################### SplitBregman_TV ##################################### +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +#reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + +#out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +# #tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, + tolerance_constant=1e-4, + TV_Penalty=Regularizer.TotalVariationPenalty.l1) + + +pars = out2[-2] +reg_output3d.append(out2) + +a=fig3D.add_subplot(2,3,2) + + +textstr = out2[-1] + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) + +###################### FGP_TV ######################################### +# u = FGP_TV(single(u0), 0.05, 100, 1e-04); +#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +# number_of_iterations=200) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) + +###################### LLT_model ######################################### +# * u0 = Im + .03*randn(size(Im)); % adding noise +# [Den] = LLT_model(single(u0), 10, 0.1, 1); +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +#input, regularization_parameter , time_step, number_of_iterations, +# tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) +pars = out2[-2] +reg_output3d.append(out2) + +a=fig3D.add_subplot(2,3,3) + + +textstr = out2[-1] + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) + +###################### PatchBased_Regul ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output3d.append(out2) + +a=fig3D.add_subplot(2,3,4) + + +textstr = out2[-1] + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) + + +####################### TGV_PD ######################################### +## Quick 2D denoising example in Matlab: +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +## u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); +# +# +#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +# first_order_term=1.3, +# second_order_term=1, +# number_of_iterations=550) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +fig3D.savefig('test\\3d.png') +plt.close(fig3D) \ No newline at end of file -- cgit v1.2.3 From b1b65784db7c01911be8a8d57dc030f521352b68 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 25 Oct 2017 16:31:29 +0100 Subject: removed dependency --- src/Python/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt index 66630cb..1b73380 100644 --- a/src/Python/CMakeLists.txt +++ b/src/Python/CMakeLists.txt @@ -154,20 +154,20 @@ add_custom_target(regularizers COMMAND bash compile.sh WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${update_code} + DEPENDS update_code ) add_custom_target(install-fista COMMAND conda install --force --use-local ccpi-fista=${CIL_VERSION} -c ccpi -c conda-forge WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${fista}) + ) add_custom_target(install-regularizers COMMAND conda install --force --use-local ccpi-regularizers=${CIL_VERSION} -c ccpi -c conda-forge WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${fista}) + ) ### add tests #add_executable(RegularizersTest ) -- cgit v1.2.3 From 23668cba99464fb0189e80a883ab9234ee6a9965 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 25 Oct 2017 16:32:09 +0100 Subject: Removed unused variables --- src/Python/fista_module.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp index 94d156c..aca3be0 100644 --- a/src/Python/fista_module.cpp +++ b/src/Python/fista_module.cpp @@ -312,7 +312,7 @@ bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int me int number_of_dims, dimX, dimY, dimZ, ll, j, count; float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; - float lambda, tk, tkp1, re, re1, re_old, epsil, funcval, mu; + float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); @@ -901,7 +901,7 @@ bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_al bp::list result; int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll; //const int *dim_array; - float *A, *U, *U_old, *P1, *P2, *P3, *Q1, *Q2, *Q3, *Q4, *Q5, *Q6, *Q7, *Q8, *Q9, *V1, *V1_old, *V2, *V2_old, *V3, *V3_old, lambda, L2, tau, sigma, alpha1, alpha0; + float *A, *U, *U_old, *P1, *P2, *Q1, *Q2, *Q3, *V1, *V1_old, *V2, *V2_old, lambda, L2, tau, sigma, alpha1, alpha0; //number_of_dims = mxGetNumberOfDimensions(prhs[0]); //dim_array = mxGetDimensions(prhs[0]); -- cgit v1.2.3 From 6b24ef4e1e0780dc1eade61df025f886712339bc Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 25 Oct 2017 16:35:12 +0100 Subject: development --- src/Python/test/test_reconstructor-os.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/Python/test/test_reconstructor-os.py b/src/Python/test/test_reconstructor-os.py index 6c82ae0..3f419cf 100644 --- a/src/Python/test/test_reconstructor-os.py +++ b/src/Python/test/test_reconstructor-os.py @@ -78,12 +78,16 @@ fistaRecon.setParameter(ring_alpha = 21) fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) +#reg = Regularizer(Regularizer.Algorithm.FGP_TV) +#reg.setParameter(regularization_parameter=0.005, +# number_of_iterations=50) reg = Regularizer(Regularizer.Algorithm.LLT_model) reg.setParameter(regularization_parameter=25, time_step=0.0003, tolerance_constant=0.0001, number_of_iterations=300) + ## Ordered subset if True: subsets = 16 @@ -306,7 +310,7 @@ if True: # for slices: # out = regularizer(input=X) print ("regularizer") - #X = reg(input=X) + X = reg(input=X)[0] ## FINAL -- cgit v1.2.3 From ff9cc12694172e1e8720f7ea7f5b22e647722e21 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 25 Oct 2017 16:56:17 +0100 Subject: doing work --- src/Python/test/test_regularizers_3d.py | 471 +++++++++++++++++--------------- 1 file changed, 258 insertions(+), 213 deletions(-) (limited to 'src') diff --git a/src/Python/test/test_regularizers_3d.py b/src/Python/test/test_regularizers_3d.py index a2e3027..2d11a7e 100644 --- a/src/Python/test/test_regularizers_3d.py +++ b/src/Python/test/test_regularizers_3d.py @@ -5,17 +5,17 @@ Created on Fri Aug 4 11:10:05 2017 @author: ofn77899 """ -from ccpi.viewer.CILViewer2D import Converter -import vtk +#from ccpi.viewer.CILViewer2D import Converter +#import vtk -import regularizers import matplotlib.pyplot as plt import numpy as np import os from enum import Enum import timeit - -from Regularizer import Regularizer +#from PIL import Image +#from Regularizer import Regularizer +from ccpi.imaging.Regularizer import Regularizer ############################################################################### #https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 @@ -46,77 +46,303 @@ def nrmse(im1, im2): # u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; # u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + #filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif" +#filename = r'/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif' + #reader = vtk.vtkTIFFReader() #reader.SetFileName(os.path.normpath(filename)) #reader.Update() -##vtk returns 3D images, let's take just the one slice there is as 2D -#Im = Converter.vtk2numpy(reader.GetOutput()).T[0]/255 +Im = plt.imread(filename) +#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255 +#img.show() +Im = np.asarray(Im, dtype='float32') + +# create a 3D image by stacking N of this images + + +#imgplot = plt.imshow(Im) +perc = 0.05 +u_n = Im + (perc* np.random.normal(size=np.shape(Im))) +y,z = np.shape(u_n) +u_n = np.reshape(u_n , (1,y,z)) + +u0 = u_n.copy() +for i in range (19): + u_n = Im + (perc* np.random.normal(size=np.shape(Im))) + u_n = np.reshape(u_n , (1,y,z)) + + u0 = np.vstack ( (u0, u_n) ) + +# map the u0 u0->u0>0 +f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +u0 = f(u0).astype('float32') + +print ("Passed image shape {0}".format(np.shape(u0))) + +## plot +fig = plt.figure() +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) +sliceno = 10 + +a=fig.add_subplot(2,3,1) +a.set_title('noise') +imgplot = plt.imshow(u0[sliceno],cmap="gray") + +reg_output = [] +############################################################################## +# Call regularizer + +####################### SplitBregman_TV ##################################### +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +use_object = True +if use_object: + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + print (reg.pars) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + # or + # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + #TV_Penalty=Regularizer.TotalVariationPenalty.l1) + plotme = reg() [0] + pars = reg.pars + textstr = reg.printParametersToString() + + #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + # TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +else: + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + pars = out2[2] + reg_output.append(out2) + plotme = reg_output[-1][0] + textstr = out2[-1] + +a=fig.add_subplot(2,3,2) + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(plotme[sliceno],cmap="gray") + +###################### FGP_TV ######################################### +# u = FGP_TV(single(u0), 0.05, 100, 1e-04); +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.0005, + number_of_iterations=50) +pars = out2[-2] + +reg_output.append(out2) + +a=fig.add_subplot(2,3,3) + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0][sliceno]) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray") + +###################### LLT_model ######################################### +# * u0 = Im + .03*randn(size(Im)); % adding noise +# [Den] = LLT_model(single(u0), 10, 0.1, 1); +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +#input, regularization_parameter , time_step, number_of_iterations, +# tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) +pars = out2[-2] + +reg_output.append(out2) + +a=fig.add_subplot(2,3,4) + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray") + + +# ###################### PatchBased_Regul ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) + + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray") + + +# ###################### TGV_PD ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,6) + + +textstr = out2[-1] + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray") + + +plt.show() + +################################################################################ +## +## 3D Regularizers +## +################################################################################ +##Example: +## figure; +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" # +#reader = vtk.vtkMetaImageReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +##vtk returns 3D images, let's take just the one slice there is as 2D +#Im = Converter.vtk2numpy(reader.GetOutput()) +#Im = Im.astype('float32') ##imgplot = plt.imshow(Im) #perc = 0.05 #u0 = Im + (perc* np.random.normal(size=np.shape(Im))) ## map the u0 u0->u0>0 #f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) #u0 = f(u0).astype('float32') +#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), +# reader.GetOutput().GetOrigin()) +#converter.Update() +#writer = vtk.vtkMetaImageWriter() +#writer.SetInputData(converter.GetOutput()) +#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") +##writer.Write() +# # ### plot -#fig = plt.figure() +#fig3D = plt.figure() ##a=fig.add_subplot(3,3,1) ##a.set_title('Original') ##imgplot = plt.imshow(Im) +#sliceNo = 32 # -#a=fig.add_subplot(2,3,1) +#a=fig3D.add_subplot(2,3,1) #a.set_title('noise') -#imgplot = plt.imshow(u0) +#imgplot = plt.imshow(u0.T[sliceNo]) +# +#reg_output3d = [] # -#reg_output = [] ############################################################################### ## Call regularizer # ######################## SplitBregman_TV ##################################### ## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# ##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) -## +# ##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, ## #tolerance_constant=1e-4, ## TV_Penalty=Regularizer.TotalVariationPenalty.l1) # -##out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, -## tolerance_constant=1e-4, -## TV_Penalty=Regularizer.TotalVariationPenalty.l1) -#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) -#pars = out2[2] -#reg_output.append(out2) +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +# +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) # -#a=fig.add_subplot(2,3,2) # #textstr = out2[-1] +# +# ## these are matplotlib.patch.Patch properties #props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) ## place a text box in upper left in axes coords #a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, # verticalalignment='top', bbox=props) -#imgplot = plt.imshow(reg_output[-1][0]) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) # ####################### FGP_TV ######################################### ## u = FGP_TV(single(u0), 0.05, 100, 1e-04); #out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, # number_of_iterations=200) #pars = out2[-2] +#reg_output3d.append(out2) # -#reg_output.append(out2) +#a=fig3D.add_subplot(2,3,2) # -#a=fig.add_subplot(2,3,3) # #textstr = out2[-1] # +# ## these are matplotlib.patch.Patch properties #props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) ## place a text box in upper left in axes coords #a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, # verticalalignment='top', bbox=props) -#imgplot = plt.imshow(reg_output[-1][0]) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) # ####################### LLT_model ######################################### ## * u0 = Im + .03*randn(size(Im)); % adding noise @@ -129,17 +355,20 @@ def nrmse(im1, im2): # tolerance_constant=0.0001, # number_of_iterations=300) #pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) # -#reg_output.append(out2) # -#a=fig.add_subplot(2,3,4) #textstr = out2[-1] +# +# ## these are matplotlib.patch.Patch properties #props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) ## place a text box in upper left in axes coords #a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, # verticalalignment='top', bbox=props) -#imgplot = plt.imshow(reg_output[-1][0]) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) # ####################### PatchBased_Regul ######################################### ## Quick 2D denoising example in Matlab: @@ -152,136 +381,6 @@ def nrmse(im1, im2): # similarity_window_ratio=1, # PB_filtering_parameter=0.08) #pars = out2[-2] -#reg_output.append(out2) -# -#a=fig.add_subplot(2,3,5) -# -# -#textstr = out2[-1] -# -## these are matplotlib.patch.Patch properties -#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -## place a text box in upper left in axes coords -#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, -# verticalalignment='top', bbox=props) -#imgplot = plt.imshow(reg_output[-1][0]) -# -# -####################### TGV_PD ######################################### -## Quick 2D denoising example in Matlab: -## Im = double(imread('lena_gray_256.tif'))/255; % loading image -## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -## u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -# -# -#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, -# first_order_term=1.3, -# second_order_term=1, -# number_of_iterations=550) -#pars = out2[-2] -#reg_output.append(out2) -# -#a=fig.add_subplot(2,3,6) -# -# -#textstr = out2[-1] -# -# -## these are matplotlib.patch.Patch properties -#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -## place a text box in upper left in axes coords -#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, -# verticalalignment='top', bbox=props) -#imgplot = plt.imshow(reg_output[-1][0]) -# - -############################################################################### -# -# 3D Regularizers -# -############################################################################### -#Example: -# figure; -# Im = double(imread('lena_gray_256.tif'))/255; % loading image -# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; -# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); - -#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" -filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" - -reader = vtk.vtkMetaImageReader() -reader.SetFileName(os.path.normpath(filename)) -reader.Update() -#vtk returns 3D images, let's take just the one slice there is as 2D -Im = Converter.vtk2numpy(reader.GetOutput()) -Im = Im.astype('float32') -#imgplot = plt.imshow(Im) -perc = 0.05 -u0 = Im + (perc* np.random.normal(size=np.shape(Im))) -# map the u0 u0->u0>0 -f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) -u0 = f(u0).astype('float32') -converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), - reader.GetOutput().GetOrigin()) -converter.Update() -writer = vtk.vtkMetaImageWriter() -writer.SetInputData(converter.GetOutput()) -writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") -#writer.Write() - - -## plot -fig3D = plt.figure(figsize=(20,16)) - -#a=fig.add_subplot(3,3,1) -#a.set_title('Original') -#imgplot = plt.imshow(Im) -sliceNo = 32 - -a=fig3D.add_subplot(2,3,1) -a.set_title('noise') -imgplot = plt.imshow(u0.T[sliceNo]) - -reg_output3d = [] - -############################################################################## -# Call regularizer - -####################### SplitBregman_TV ##################################### -# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); - -#reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) - -#out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, -# #tolerance_constant=1e-4, -# TV_Penalty=Regularizer.TotalVariationPenalty.l1) - -out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, - tolerance_constant=1e-4, - TV_Penalty=Regularizer.TotalVariationPenalty.l1) - - -pars = out2[-2] -reg_output3d.append(out2) - -a=fig3D.add_subplot(2,3,2) - - -textstr = out2[-1] - - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) - -###################### FGP_TV ######################################### -# u = FGP_TV(single(u0), 0.05, 100, 1e-04); -#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, -# number_of_iterations=200) -#pars = out2[-2] #reg_output3d.append(out2) # #a=fig3D.add_subplot(2,3,2) @@ -296,67 +395,15 @@ imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) #a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, # verticalalignment='top', bbox=props) #imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# -###################### LLT_model ######################################### -# * u0 = Im + .03*randn(size(Im)); % adding noise -# [Den] = LLT_model(single(u0), 10, 0.1, 1); -#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); -#input, regularization_parameter , time_step, number_of_iterations, -# tolerance_constant, restrictive_Z_smoothing=0 -out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, - time_step=0.0003, - tolerance_constant=0.0001, - number_of_iterations=300) -pars = out2[-2] -reg_output3d.append(out2) - -a=fig3D.add_subplot(2,3,3) - - -textstr = out2[-1] - - -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) - -###################### PatchBased_Regul ######################################### +###################### TGV_PD ######################################### # Quick 2D denoising example in Matlab: # Im = double(imread('lena_gray_256.tif'))/255; % loading image # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -# ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); - -out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, - searching_window_ratio=3, - similarity_window_ratio=1, - PB_filtering_parameter=0.08) -pars = out2[-2] -reg_output3d.append(out2) - -a=fig3D.add_subplot(2,3,4) - - -textstr = out2[-1] +# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -# these are matplotlib.patch.Patch properties -props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) -# place a text box in upper left in axes coords -a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, - verticalalignment='top', bbox=props) -imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) - - -####################### TGV_PD ######################################### -## Quick 2D denoising example in Matlab: -## Im = double(imread('lena_gray_256.tif'))/255; % loading image -## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise -## u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); -# -# #out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, # first_order_term=1.3, # second_order_term=1, @@ -376,5 +423,3 @@ imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) #a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, # verticalalignment='top', bbox=props) #imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) -fig3D.savefig('test\\3d.png') -plt.close(fig3D) \ No newline at end of file -- cgit v1.2.3 From 01861a7022cb7855bc1a8cd7f8cfd6282690a4f1 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Wed, 25 Oct 2017 16:57:19 +0100 Subject: added to repository --- src/Python/test/test_regularizers.py | 412 +++++++++++++++++++++++++++++++++++ 1 file changed, 412 insertions(+) create mode 100644 src/Python/test/test_regularizers.py (limited to 'src') diff --git a/src/Python/test/test_regularizers.py b/src/Python/test/test_regularizers.py new file mode 100644 index 0000000..27e4ed3 --- /dev/null +++ b/src/Python/test/test_regularizers.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Aug 4 11:10:05 2017 + +@author: ofn77899 +""" + +#from ccpi.viewer.CILViewer2D import Converter +#import vtk + +import matplotlib.pyplot as plt +import numpy as np +import os +from enum import Enum +import timeit +#from PIL import Image +#from Regularizer import Regularizer +from ccpi.imaging.Regularizer import Regularizer + +############################################################################### +#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 +#NRMSE a normalization of the root of the mean squared error +#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum +# intensity from the two images being compared, and respectively the same for +# minval. RMSE is given by the square root of MSE: +# sqrt[(sum(A - B) ** 2) / |A|], +# where |A| means the number of elements in A. By doing this, the maximum value +# given by RMSE is maxval. + +def nrmse(im1, im2): + a, b = im1.shape + rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b)) + max_val = max(np.max(im1), np.max(im2)) + min_val = min(np.min(im1), np.min(im2)) + return 1 - (rmse / (max_val - min_val)) +############################################################################### + +############################################################################### +# +# 2D Regularizers +# +############################################################################### +#Example: +# figure; +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + + +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif" +#filename = r'/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif' + +#reader = vtk.vtkTIFFReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +Im = plt.imread(filename) +#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255 +#img.show() +Im = np.asarray(Im, dtype='float32') + + + + +#imgplot = plt.imshow(Im) +perc = 0.05 +u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +# map the u0 u0->u0>0 +f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +u0 = f(u0).astype('float32') + +## plot +fig = plt.figure() +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) + +a=fig.add_subplot(2,3,1) +a.set_title('noise') +imgplot = plt.imshow(u0,cmap="gray") + +reg_output = [] +############################################################################## +# Call regularizer + +####################### SplitBregman_TV ##################################### +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +use_object = True +if use_object: + reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + print (reg.pars) + reg.setParameter(input=u0) + reg.setParameter(regularization_parameter=10.) + # or + # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + #TV_Penalty=Regularizer.TotalVariationPenalty.l1) + plotme = reg() [0] + pars = reg.pars + textstr = reg.printParametersToString() + + #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, + #tolerance_constant=1e-4, + # TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +else: + out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + pars = out2[2] + reg_output.append(out2) + plotme = reg_output[-1][0] + textstr = out2[-1] + +a=fig.add_subplot(2,3,2) + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(plotme,cmap="gray") + +###################### FGP_TV ######################################### +# u = FGP_TV(single(u0), 0.05, 100, 1e-04); +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.0005, + number_of_iterations=50) +pars = out2[-2] + +reg_output.append(out2) + +a=fig.add_subplot(2,3,3) + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + +###################### LLT_model ######################################### +# * u0 = Im + .03*randn(size(Im)); % adding noise +# [Den] = LLT_model(single(u0), 10, 0.1, 1); +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +#input, regularization_parameter , time_step, number_of_iterations, +# tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, + time_step=0.0003, + tolerance_constant=0.0001, + number_of_iterations=300) +pars = out2[-2] + +reg_output.append(out2) + +a=fig.add_subplot(2,3,4) + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + + +# ###################### PatchBased_Regul ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, + searching_window_ratio=3, + similarity_window_ratio=1, + PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) + + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + + +# ###################### TGV_PD ######################################### +# # Quick 2D denoising example in Matlab: +# # Im = double(imread('lena_gray_256.tif'))/255; % loading image +# # u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# # u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, + first_order_term=1.3, + second_order_term=1, + number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,6) + + +textstr = out2[-1] + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + + +plt.show() + +################################################################################ +## +## 3D Regularizers +## +################################################################################ +##Example: +## figure; +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" +# +#reader = vtk.vtkMetaImageReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +##vtk returns 3D images, let's take just the one slice there is as 2D +#Im = Converter.vtk2numpy(reader.GetOutput()) +#Im = Im.astype('float32') +##imgplot = plt.imshow(Im) +#perc = 0.05 +#u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +## map the u0 u0->u0>0 +#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +#u0 = f(u0).astype('float32') +#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), +# reader.GetOutput().GetOrigin()) +#converter.Update() +#writer = vtk.vtkMetaImageWriter() +#writer.SetInputData(converter.GetOutput()) +#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") +##writer.Write() +# +# +### plot +#fig3D = plt.figure() +##a=fig.add_subplot(3,3,1) +##a.set_title('Original') +##imgplot = plt.imshow(Im) +#sliceNo = 32 +# +#a=fig3D.add_subplot(2,3,1) +#a.set_title('noise') +#imgplot = plt.imshow(u0.T[sliceNo]) +# +#reg_output3d = [] +# +############################################################################### +## Call regularizer +# +######################## SplitBregman_TV ##################################### +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +# +##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +## #tolerance_constant=1e-4, +## TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +# tolerance_constant=1e-4, +# TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +# +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### FGP_TV ######################################### +## u = FGP_TV(single(u0), 0.05, 100, 1e-04); +#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +# number_of_iterations=200) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### LLT_model ######################################### +## * u0 = Im + .03*randn(size(Im)); % adding noise +## [Den] = LLT_model(single(u0), 10, 0.1, 1); +##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0); +##input, regularization_parameter , time_step, number_of_iterations, +## tolerance_constant, restrictive_Z_smoothing=0 +#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +# time_step=0.0003, +# tolerance_constant=0.0001, +# number_of_iterations=300) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### PatchBased_Regul ######################################### +## Quick 2D denoising example in Matlab: +## Im = double(imread('lena_gray_256.tif'))/255; % loading image +## u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +## ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05); +# +#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +# searching_window_ratio=3, +# similarity_window_ratio=1, +# PB_filtering_parameter=0.08) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# + +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab: +# Im = double(imread('lena_gray_256.tif'))/255; % loading image +# u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +# first_order_term=1.3, +# second_order_term=1, +# number_of_iterations=550) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +# verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) -- cgit v1.2.3