diff options
Diffstat (limited to 'src/Python')
34 files changed, 6262 insertions, 0 deletions
diff --git a/src/Python/CMakeLists.txt b/src/Python/CMakeLists.txt new file mode 100644 index 0000000..1b73380 --- /dev/null +++ b/src/Python/CMakeLists.txt @@ -0,0 +1,181 @@ +#   Copyright 2017 Edoardo Pasca +# +#   Licensed under the Apache License, Version 2.0 (the "License"); +#   you may not use this file except in compliance with the License. +#   You may obtain a copy of the License at +# +#       http://www.apache.org/licenses/LICENSE-2.0 +# +#   Unless required by applicable law or agreed to in writing, software +#   distributed under the License is distributed on an "AS IS" BASIS, +#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#   See the License for the specific language governing permissions and +#   limitations under the License. + +# variables that must be set for conda compilation + +#PREFIX=C:\Apps\Miniconda2\envs\cil\Library +#LIBRARY_INC=C:\\Apps\\Miniconda2\\envs\\cil\\Library\\include +set (NUMPY_VERSION 1.12) + +## Tries to parse the output of conda env list to determine the current +## active conda environment +message ("Trying to determine your active conda environment...") +execute_process(COMMAND "conda" "env" "list" +					OUTPUT_VARIABLE _CONDA_ENVS +					RESULT_VARIABLE _CONDA_RESULT +					ERROR_VARIABLE _CONDA_ERR) +			if(NOT _CONDA_RESULT) +				string(REPLACE "\n" ";" ENV_LIST ${_CONDA_ENVS}) +				foreach(line ${ENV_LIST}) +				  string(REGEX MATCHALL "(.+)[*](.+)" match ${line}) +				  if (NOT ${match} EQUAL "") +				    #message("MATCHED " ${CMAKE_MATCH_0}) +					#message("MATCHED " ${CMAKE_MATCH_1}) +					#message("MATCHED " ${CMAKE_MATCH_2}) +					string(STRIP ${CMAKE_MATCH_1} CONDA_ENVIRONMENT) +					string(STRIP ${CMAKE_MATCH_2} CONDA_ENVIRONMENT_PATH)    +				  endif() +				endforeach() +			else() +				message(FATAL_ERROR "ERROR with conda command " ${_CONDA_ERR}) +			endif() + +if (${CONDA_ENVIRONMENT} AND ${CONDA_ENVIRONMENT_PATH}) +  message (FATAL_ERROR "CONDA NOT FOUND") +else() +  message("**********************************************************") +  message("Using current conda environmnet " ${CONDA_ENVIRONMENT}) +  message("Using current conda environmnet path " ${CONDA_ENVIRONMENT_PATH}) +endif() + +message("CIL VERSION " ${CIL_VERSION}) + +# set the Python variables for the Conda environment +include(FindAnacondaEnvironment.cmake) +findPythonForAnacondaEnvironment(${CONDA_ENVIRONMENT_PATH}) + +message("Python found " ${PYTHON_VERSION_STRING}) +message("Python found Major " ${PYTHON_VERSION_MAJOR}) +message("Python found Minor " ${PYTHON_VERSION_MINOR}) + +findPythonPackagesPath() +message("PYTHON_PACKAGES_FOUND " ${PYTHON_PACKAGES_PATH}) + +## CACHE SOME VARIABLES ## +set (CONDA_ENVIRONMENT ${CONDA_ENVIRONMENT} CACHE INTERNAL "active conda environment" FORCE) +set (CONDA_ENVIRONMENT_PATH ${CONDA_ENVIRONMENT_PATH} CACHE INTERNAL "active conda environment" FORCE) + +set (PYTHON_VERSION_STRING ${PYTHON_VERSION_STRING} CACHE INTERNAL "conda environment Python version string" FORCE) +set (PYTHON_VERSION_MAJOR ${PYTHON_VERSION_MAJOR} CACHE INTERNAL "conda environment Python version major" FORCE) +set (PYTHON_VERSION_MINOR ${PYTHON_VERSION_MINOR} CACHE INTERNAL "conda environment Python version minor" FORCE) +set (PYTHON_VERSION_PATCH ${PYTHON_VERSION_PATCH} CACHE INTERNAL "conda environment Python version patch" FORCE) +set (PYTHON_PACKAGES_PATH ${PYTHON_PACKAGES_PATH} CACHE INTERNAL "conda environment Python packages path" FORCE) + +if (WIN32) +  #set (CONDA_ENVIRONMENT_PATH "C:\\Apps\\Miniconda2\\envs\\${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") +  set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}\\Library" CACHE PATH "env dir") +  set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}\\include" CACHE PATH "env dir") +elseif (UNIX) +  #set (CONDA_ENVIRONMENT_PATH "/apps/anaconda/2.4/envs/${CONDA_ENVIRONMENT}" CACHE PATH "Main environment directory") +  set (CONDA_ENVIRONMENT_PREFIX "${CONDA_ENVIRONMENT_PATH}/lib/python${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}" CACHE PATH "env dir") +  set (CONDA_ENVIRONMENT_LIBRARY_INC "${CONDA_ENVIRONMENT_PREFIX}/include" CACHE PATH "env dir") +endif() + +######### CONFIGURE REGULARIZER PACKAGE ############# + +# copy the Pyhon files of the package regularizer +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging/) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi) +# regularizers +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/imaging/Regularizer.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/imaging) + +# Copy and configure the relative conda build and recipes +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup.py) +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe) + +if (WIN32) +	 +  file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/)  +  configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile.bat) + +elseif(UNIX) +   +  message ("We are on UNIX") +  file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/conda-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/conda-recipe/) +  # assumes we will use bash +  configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile.sh) + +endif() + +########## CONFIGURE FISTA RECONSTRUCTOR PACKAGE +# fista reconstructor +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/FISTAReconstructor.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/ccpi/reconstruction/__init__.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/ccpi/reconstruction) + +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup-fista.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup-fista.py) +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/meta.yaml DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe) + +if (WIN32) + +  file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/bld.bat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/) +  configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.bat.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.bat) + +elseif(UNIX) +  message ("We are on UNIX") +  file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/fista-recipe/build.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/fista-recipe/) +  # assumes we will use bash +  configure_file(${CMAKE_CURRENT_SOURCE_DIR}/compile-fista.sh.in ${CMAKE_CURRENT_BINARY_DIR}/compile-fista.sh) +endif() + +#############################  TARGETS + +##########################  REGULARIZER PACKAGE ############################### + +# runs cmake on the build tree to update the code from source +add_custom_target(update_code  +		COMMAND ${CMAKE_COMMAND}  +		ARGS ${CMAKE_SOURCE_DIR} +		WORKING_DIRECTORY ${CMAKE_BINARY_DIR} +		)  + + +add_custom_target(fista +		COMMAND bash +		compile-fista.sh +		WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +		DEPENDS ${update_code} +		) + +add_custom_target(regularizers +		COMMAND bash +		compile.sh +		WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +		DEPENDS update_code +		) + +add_custom_target(install-fista +		COMMAND conda +		install --force --use-local ccpi-fista=${CIL_VERSION} -c ccpi -c conda-forge +		WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +		) + +add_custom_target(install-regularizers +		COMMAND conda +		install --force --use-local ccpi-regularizers=${CIL_VERSION} -c ccpi -c conda-forge +		WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +		) +### add tests + +#add_executable(RegularizersTest ) +#find_package(tiff) +#if (TIFF_FOUND) +#  message("LibTIFF Found") +#  message("TIFF_INCLUDE_DIR " ${TIFF_INCLUDE_DIR}) +#  message("TIFF_LIBRARIES" ${TIFF_LIBRARIES}) +#else() +#  message("LibTIFF not found") +#endif() diff --git a/src/Python/FindAnacondaEnvironment.cmake b/src/Python/FindAnacondaEnvironment.cmake new file mode 100644 index 0000000..fa4637a --- /dev/null +++ b/src/Python/FindAnacondaEnvironment.cmake @@ -0,0 +1,149 @@ +#   Copyright 2017 Edoardo Pasca +# +#   Licensed under the Apache License, Version 2.0 (the "License"); +#   you may not use this file except in compliance with the License. +#   You may obtain a copy of the License at +# +#       http://www.apache.org/licenses/LICENSE-2.0 +# +#   Unless required by applicable law or agreed to in writing, software +#   distributed under the License is distributed on an "AS IS" BASIS, +#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#   See the License for the specific language governing permissions and +#   limitations under the License. + +# #.rst: +# FindAnacondaEnvironment +# -------------- +# +# Find Python executable and library for a specific Anaconda environment +# +# This module finds the Python interpreter for a specific Anaconda enviroment,  +# if installed and determines where the include files and libraries are.   +# This code sets the following variables: +# +# :: +#   PYTHONINTERP_FOUND         - if the Python interpret has been found +#   PYTHON_EXECUTABLE          - the Python interpret found +#   PYTHON_LIBRARY             - path to the python library +#   PYTHON_INCLUDE_PATH        - path to where Python.h is found (deprecated) +#   PYTHON_INCLUDE_DIRS        - path to where Python.h is found +#   PYTHONLIBS_VERSION_STRING  - version of the Python libs found (since CMake 2.8.8) +#   PYTHON_VERSION_MAJOR       - major Python version +#   PYTHON_VERSION_MINOR       - minor Python version +#   PYTHON_VERSION_PATCH       - patch Python version + + + +function (findPythonForAnacondaEnvironment env) +	if (WIN32) +	  file(TO_CMAKE_PATH ${env}/python.exe PYTHON_EXECUTABLE) +        elseif (UNIX) +  	  file(TO_CMAKE_PATH ${env}/bin/python PYTHON_EXECUTABLE) +	endif() + +	 +	message("findPythonForAnacondaEnvironment Found Python Executable" ${PYTHON_EXECUTABLE}) +	####### FROM FindPythonInterpr ######## +	# determine python version string +	if(PYTHON_EXECUTABLE) +		execute_process(COMMAND "${PYTHON_EXECUTABLE}" -c +								"import sys; sys.stdout.write(';'.join([str(x) for x in sys.version_info[:3]]))" +						OUTPUT_VARIABLE _VERSION +						RESULT_VARIABLE _PYTHON_VERSION_RESULT +						ERROR_QUIET) +		if(NOT _PYTHON_VERSION_RESULT) +			string(REPLACE ";" "." _PYTHON_VERSION_STRING "${_VERSION}") +			list(GET _VERSION 0 _PYTHON_VERSION_MAJOR) +			list(GET _VERSION 1 _PYTHON_VERSION_MINOR) +			list(GET _VERSION 2 _PYTHON_VERSION_PATCH) +			if(PYTHON_VERSION_PATCH EQUAL 0) +				# it's called "Python 2.7", not "2.7.0" +				string(REGEX REPLACE "\\.0$" "" _PYTHON_VERSION_STRING "${PYTHON_VERSION_STRING}") +			endif() +		else() +			# sys.version predates sys.version_info, so use that +			execute_process(COMMAND "${PYTHON_EXECUTABLE}" -c "import sys; sys.stdout.write(sys.version)" +							OUTPUT_VARIABLE _VERSION +							RESULT_VARIABLE _PYTHON_VERSION_RESULT +							ERROR_QUIET) +			if(NOT _PYTHON_VERSION_RESULT) +				string(REGEX REPLACE " .*" "" _PYTHON_VERSION_STRING "${_VERSION}") +				string(REGEX REPLACE "^([0-9]+)\\.[0-9]+.*" "\\1" _PYTHON_VERSION_MAJOR "${PYTHON_VERSION_STRING}") +				string(REGEX REPLACE "^[0-9]+\\.([0-9])+.*" "\\1" _PYTHON_VERSION_MINOR "${PYTHON_VERSION_STRING}") +				if(PYTHON_VERSION_STRING MATCHES "^[0-9]+\\.[0-9]+\\.([0-9]+)") +					set(PYTHON_VERSION_PATCH "${CMAKE_MATCH_1}") +				else() +					set(PYTHON_VERSION_PATCH "0") +				endif() +			else() +				# sys.version was first documented for Python 1.5, so assume +				# this is older. +				set(PYTHON_VERSION_STRING "1.4" PARENT_SCOPE) +				set(PYTHON_VERSION_MAJOR "1" PARENT_SCOPE) +				set(PYTHON_VERSION_MINOR "4" PARENT_SCOPE) +				set(PYTHON_VERSION_PATCH "0" PARENT_SCOPE) +			endif() +		endif() +		unset(_PYTHON_VERSION_RESULT) +		unset(_VERSION) +	endif() +	############################################### +	 +	set (PYTHON_EXECUTABLE ${PYTHON_EXECUTABLE} PARENT_SCOPE) +	set (PYTHONINTERP_FOUND "ON" PARENT_SCOPE) +	set (PYTHON_VERSION_STRING ${_PYTHON_VERSION_STRING} PARENT_SCOPE) +	set (PYTHON_VERSION_MAJOR ${_PYTHON_VERSION_MAJOR} PARENT_SCOPE) +	set (PYTHON_VERSION_MINOR ${_PYTHON_VERSION_MINOR} PARENT_SCOPE) +	set (PYTHON_VERSION_PATCH ${_PYTHON_VERSION_PATCH} PARENT_SCOPE) +	message("My version found " ${PYTHON_VERSION_STRING}) + +endfunction() + + + +set(Python_ADDITIONAL_VERSIONS 3.5) + +find_package(PythonInterp) +if (PYTHONINTERP_FOUND) +   +  message("Found interpret " ${PYTHON_EXECUTABLE}) +  message("Python Library " ${PYTHON_LIBRARY}) +  message("Python Include Dir " ${PYTHON_INCLUDE_DIR}) +  message("Python Include Path " ${PYTHON_INCLUDE_PATH}) +   +  foreach(pv ${PYTHON_VERSION_STRING}) +    message("Found interpret " ${pv}) +  endforeach() +endif() + + + +find_package(PythonLibs) +if (PYTHONLIB_FOUND)  +  message("Found PythonLibs PYTHON_LIBRARIES " ${PYTHON_LIBRARIES}) +  message("Found PythonLibs PYTHON_INCLUDE_PATH " ${PYTHON_INCLUDE_PATH}) +  message("Found PythonLibs PYTHON_INCLUDE_DIRS " ${PYTHON_INCLUDE_DIRS}) +  message("Found PythonLibs PYTHONLIBS_VERSION_STRING " ${PYTHONLIBS_VERSION_STRING}  ) +else() +  message("No PythonLibs Found")   +endif() + + + + +function(findPythonPackagesPath) +   execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "from distutils.sysconfig import *; print (get_python_lib())" +                      RESULT_VARIABLE PYTHON_CVPY_PROCESS +                      OUTPUT_VARIABLE PYTHON_STD_PACKAGES_PATH +                      OUTPUT_STRIP_TRAILING_WHITESPACE) +   #message("STD_PACKAGES " ${PYTHON_STD_PACKAGES_PATH}) +   if("${PYTHON_STD_PACKAGES_PATH}" MATCHES "site-packages") +        set(_PYTHON_PACKAGES_PATH "python${PYTHON_VERSION_MAJOR_MINOR}/site-packages") +   endif() + +    SET(PYTHON_PACKAGES_PATH "${PYTHON_STD_PACKAGES_PATH}" PARENT_SCOPE) + +endfunction() + + diff --git a/src/Python/Matlab2Python_utils.cpp b/src/Python/Matlab2Python_utils.cpp new file mode 100644 index 0000000..ee76bc7 --- /dev/null +++ b/src/Python/Matlab2Python_utils.cpp @@ -0,0 +1,276 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazanteev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include <iostream> +#include <cmath> + +#include <boost/python.hpp> +#include <boost/python/numpy.hpp> +#include "boost/tuple/tuple.hpp" + +#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) +#include <windows.h> +// this trick only if compiler is MSVC +__if_not_exists(uint8_t) { typedef __int8 uint8_t; } +__if_not_exists(uint16_t) { typedef __int8 uint16_t; } +#endif + +namespace bp = boost::python; +namespace np = boost::python::numpy; + +/*! in the Matlab implementation this is called as +void mexFunction( +int nlhs, mxArray *plhs[], +int nrhs, const mxArray *prhs[]) +where: +prhs Array of pointers to the INPUT mxArrays +nrhs int number of INPUT mxArrays + +nlhs Array of pointers to the OUTPUT mxArrays +plhs int number of OUTPUT mxArrays + +*********************************************************** +	 +*********************************************************** +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray.	In C, mxGetScalar returns a double. +*********************************************************** +char *mxArrayToString(const mxArray *array_ptr); +args: array_ptr Pointer to mxCHAR array. +Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array. +Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string. +*********************************************************** +mxClassID mxGetClassID(const mxArray *pm); +args: pm Pointer to an mxArray +Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types, +mxGetClassId returns a unique value identifying the class of the array contents. +Use mxIsClass to determine whether an array is of a specific user-defined type. + +mxClassID Value	  MATLAB Type   MEX Type	 C Primitive Type +mxINT8_CLASS 	  int8	        int8_T	     char, byte +mxUINT8_CLASS	  uint8	        uint8_T	     unsigned char, byte +mxINT16_CLASS	  int16	        int16_T	     short +mxUINT16_CLASS	  uint16	    uint16_T	 unsigned short +mxINT32_CLASS	  int32	        int32_T	     int +mxUINT32_CLASS	  uint32	    uint32_T	 unsigned int +mxINT64_CLASS	  int64	        int64_T	     long long +mxUINT64_CLASS	  uint64	    uint64_T 	 unsigned long long +mxSINGLE_CLASS	  single	    float	     float +mxDOUBLE_CLASS	  double	    double	     double + +**************************************************************** +double *mxGetPr(const mxArray *pm); +args: pm Pointer to an mxArray of type double +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +**************************************************************** +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag  If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). +Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + +void mexErrMessageText(char* text) { +	std::cerr << text << std::endl; +} + +/* +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray.	In C, mxGetScalar returns a double. +*/ + +template<typename T> +double mxGetScalar(const np::ndarray plh) { +	return (double)bp::extract<T>(plh[0]); +} + + + +template<typename T> +T * mxGetData(const np::ndarray pm) { +    //args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +	//Returns: Pointer to the value of the first real(nonimaginary) element of the mxArray.In C, mxGetScalar returns a double. +	/*Access the numpy array pointer: +	char * get_data() const; +	Returns:	Array’s raw data pointer as a char +	Note:	This returns char so stride math works properly on it.User will have to reinterpret_cast it. +	probably this would work. +	A = reinterpret_cast<float *>(prhs[0]); +	*/ +	//return reinterpret_cast<T *>(prhs[0]); +} + +template<typename T> +np::ndarray zeros(int dims , int * dim_array, T el) { +	bp::tuple shape; +	if (dims == 3) +		shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); +	else if (dims == 2) +		shape = bp::make_tuple(dim_array[0], dim_array[1]); +	np::dtype dtype = np::dtype::get_builtin<T>(); +	np::ndarray zz = np::zeros(shape, dtype); +	return zz; +} + + +bp::list mexFunction( np::ndarray input ) { +	int number_of_dims = input.get_nd(); +	int dim_array[3]; + +	dim_array[0] = input.shape(0); +	dim_array[1] = input.shape(1); +	if (number_of_dims == 2) { +		dim_array[2] = -1; +	} +	else { +		dim_array[2] = input.shape(2); +	} + +	/**************************************************************************/ +	np::ndarray zz = zeros(3, dim_array, (int)0); +	np::ndarray fzz = zeros(3, dim_array, (float)0); +	/**************************************************************************/ +	 +	int * A = reinterpret_cast<int *>( input.get_data() ); +	int * B = reinterpret_cast<int *>( zz.get_data() ); +	float * C = reinterpret_cast<float *>(fzz.get_data()); + +	//Copy data and cast +	for (int i = 0; i < dim_array[0]; i++) { +		for (int j = 0; j < dim_array[1]; j++) { +			for (int k = 0; k < dim_array[2]; k++) { +				int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; +				int val = (*(A + index)); +				float fval = sqrt((float)val); +				std::memcpy(B + index , &val, sizeof(int)); +				std::memcpy(C + index , &fval, sizeof(float)); +			} +		} +	} + + +	bp::list result; + +	result.append<int>(number_of_dims); +	result.append<int>(dim_array[0]); +	result.append<int>(dim_array[1]); +	result.append<int>(dim_array[2]); +	result.append<np::ndarray>(zz); +	result.append<np::ndarray>(fzz); + +	//result.append<bp::tuple>(tup); +	return result; + +} +bp::list doSomething(np::ndarray input, PyObject *pyobj , PyObject *pyobj2) { + +	boost::python::object output(boost::python::handle<>(boost::python::borrowed(pyobj))); +	int isOutput = !(output == boost::python::api::object()); + +	boost::python::object calculate(boost::python::handle<>(boost::python::borrowed(pyobj2))); +	int isCalculate = !(calculate == boost::python::api::object()); + +	int number_of_dims = input.get_nd(); +	int dim_array[3]; + +	dim_array[0] = input.shape(0); +	dim_array[1] = input.shape(1); +	if (number_of_dims == 2) { +		dim_array[2] = -1; +	} +	else { +		dim_array[2] = input.shape(2); +	} + +	/**************************************************************************/ +	np::ndarray zz = zeros(3, dim_array, (int)0); +	np::ndarray fzz = zeros(3, dim_array, (float)0); +	/**************************************************************************/ + +	int * A = reinterpret_cast<int *>(input.get_data()); +	int * B = reinterpret_cast<int *>(zz.get_data()); +	float * C = reinterpret_cast<float *>(fzz.get_data()); + +	//Copy data and cast +	for (int i = 0; i < dim_array[0]; i++) { +		for (int j = 0; j < dim_array[1]; j++) { +			for (int k = 0; k < dim_array[2]; k++) { +				int index = k + dim_array[2] * j + dim_array[2] * dim_array[1] * i; +				int val = (*(A + index)); +				float fval = sqrt((float)val); +				std::memcpy(B + index, &val, sizeof(int)); +				std::memcpy(C + index, &fval, sizeof(float)); +				// if the PyObj is not None evaluate the function  +				if (isOutput)	 +					output(fval); +				if (isCalculate) { +					float nfval = (float)bp::extract<float>(calculate(val)); +					if (isOutput) +						output(nfval); +					std::memcpy(C + index, &nfval, sizeof(float)); +				} +			} +		} +	} + + +	bp::list result; + +	result.append<int>(number_of_dims); +	result.append<int>(dim_array[0]); +	result.append<int>(dim_array[1]); +	result.append<int>(dim_array[2]); +	result.append<np::ndarray>(zz); +	result.append<np::ndarray>(fzz); + +	//result.append<bp::tuple>(tup); +	return result; + +} + + +BOOST_PYTHON_MODULE(prova) +{ +	np::initialize(); + +	//To specify that this module is a package +	bp::object package = bp::scope(); +	package.attr("__path__") = "prova"; + +	np::dtype dt1 = np::dtype::get_builtin<uint8_t>(); +	np::dtype dt2 = np::dtype::get_builtin<uint16_t>(); +	 +	//import_array(); +	//numpy_boost_python_register_type<float, 1>(); +	//numpy_boost_python_register_type<float, 2>(); +	//numpy_boost_python_register_type<float, 3>(); +	//numpy_boost_python_register_type<double, 3>(); +	def("mexFunction", mexFunction); +	def("doSomething", doSomething); +} diff --git a/src/Python/Regularizer.py b/src/Python/Regularizer.py new file mode 100644 index 0000000..15dbbb4 --- /dev/null +++ b/src/Python/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug  8 14:26:00 2017 + +@author: ofn77899 +""" + +import regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): +    '''Class to handle regularizer algorithms to be used during reconstruction +     +    Currently 5 CPU (OMP) regularization algorithms are available: +         +    1) SplitBregman_TV +    2) FGP_TV +    3) LLT_model +    4) PatchBased_Regul +    5) TGV_PD +     +    Usage: +        the regularizer can be invoked as object or as static method +        Depending on the actual regularizer the input parameter may vary, and  +        a different default setting is defined. +        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + +        out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, +          tolerance_constant=1e-4,  +          TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., +          number_of_iterations=30, tolerance_constant=1e-4,  +          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +         +        A number of optional parameters can be passed or skipped +        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + +    ''' +    class Algorithm(Enum): +        SplitBregman_TV = regularizers.SplitBregman_TV +        FGP_TV = regularizers.FGP_TV +        LLT_model = regularizers.LLT_model +        PatchBased_Regul = regularizers.PatchBased_Regul +        TGV_PD = regularizers.TGV_PD +    # Algorithm +     +    class TotalVariationPenalty(Enum): +        isotropic = 0 +        l1 = 1 +    # TotalVariationPenalty +         +    def __init__(self , algorithm, debug = True): +        self.setAlgorithm ( algorithm ) +        self.debug = debug +    # __init__ +     +    def setAlgorithm(self, algorithm): +        self.algorithm = algorithm +        self.pars = self.getDefaultParsForAlgorithm(algorithm) +    # setAlgorithm +         +    def getDefaultParsForAlgorithm(self, algorithm): +        pars = dict() +         +        if algorithm == Regularizer.Algorithm.SplitBregman_TV : +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['regularization_parameter'] = None +            pars['number_of_iterations'] = 35 +            pars['tolerance_constant'] = 0.0001 +            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic +             +        elif algorithm == Regularizer.Algorithm.FGP_TV : +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['regularization_parameter'] = None +            pars['number_of_iterations'] = 50 +            pars['tolerance_constant'] = 0.001 +            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic +             +        elif algorithm == Regularizer.Algorithm.LLT_model: +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['regularization_parameter'] = None +            pars['time_step'] = None +            pars['number_of_iterations'] = None +            pars['tolerance_constant'] = None +            pars['restrictive_Z_smoothing'] = 0 +             +        elif algorithm == Regularizer.Algorithm.PatchBased_Regul: +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['searching_window_ratio'] = None +            pars['similarity_window_ratio'] = None +            pars['PB_filtering_parameter'] = None +            pars['regularization_parameter'] = None +             +        elif algorithm == Regularizer.Algorithm.TGV_PD: +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['first_order_term'] = None +            pars['second_order_term'] = None +            pars['number_of_iterations'] = None +            pars['regularization_parameter'] = None +             +        else: +            raise Exception('Unknown regularizer algorithm') +             +        return pars +    # parsForAlgorithm +     +    def setParameter(self, **kwargs): +        '''set named parameter for the regularization engine +         +        raises Exception if the named parameter is not recognized +        Typical usage is: +             +        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +        reg.setParameter(input=u0)     +        reg.setParameter(regularization_parameter=10.) +         +        it can be also used as +        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +        reg.setParameter(input=u0 , regularization_parameter=10.) +        ''' +         +        for key , value in kwargs.items(): +            if key in self.pars.keys(): +                self.pars[key] = value +            else: +                raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) +    # setParameter +	 +    def getParameter(self, **kwargs): +        ret = {} +        for key , value in kwargs.items(): +            if key in self.pars.keys(): +                ret[key] = self.pars[key] +        else: +            raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) +    # setParameter +	 +         +    def __call__(self, input = None, regularization_parameter = None, **kwargs): +        '''Actual call for the regularizer.  +         +        One can either set the regularization parameters first and then call the +        algorithm or set the regularization parameter during the call (as  +        is done in the static methods).  +        ''' +         +        if kwargs is not None: +            for key, value in kwargs.items(): +                #print("{0} = {1}".format(key, value))                         +                self.pars[key] = value +                     +        if input is not None:  +            self.pars['input'] = input +        if regularization_parameter is not None: +            self.pars['regularization_parameter'] = regularization_parameter +             +        if self.debug: +            print ("--------------------------------------------------") +            for key, value in self.pars.items(): +                if key== 'algorithm' : +                    print("{0} = {1}".format(key, value.__name__)) +                elif key == 'input': +                    print("{0} = {1}".format(key, np.shape(value))) +                else: +                    print("{0} = {1}".format(key, value)) +         +             +        if None in self.pars: +                raise Exception("Not all parameters have been provided") +         +        input = self.pars['input'] +        regularization_parameter = self.pars['regularization_parameter'] +        if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : +            return self.algorithm(input, regularization_parameter, +                              self.pars['number_of_iterations'], +                              self.pars['tolerance_constant'], +                              self.pars['TV_penalty'].value )     +        elif self.algorithm == Regularizer.Algorithm.FGP_TV : +            return self.algorithm(input, regularization_parameter, +                              self.pars['number_of_iterations'], +                              self.pars['tolerance_constant'], +                              self.pars['TV_penalty'].value ) +        elif self.algorithm == Regularizer.Algorithm.LLT_model : +            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) +            # no default +            return self.algorithm(input,  +                              regularization_parameter, +                              self.pars['time_step'] ,  +                              self.pars['number_of_iterations'], +                              self.pars['tolerance_constant'], +                              self.pars['restrictive_Z_smoothing'] ) +        elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : +            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) +            # no default +            return self.algorithm(input, regularization_parameter, +                                  self.pars['searching_window_ratio'] ,  +                                  self.pars['similarity_window_ratio'] ,  +                                  self.pars['PB_filtering_parameter']) +        elif self.algorithm == Regularizer.Algorithm.TGV_PD : +            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) +            # no default +            if len(np.shape(input)) == 2: +                return self.algorithm(input, regularization_parameter, +                                  self.pars['first_order_term'] ,  +                                  self.pars['second_order_term'] ,  +                                  self.pars['number_of_iterations']) +            elif len(np.shape(input)) == 3: +                #assuming it's 3D +                # run independent calls on each slice +                out3d = input.copy() +                for i in range(np.shape(input)[2]): +                    out = self.algorithm(input, regularization_parameter, +                                 self.pars['first_order_term'] ,  +                                 self.pars['second_order_term'] ,  +                                 self.pars['number_of_iterations']) +                    # copy the result in the 3D image +                    out3d.T[i] = out[0].copy() +                # append the rest of the info that the algorithm returns +                output = [out3d] +                for i in range(1,len(out)): +                    output.append(out[i]) +                return output +                 +                 +             +             +         +    # __call__ +     +    @staticmethod +    def SplitBregman_TV(input, regularization_parameter , **kwargs): +        start_time = timeit.default_timer() +        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +        out = list( reg(input, regularization_parameter, **kwargs) ) +        out.append(reg.pars) +        txt = reg.printParametersToString() +        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +        out.append(txt) +        return out +         +    @staticmethod +    def FGP_TV(input, regularization_parameter , **kwargs): +        start_time = timeit.default_timer() +        reg = Regularizer(Regularizer.Algorithm.FGP_TV) +        out = list( reg(input, regularization_parameter, **kwargs) ) +        out.append(reg.pars) +        txt = reg.printParametersToString() +        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +        out.append(txt) +        return out +     +    @staticmethod +    def LLT_model(input, regularization_parameter , time_step, number_of_iterations, +                  tolerance_constant, restrictive_Z_smoothing=0): +        start_time = timeit.default_timer() +        reg = Regularizer(Regularizer.Algorithm.LLT_model) +        out = list( reg(input, regularization_parameter, time_step=time_step,  +                        number_of_iterations=number_of_iterations, +                        tolerance_constant=tolerance_constant,  +                        restrictive_Z_smoothing=restrictive_Z_smoothing) ) +        out.append(reg.pars) +        txt = reg.printParametersToString() +        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +        out.append(txt) +        return out +     +    @staticmethod +    def PatchBased_Regul(input, regularization_parameter, +                        searching_window_ratio,  +                        similarity_window_ratio, +                        PB_filtering_parameter): +        start_time = timeit.default_timer() +        reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul)    +        out = list( reg(input,  +                        regularization_parameter, +                        searching_window_ratio=searching_window_ratio,  +                        similarity_window_ratio=similarity_window_ratio, +                        PB_filtering_parameter=PB_filtering_parameter ) +            ) +        out.append(reg.pars) +        txt = reg.printParametersToString() +        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +        out.append(txt) +        return out +     +    @staticmethod +    def TGV_PD(input, regularization_parameter , first_order_term,  +               second_order_term, number_of_iterations): +        start_time = timeit.default_timer() +         +        reg = Regularizer(Regularizer.Algorithm.TGV_PD) +        out = list( reg(input, regularization_parameter,  +                        first_order_term=first_order_term,  +                        second_order_term=second_order_term, +                        number_of_iterations=number_of_iterations) ) +        out.append(reg.pars) +        txt = reg.printParametersToString() +        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +        out.append(txt) +         +        return out +     +    def printParametersToString(self): +        txt = r'' +        for key, value in self.pars.items(): +            if key== 'algorithm' : +                txt += "{0} = {1}".format(key, value.__name__) +            elif key == 'input': +                txt += "{0} = {1}".format(key, np.shape(value)) +            else: +                txt += "{0} = {1}".format(key, value) +            txt += '\n' +        return txt +         diff --git a/src/Python/ccpi/__init__.py b/src/Python/ccpi/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/Python/ccpi/__init__.py diff --git a/src/Python/ccpi/imaging/Regularizer.py b/src/Python/ccpi/imaging/Regularizer.py new file mode 100644 index 0000000..fb9ae08 --- /dev/null +++ b/src/Python/ccpi/imaging/Regularizer.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug  8 14:26:00 2017 + +@author: ofn77899 +""" + +from ccpi.imaging import cpu_regularizers +import numpy as np +from enum import Enum +import timeit + +class Regularizer(): +    '''Class to handle regularizer algorithms to be used during reconstruction +     +    Currently 5 CPU (OMP) regularization algorithms are available: +         +    1) SplitBregman_TV +    2) FGP_TV +    3) LLT_model +    4) PatchBased_Regul +    5) TGV_PD +     +    Usage: +        the regularizer can be invoked as object or as static method +        Depending on the actual regularizer the input parameter may vary, and  +        a different default setting is defined. +        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) + +        out = reg(input=u0, regularization_parameter=10., number_of_iterations=30, +          tolerance_constant=1e-4,  +          TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., +          number_of_iterations=30, tolerance_constant=1e-4,  +          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +         +        A number of optional parameters can be passed or skipped +        out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) + +    ''' +    class Algorithm(Enum): +        SplitBregman_TV = cpu_regularizers.SplitBregman_TV +        FGP_TV = cpu_regularizers.FGP_TV +        LLT_model = cpu_regularizers.LLT_model +        PatchBased_Regul = cpu_regularizers.PatchBased_Regul +        TGV_PD = cpu_regularizers.TGV_PD +    # Algorithm +     +    class TotalVariationPenalty(Enum): +        isotropic = 0 +        l1 = 1 +    # TotalVariationPenalty +         +    def __init__(self , algorithm, debug = True): +        self.setAlgorithm ( algorithm ) +        self.debug = debug +    # __init__ +     +    def setAlgorithm(self, algorithm): +        self.algorithm = algorithm +        self.pars = self.getDefaultParsForAlgorithm(algorithm) +    # setAlgorithm +         +    def getDefaultParsForAlgorithm(self, algorithm): +        pars = dict() +         +        if algorithm == Regularizer.Algorithm.SplitBregman_TV : +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['regularization_parameter'] = None +            pars['number_of_iterations'] = 35 +            pars['tolerance_constant'] = 0.0001 +            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic +             +        elif algorithm == Regularizer.Algorithm.FGP_TV : +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['regularization_parameter'] = None +            pars['number_of_iterations'] = 50 +            pars['tolerance_constant'] = 0.001 +            pars['TV_penalty'] = Regularizer.TotalVariationPenalty.isotropic +             +        elif algorithm == Regularizer.Algorithm.LLT_model: +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['regularization_parameter'] = None +            pars['time_step'] = None +            pars['number_of_iterations'] = None +            pars['tolerance_constant'] = None +            pars['restrictive_Z_smoothing'] = 0 +             +        elif algorithm == Regularizer.Algorithm.PatchBased_Regul: +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['searching_window_ratio'] = None +            pars['similarity_window_ratio'] = None +            pars['PB_filtering_parameter'] = None +            pars['regularization_parameter'] = None +             +        elif algorithm == Regularizer.Algorithm.TGV_PD: +            pars['algorithm'] = algorithm +            pars['input'] = None +            pars['first_order_term'] = None +            pars['second_order_term'] = None +            pars['number_of_iterations'] = None +            pars['regularization_parameter'] = None +             +        else: +            raise Exception('Unknown regularizer algorithm') +             +        return pars +    # parsForAlgorithm +     +    def setParameter(self, **kwargs): +        '''set named parameter for the regularization engine +         +        raises Exception if the named parameter is not recognized +        Typical usage is: +             +        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +        reg.setParameter(input=u0)     +        reg.setParameter(regularization_parameter=10.) +         +        it can be also used as +        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +        reg.setParameter(input=u0 , regularization_parameter=10.) +        ''' +         +        for key , value in kwargs.items(): +            if key in self.pars.keys(): +                self.pars[key] = value +            else: +                raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) +    # setParameter +	 +    def getParameter(self, **kwargs): +        ret = {} +        for key , value in kwargs.items(): +            if key in self.pars.keys(): +                ret[key] = self.pars[key] +        else: +            raise Exception('Wrong parameter {0} for regularizer algorithm'.format(key)) +    # setParameter +	 +         +    def __call__(self, input = None, regularization_parameter = None, **kwargs): +        '''Actual call for the regularizer.  +         +        One can either set the regularization parameters first and then call the +        algorithm or set the regularization parameter during the call (as  +        is done in the static methods).  +        ''' +         +        if kwargs is not None: +            for key, value in kwargs.items(): +                #print("{0} = {1}".format(key, value))                         +                self.pars[key] = value +                     +        if input is not None:  +            self.pars['input'] = input +        if regularization_parameter is not None: +            self.pars['regularization_parameter'] = regularization_parameter +             +        if self.debug: +            print ("--------------------------------------------------") +            for key, value in self.pars.items(): +                if key== 'algorithm' : +                    print("{0} = {1}".format(key, value.__name__)) +                elif key == 'input': +                    print("{0} = {1}".format(key, np.shape(value))) +                else: +                    print("{0} = {1}".format(key, value)) +         +             +        if None in self.pars: +                raise Exception("Not all parameters have been provided") +         +        input = self.pars['input'] +        regularization_parameter = self.pars['regularization_parameter'] +        if self.algorithm == Regularizer.Algorithm.SplitBregman_TV : +            return self.algorithm(input, regularization_parameter, +                              self.pars['number_of_iterations'], +                              self.pars['tolerance_constant'], +                              self.pars['TV_penalty'].value )     +        elif self.algorithm == Regularizer.Algorithm.FGP_TV : +            return self.algorithm(input, regularization_parameter, +                              self.pars['number_of_iterations'], +                              self.pars['tolerance_constant'], +                              self.pars['TV_penalty'].value ) +        elif self.algorithm == Regularizer.Algorithm.LLT_model : +            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) +            # no default +            return self.algorithm(input,  +                              regularization_parameter, +                              self.pars['time_step'] ,  +                              self.pars['number_of_iterations'], +                              self.pars['tolerance_constant'], +                              self.pars['restrictive_Z_smoothing'] ) +        elif self.algorithm == Regularizer.Algorithm.PatchBased_Regul : +            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) +            # no default +            return self.algorithm(input, regularization_parameter, +                                  self.pars['searching_window_ratio'] ,  +                                  self.pars['similarity_window_ratio'] ,  +                                  self.pars['PB_filtering_parameter']) +        elif self.algorithm == Regularizer.Algorithm.TGV_PD : +            #LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) +            # no default +            if len(np.shape(input)) == 2: +                return self.algorithm(input, regularization_parameter, +                                  self.pars['first_order_term'] ,  +                                  self.pars['second_order_term'] ,  +                                  self.pars['number_of_iterations']) +            elif len(np.shape(input)) == 3: +                #assuming it's 3D +                # run independent calls on each slice +                out3d = input.copy() +                for i in range(np.shape(input)[2]): +                    out = self.algorithm(input, regularization_parameter, +                                 self.pars['first_order_term'] ,  +                                 self.pars['second_order_term'] ,  +                                 self.pars['number_of_iterations']) +                    # copy the result in the 3D image +                    out3d.T[i] = out[0].copy() +                # append the rest of the info that the algorithm returns +                output = [out3d] +                for i in range(1,len(out)): +                    output.append(out[i]) +                return output +                 +                 +             +             +         +    # __call__ +     +    @staticmethod +    def SplitBregman_TV(input, regularization_parameter , **kwargs): +        start_time = timeit.default_timer() +        reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +        out = list( reg(input, regularization_parameter, **kwargs) ) +        out.append(reg.pars) +        txt = reg.printParametersToString() +        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +        out.append(txt) +        return out +         +    @staticmethod +    def FGP_TV(input, regularization_parameter , **kwargs): +        start_time = timeit.default_timer() +        reg = Regularizer(Regularizer.Algorithm.FGP_TV) +        out = list( reg(input, regularization_parameter, **kwargs) ) +        out.append(reg.pars) +        txt = reg.printParametersToString() +        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +        out.append(txt) +        return out +     +    @staticmethod +    def LLT_model(input, regularization_parameter , time_step, number_of_iterations, +                  tolerance_constant, restrictive_Z_smoothing=0): +        start_time = timeit.default_timer() +        reg = Regularizer(Regularizer.Algorithm.LLT_model) +        out = list( reg(input, regularization_parameter, time_step=time_step,  +                        number_of_iterations=number_of_iterations, +                        tolerance_constant=tolerance_constant,  +                        restrictive_Z_smoothing=restrictive_Z_smoothing) ) +        out.append(reg.pars) +        txt = reg.printParametersToString() +        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +        out.append(txt) +        return out +     +    @staticmethod +    def PatchBased_Regul(input, regularization_parameter, +                        searching_window_ratio,  +                        similarity_window_ratio, +                        PB_filtering_parameter): +        start_time = timeit.default_timer() +        reg = Regularizer(Regularizer.Algorithm.PatchBased_Regul)    +        out = list( reg(input,  +                        regularization_parameter, +                        searching_window_ratio=searching_window_ratio,  +                        similarity_window_ratio=similarity_window_ratio, +                        PB_filtering_parameter=PB_filtering_parameter ) +            ) +        out.append(reg.pars) +        txt = reg.printParametersToString() +        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +        out.append(txt) +        return out +     +    @staticmethod +    def TGV_PD(input, regularization_parameter , first_order_term,  +               second_order_term, number_of_iterations): +        start_time = timeit.default_timer() +         +        reg = Regularizer(Regularizer.Algorithm.TGV_PD) +        out = list( reg(input, regularization_parameter,  +                        first_order_term=first_order_term,  +                        second_order_term=second_order_term, +                        number_of_iterations=number_of_iterations) ) +        out.append(reg.pars) +        txt = reg.printParametersToString() +        txt += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +        out.append(txt) +         +        return out +     +    def printParametersToString(self): +        txt = r'' +        for key, value in self.pars.items(): +            if key== 'algorithm' : +                txt += "{0} = {1}".format(key, value.__name__) +            elif key == 'input': +                txt += "{0} = {1}".format(key, np.shape(value)) +            else: +                txt += "{0} = {1}".format(key, value) +            txt += '\n' +        return txt +         diff --git a/src/Python/ccpi/imaging/__init__.py b/src/Python/ccpi/imaging/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/Python/ccpi/imaging/__init__.py diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py new file mode 100644 index 0000000..c903712 --- /dev/null +++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py @@ -0,0 +1,612 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +#from ccpi.reconstruction.parallelbeam import alg + +#from ccpi.imaging.Regularizer import Regularizer +from enum import Enum + +import astra + +    +     +class FISTAReconstructor(): +    '''FISTA-based reconstruction algorithm using ASTRA-toolbox +     +    ''' +    # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> +    # ___Input___: +    # params.[] file: +    #       - .proj_geom (geometry of the projector) [required] +    #       - .vol_geom (geometry of the reconstructed object) [required] +    #       - .sino (vectorized in 2D or 3D sinogram) [required] +    #       - .iterFISTA (iterations for the main loop, default 40) +    #       - .L_const (Lipschitz constant, default Power method)                                                                                                    ) +    #       - .X_ideal (ideal image, if given) +    #       - .weights (statisitcal weights, size of the sinogram) +    #       - .ROI (Region-of-interest, only if X_ideal is given) +    #       - .initialize (a 'warm start' using SIRT method from ASTRA) +    #----------------Regularization choices------------------------ +    #       - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) +    #       - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) +    #       - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) +    #       - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) +    #       - .Regul_Iterations (iterations for the selected penalty, default 25) +    #       - .Regul_tauLLT (time step parameter for LLT term) +    #       - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) +    #       - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) +    #----------------Visualization parameters------------------------ +    #       - .show (visualize reconstruction 1/0, (0 default)) +    #       - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) +    #       - .slice (for 3D volumes - slice number to imshow) +    # ___Output___: +    # 1. X - reconstructed image/volume +    # 2. output - a structure with +    #    - .Resid_error - residual error (if X_ideal is given) +    #    - .objective: value of the objective function +    #    - .L_const: Lipshitz constant to avoid recalculations +     +    # References: +    # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse +    # Problems" by A. Beck and M Teboulle +    # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo +    # 3. "A novel tomographic reconstruction method based on the robust +    # Student's t function for suppressing data outliers" D. Kazantsev et.al. +    # D. Kazantsev, 2016-17 +    def __init__(self, projector_geometry, output_geometry, input_sinogram, +                 **kwargs): +        # handle parmeters: +        # obligatory parameters +        self.pars = dict() +        self.pars['projector_geometry'] = projector_geometry # proj_geom +        self.pars['output_geometry'] = output_geometry       # vol_geom +        self.pars['input_sinogram'] = input_sinogram         # sino +        sliceZ, nangles, detectors = numpy.shape(input_sinogram) +        self.pars['detectors'] = detectors +        self.pars['number_of_angles'] = nangles +        self.pars['SlicesZ'] = sliceZ +        self.pars['output_volume'] = None + +        print (self.pars) +        # handle optional input parameters (at instantiation) +         +        # Accepted input keywords +        kw = ( +              # mandatory fields +              'projector_geometry', +              'output_geometry', +              'input_sinogram', +              'detectors', +              'number_of_angles', +              'SlicesZ', +              # optional fields +              'number_of_iterations',  +              'Lipschitz_constant' ,  +              'ideal_image' , +              'weights' ,  +              'region_of_interest' ,  +              'initialize' ,  +              'regularizer' ,  +              'ring_lambda_R_L1', +              'ring_alpha', +              'subsets', +              'output_volume', +              'os_subsets', +              'os_indices', +              'os_bins') +        self.acceptedInputKeywords = list(kw) +         +        # handle keyworded parameters +        if kwargs is not None: +            for key, value in kwargs.items(): +                if key in kw: +                    #print("{0} = {1}".format(key, value))                         +                    self.pars[key] = value +                     +        # set the default values for the parameters if not set +        if 'number_of_iterations' in kwargs.keys(): +            self.pars['number_of_iterations'] = kwargs['number_of_iterations'] +        else: +            self.pars['number_of_iterations'] = 40 +        if 'weights' in kwargs.keys(): +            self.pars['weights'] = kwargs['weights'] +        else: +            self.pars['weights'] = \ +                                 numpy.ones(numpy.shape( +                                     self.pars['input_sinogram'])) +        if 'Lipschitz_constant' in kwargs.keys(): +            self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] +        else: +            self.pars['Lipschitz_constant'] = None +         +        if not 'ideal_image' in kwargs.keys(): +            self.pars['ideal_image'] = None +         +        if not 'region_of_interest'in kwargs.keys() : +            if self.pars['ideal_image'] == None: +                self.pars['region_of_interest'] = None +            else: +                ## nonzero if the image is larger than m +                fsm = numpy.frompyfunc(lambda x,m: 1 if x>m else 0, 2,1) +                 +                self.pars['region_of_interest'] = fsm(self.pars['ideal_image'], 0) +                 +        # the regularizer must be a correctly instantiated object     +        if not 'regularizer' in kwargs.keys() : +            self.pars['regularizer'] = None + +        #RING REMOVAL +        if not 'ring_lambda_R_L1' in kwargs.keys(): +            self.pars['ring_lambda_R_L1'] = 0 +        if not 'ring_alpha' in kwargs.keys(): +            self.pars['ring_alpha'] = 1 + +        # ORDERED SUBSET +        if not 'subsets' in kwargs.keys(): +            self.pars['subsets'] = 0 +        else: +            self.createOrderedSubsets() + +        if not 'initialize' in kwargs.keys(): +            self.pars['initialize'] = False + +         +             +             +    def setParameter(self, **kwargs): +        '''set named parameter for the reconstructor engine +         +        raises Exception if the named parameter is not recognized +         +        ''' +        for key , value in kwargs.items(): +            if key in self.acceptedInputKeywords: +                self.pars[key] = value +            else: +                raise Exception('Wrong parameter {0} for '.format(key) + +                                'reconstructor') +    # setParameter + +    def getParameter(self, key): +        if type(key) is str: +            if key in self.acceptedInputKeywords: +                return self.pars[key] +            else: +                raise Exception('Unrecongnised parameter: {0} '.format(key) ) +        elif type(key) is list: +            outpars = [] +            for k in key: +                outpars.append(self.getParameter(k)) +            return outpars +        else: +            raise Exception('Unhandled input {0}' .format(str(type(key)))) +             +     +    def calculateLipschitzConstantWithPowerMethod(self): +        ''' using Power method (PM) to establish L constant''' +         +        N = self.pars['output_geometry']['GridColCount'] +        proj_geom = self.pars['projector_geometry'] +        vol_geom = self.pars['output_geometry'] +        weights = self.pars['weights'] +        SlicesZ = self.pars['SlicesZ'] +         +             +                                +        if (proj_geom['type'] == 'parallel') or \ +           (proj_geom['type'] == 'parallel3d'): +            #% for parallel geometry we can do just one slice +            #print('Calculating Lipshitz constant for parallel beam geometry...') +            niter = 5;# % number of iteration for the PM +            #N = params.vol_geom.GridColCount; +            #x1 = rand(N,N,1); +            x1 = numpy.random.rand(1,N,N) +            #sqweight = sqrt(weights(:,:,1)); +            sqweight = numpy.sqrt(weights[0]) +            proj_geomT = proj_geom.copy(); +            proj_geomT['DetectorRowCount'] = 1; +            vol_geomT = vol_geom.copy(); +            vol_geomT['GridSliceCount'] = 1; +             +            #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); +             +             +            for i in range(niter): +            #        [id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geomT, vol_geomT); +            #            s = norm(x1(:)); +            #            x1 = x1/s; +            #            [sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); +            #            y = sqweight.*y; +            #            astra_mex_data3d('delete', sino_id); +            #            astra_mex_data3d('delete', id); +                #print ("iteration {0}".format(i)) +                                 +                sino_id, y = astra.creators.create_sino3d_gpu(x1, +                                                          proj_geomT, +                                                          vol_geomT) +                 +                y = (sqweight * y).copy() # element wise multiplication +                 +                #b=fig.add_subplot(2,1,2) +                #imgplot = plt.imshow(x1[0]) +                #plt.show() +                 +                #astra_mex_data3d('delete', sino_id); +                astra.matlab.data3d('delete', sino_id) +                del x1 +                     +                idx,x1 = astra.creators.create_backprojection3d_gpu((sqweight*y).copy(),  +                                                                    proj_geomT, +                                                                    vol_geomT) +                del y +                 +                                                                     +                s = numpy.linalg.norm(x1) +                ### this line? +                x1 = (x1/s).copy(); +                 +            #        ### this line? +            #        sino_id, y = astra.creators.create_sino3d_gpu(x1,  +            #                                                      proj_geomT,  +            #                                                      vol_geomT); +            #        y = sqweight * y; +                astra.matlab.data3d('delete', sino_id); +                astra.matlab.data3d('delete', idx) +                print ("iteration {0} s= {1}".format(i,s)) +                 +            #end +            del proj_geomT +            del vol_geomT +            #plt.show() +        else: +            #% divergen beam geometry +            print('Calculating Lipshitz constant for divergen beam geometry...') +            niter = 8; #% number of iteration for PM +            x1 = numpy.random.rand(SlicesZ , N , N); +            #sqweight = sqrt(weights); +            sqweight = numpy.sqrt(weights[0]) +             +            sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); +            y = sqweight*y; +            #astra_mex_data3d('delete', sino_id); +            astra.matlab.data3d('delete', sino_id); +             +            for i in range(niter): +                #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); +                idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,  +                                                                    proj_geom,  +                                                                    vol_geom) +                s = numpy.linalg.norm(x1) +                ### this line? +                x1 = x1/s; +                ### this line? +                #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); +                sino_id, y = astra.creators.create_sino3d_gpu(x1,  +                                                              proj_geom,  +                                                              vol_geom); +                 +                y = sqweight*y; +                #astra_mex_data3d('delete', sino_id); +                #astra_mex_data3d('delete', id); +                astra.matlab.data3d('delete', sino_id); +                astra.matlab.data3d('delete', idx); +            #end +            #clear x1 +            del x1 + +         +        return s +     +     +    def setRegularizer(self, regularizer): +        if regularizer is not None: +            self.pars['regularizer'] = regularizer +         + +    def initialize(self): +        # convenience variable storage +        proj_geom = self.pars['projector_geometry'] +        vol_geom = self.pars['output_geometry'] +        sino = self.pars['input_sinogram'] +         +        # a 'warm start' with SIRT method +        # Create a data object for the reconstruction +        rec_id = astra.matlab.data3d('create', '-vol', +                                    vol_geom); +         +        #sinogram_id = astra_mex_data3d('create', '-proj3d', proj_geom, sino); +        sinogram_id = astra.matlab.data3d('create', '-proj3d', +                                          proj_geom, +                                          sino) + +        sirt_config = astra.astra_dict('SIRT3D_CUDA') +        sirt_config['ReconstructionDataId' ] = rec_id +        sirt_config['ProjectionDataId'] = sinogram_id + +        sirt = astra.algorithm.create(sirt_config) +        astra.algorithm.run(sirt, iterations=35) +        X = astra.matlab.data3d('get', rec_id) + +        # clean up memory +        astra.matlab.data3d('delete', rec_id) +        astra.matlab.data3d('delete', sinogram_id) +        astra.algorithm.delete(sirt) + +         + +        return X + +    def createOrderedSubsets(self, subsets=None): +        if subsets is None: +            try: +                subsets = self.getParameter('subsets') +            except Exception(): +                subsets = 0 +            #return subsets +        else: +            self.setParameter(subsets=subsets) +             + +        angles = self.getParameter('projector_geometry')['ProjectionAngles']  +         +        #binEdges = numpy.linspace(angles.min(), +        #                          angles.max(), +        #                          subsets + 1) +        binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) +        # get rearranged subset indices +        IndicesReorg = numpy.zeros((numpy.shape(angles)), dtype=numpy.int32) +        counterM = 0 +        for ii in range(binsDiscr.max()): +            counter = 0 +            for jj in range(subsets): +                curr_index = ii + jj  + counter +                #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) +                if binsDiscr[jj] > ii: +                    if (counterM < numpy.size(IndicesReorg)): +                        IndicesReorg[counterM] = curr_index +                    counterM = counterM + 1 +                     +                counter = counter + binsDiscr[jj] - 1     +                 +        # store the OS in parameters +        self.setParameter(os_subsets=subsets, +                          os_bins=binsDiscr, +                          os_indices=IndicesReorg) +             + +    def prepareForIteration(self): +        print ("FISTA Reconstructor: prepare for iteration") +         +        self.residual_error = numpy.zeros((self.pars['number_of_iterations'])) +        self.objective = numpy.zeros((self.pars['number_of_iterations'])) + +        #2D array (for 3D data) of sparse "ring"  +        detectors, nangles, sliceZ  = numpy.shape(self.pars['input_sinogram']) +        self.r = numpy.zeros((detectors, sliceZ), dtype=numpy.float) +        # another ring variable +        self.r_x = self.r.copy() + +        self.residual = numpy.zeros(numpy.shape(self.pars['input_sinogram'])) +         +        if self.getParameter('Lipschitz_constant') is None: +            self.pars['Lipschitz_constant'] = \ +                            self.calculateLipschitzConstantWithPowerMethod() +        # errors vector (if the ground truth is given) +        self.Resid_error = numpy.zeros((self.getParameter('number_of_iterations'))); +        # objective function values vector +        self.objective = numpy.zeros((self.getParameter('number_of_iterations')));       +         + +    # prepareForIteration + +    def iterate(self, Xin=None): +        print ("FISTA Reconstructor: iterate") +         +        if Xin is None:     +            if self.getParameter('initialize'): +                X = self.initialize() +            else: +                N = vol_geom['GridColCount'] +                X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) +        else: +            # copy by reference +            X = Xin +        # store the output volume in the parameters +        self.setParameter(output_volume=X) +        X_t = X.copy() +        # convenience variable storage +        proj_geom , vol_geom, sino , \ +          SlicesZ  = self.getParameter([ 'projector_geometry' , +                                                'output_geometry', +                                                'input_sinogram', +                                                'SlicesZ' ]) +                    +        t = 1 +         +        for i in range(self.getParameter('number_of_iterations')): +            X_old = X.copy() +            t_old = t +            r_old = self.r.copy() +            if self.getParameter('projector_geometry')['type'] == 'parallel' or \ +               self.getParameter('projector_geometry')['type'] == 'fanflat' or \ +               self.getParameter('projector_geometry')['type'] == 'fanflat_vec': +                # if the geometry is parallel use slice-by-slice +                # projection-backprojection routine +                #sino_updt = zeros(size(sino),'single'); +                proj_geomT = proj_geom.copy() +                proj_geomT['DetectorRowCount'] = 1 +                vol_geomT = vol_geom.copy() +                vol_geomT['GridSliceCount'] = 1; +                self.sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) +                for kkk in range(SlicesZ): +                    sino_id, self.sino_updt[kkk] = \ +                             astra.creators.create_sino3d_gpu( +                                 X_t[kkk:kkk+1], proj_geomT, vol_geomT) +                    astra.matlab.data3d('delete', sino_id) +            else: +                # for divergent 3D geometry (watch the GPU memory overflow in +                # ASTRA versions < 1.8) +                #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); +                sino_id, self.sino_updt = astra.creators.create_sino3d_gpu( +                    X_t, proj_geom, vol_geom) + + +            ## RING REMOVAL +            self.ringRemoval(i) +            ## Projection/Backprojection Routine +            self.projectionBackprojection(X, X_t) +            astra.matlab.data3d('delete', sino_id) +            ## REGULARIZATION +            X = self.regularize(X) +            ## Update Loop +            X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old) +            self.setParameter(output_volume=X) +        return X +    ## iterate +     +    def ringRemoval(self, i): +        print ("FISTA Reconstructor: ring removal") +        residual = self.residual +        lambdaR_L1 , alpha_ring , weights , L_const , sino= \ +                   self.getParameter(['ring_lambda_R_L1', +                                      'ring_alpha' , 'weights', +                                      'Lipschitz_constant', +                                      'input_sinogram']) +        r_x = self.r_x +        sino_updt = self.sino_updt +         +        SlicesZ, anglesNumb, Detectors = \ +                    numpy.shape(self.getParameter('input_sinogram')) +        if lambdaR_L1 > 0 : +             for kkk in range(anglesNumb): +                  +                 residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ +                                       ((sino_updt[:,kkk,:]).squeeze() - \ +                                        (sino[:,kkk,:]).squeeze() -\ +                                        (alpha_ring * r_x) +                                        ) +             vec = residual.sum(axis = 1) +             #if SlicesZ > 1: +             #    vec = vec[:,1,:].squeeze() +             self.r = (r_x - (1./L_const) * vec).copy() +             self.objective[i] = (0.5 * (residual ** 2).sum()) + +    def projectionBackprojection(self, X, X_t): +        print ("FISTA Reconstructor: projection-backprojection routine") +         +        # a few useful variables +        SlicesZ, anglesNumb, Detectors = \ +                    numpy.shape(self.getParameter('input_sinogram')) +        residual = self.residual +        proj_geom , vol_geom , L_const = \ +                  self.getParameter(['projector_geometry' , +                                                  'output_geometry', +                                                  'Lipschitz_constant']) +         +         +        if self.getParameter('projector_geometry')['type'] == 'parallel' or \ +           self.getParameter('projector_geometry')['type'] == 'fanflat' or \ +           self.getParameter('projector_geometry')['type'] == 'fanflat_vec': +            # if the geometry is parallel use slice-by-slice +            # projection-backprojection routine +            #sino_updt = zeros(size(sino),'single'); +            proj_geomT = proj_geom.copy() +            proj_geomT['DetectorRowCount'] = 1 +            vol_geomT = vol_geom.copy() +            vol_geomT['GridSliceCount'] = 1; +            x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) +             +            for kkk in range(SlicesZ): +                 +                x_id, x_temp[kkk] = \ +                         astra.creators.create_backprojection3d_gpu( +                             residual[kkk:kkk+1], +                             proj_geomT, vol_geomT) +                astra.matlab.data3d('delete', x_id) +        else: +            x_id, x_temp = \ +                  astra.creators.create_backprojection3d_gpu( +                      residual, proj_geom, vol_geom)             + +        X = X_t - (1/L_const) * x_temp +        #astra.matlab.data3d('delete', sino_id) +        astra.matlab.data3d('delete', x_id) + +    def regularize(self, X): +        print ("FISTA Reconstructor: regularize") +         +        regularizer = self.getParameter('regularizer') +        if regularizer is not None: +            return regularizer(input=X) +        else: +            return X + +    def updateLoop(self, i, X, X_old, r_old, t, t_old): +        print ("FISTA Reconstructor: update loop") +        lambdaR_L1 = self.getParameter('ring_lambda_R_L1') +        if lambdaR_L1 > 0: +            self.r = numpy.max( +                numpy.abs(self.r) - lambdaR_L1 , 0) * \ +                numpy.sign(self.r) +        t = (1 + numpy.sqrt(1 + 4 * t**2))/2 +        X_t = X + (((t_old -1)/t) * (X - X_old)) + +        if lambdaR_L1 > 0: +            self.r_x = self.r + \ +                             (((t_old-1)/t) * (self.r - r_old)) + +        if self.getParameter('region_of_interest') is None: +            string = 'Iteration Number {0} | Objective {1} \n' +            print (string.format( i, self.objective[i])) +        else: +            ROI , X_ideal = fistaRecon.getParameter('region_of_interest', +                                                    'ideal_image') +             +            Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) +            string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' +            print (string.format(i,Resid_error[i], self.objective[i])) +        return (X , X_t, t) + +    def os_iterate(self, Xin=None): +        print ("FISTA Reconstructor: iterate") +         +        if Xin is None:     +            if self.getParameter('initialize'): +                X = self.initialize() +            else: +                N = vol_geom['GridColCount'] +                X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) +        else: +            # copy by reference +            X = Xin +        # store the output volume in the parameters +        self.setParameter(output_volume=X) +        X_t = X.copy() + +        # some useful constants +        proj_geom , vol_geom, sino , \ +          SlicesZ, weights , alpha_ring ,\ +          lambdaR_L1 , L_const = self.getParameter( +            ['projector_geometry' , 'output_geometry', +             'input_sinogram', 'SlicesZ' ,  'weights', 'ring_alpha' , +             'ring_lambda_R_L1', 'Lipschitz_constant']) diff --git a/src/Python/ccpi/reconstruction/Reconstructor.py b/src/Python/ccpi/reconstruction/Reconstructor.py new file mode 100644 index 0000000..ba67327 --- /dev/null +++ b/src/Python/ccpi/reconstruction/Reconstructor.py @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- +############################################################################### +#This work is part of the Core Imaging Library developed by +#Visual Analytics and Imaging System Group of the Science Technology +#Facilities Council, STFC +# +#Copyright 2017 Edoardo Pasca, Srikanth Nagella +#Copyright 2017 Daniil Kazantsev +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +#http://www.apache.org/licenses/LICENSE-2.0 +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +############################################################################### + + + +import numpy +import h5py +from ccpi.reconstruction.parallelbeam import alg + +from Regularizer import Regularizer +from enum import Enum + +import astra + + +class Reconstructor: +     +    class Algorithm(Enum): +        CGLS = alg.cgls +        CGLS_CONV = alg.cgls_conv +        SIRT = alg.sirt +        MLEM = alg.mlem +        CGLS_TICHONOV = alg.cgls_tikhonov +        CGLS_TVREG = alg.cgls_TVreg +        FISTA = 'fista' +         +    def __init__(self, algorithm = None, projection_data = None, +                 angles = None, center_of_rotation = None ,  +                 flat_field = None, dark_field = None,  +                 iterations = None, resolution = None, isLogScale = False, threads = None,  +                 normalized_projection = None): +     +        self.pars = dict() +        self.pars['algorithm'] = algorithm +        self.pars['projection_data'] = projection_data +        self.pars['normalized_projection'] = normalized_projection +        self.pars['angles'] = angles +        self.pars['center_of_rotation'] = numpy.double(center_of_rotation) +        self.pars['flat_field'] = flat_field +        self.pars['iterations'] = iterations +        self.pars['dark_field'] = dark_field +        self.pars['resolution'] = resolution +        self.pars['isLogScale'] = isLogScale +        self.pars['threads'] = threads +        if (iterations != None): +            self.pars['iterationValues'] = numpy.zeros((iterations))  +         +        if projection_data != None and dark_field != None and flat_field != None: +            norm = self.normalize(projection_data, dark_field, flat_field, 0.1) +            self.pars['normalized_projection'] = norm +             +     +    def setPars(self, parameters): +        keys = ['algorithm','projection_data' ,'normalized_projection', \ +                'angles' , 'center_of_rotation' , 'flat_field', \ +                'iterations','dark_field' , 'resolution', 'isLogScale' , \ +                'threads' , 'iterationValues', 'regularize'] +         +        for k in keys: +            if k not in parameters.keys(): +                self.pars[k] = None +            else: +                self.pars[k] = parameters[k] +                 +         +    def sanityCheck(self): +        projection_data = self.pars['projection_data'] +        dark_field = self.pars['dark_field'] +        flat_field = self.pars['flat_field'] +        angles = self.pars['angles'] +         +        if projection_data != None and dark_field != None and \ +            angles != None and flat_field != None: +            data_shape =  numpy.shape(projection_data) +            angle_shape = numpy.shape(angles) +             +            if angle_shape[0] != data_shape[0]: +                #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ +                #                (angle_shape[0] , data_shape[0]) ) +                return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ +                                (angle_shape[0] , data_shape[0]) ) +             +            if data_shape[1:] != numpy.shape(flat_field): +                #raise Exception('Projection and flat field dimensions do not match') +                return (False , 'Projection and flat field dimensions do not match') +            if data_shape[1:] != numpy.shape(dark_field): +                #raise Exception('Projection and dark field dimensions do not match') +                return (False , 'Projection and dark field dimensions do not match') +             +            return (True , '' ) +        elif self.pars['normalized_projection'] != None: +            data_shape =  numpy.shape(self.pars['normalized_projection']) +            angle_shape = numpy.shape(angles) +             +            if angle_shape[0] != data_shape[0]: +                #raise Exception('Projections and angles dimensions do not match: %d vs %d' % \ +                #                (angle_shape[0] , data_shape[0]) ) +                return (False , 'Projections and angles dimensions do not match: %d vs %d' % \ +                                (angle_shape[0] , data_shape[0]) ) +            else: +                return (True , '' ) +        else: +            return (False , 'Not enough data') +             +    def reconstruct(self, parameters = None): +        if parameters != None: +            self.setPars(parameters) +         +        go , reason = self.sanityCheck() +        if go: +            return self._reconstruct() +        else: +            raise Exception(reason) +             +             +    def _reconstruct(self, parameters=None): +        if parameters!=None: +            self.setPars(parameters) +        parameters = self.pars +         +        if parameters['algorithm'] != None and \ +           parameters['normalized_projection'] != None and \ +           parameters['angles'] != None and \ +           parameters['center_of_rotation'] != None and \ +           parameters['iterations'] != None and \ +           parameters['resolution'] != None and\ +           parameters['threads'] != None and\ +           parameters['isLogScale'] != None: +                +                +           if parameters['algorithm'] in (Reconstructor.Algorithm.CGLS, +                        Reconstructor.Algorithm.MLEM, Reconstructor.Algorithm.SIRT): +               #store parameters +               self.pars = parameters +               result = parameters['algorithm']( +                           parameters['normalized_projection'] , +                           parameters['angles'], +                           parameters['center_of_rotation'], +                           parameters['resolution'], +                           parameters['iterations'], +                           parameters['threads'] , +                           parameters['isLogScale'] +                           ) +               return result +           elif parameters['algorithm'] in (Reconstructor.Algorithm.CGLS_CONV, +                          Reconstructor.Algorithm.CGLS_TICHONOV,  +                          Reconstructor.Algorithm.CGLS_TVREG) : +               self.pars = parameters +               result = parameters['algorithm']( +                           parameters['normalized_projection'] , +                           parameters['angles'], +                           parameters['center_of_rotation'], +                           parameters['resolution'], +                           parameters['iterations'], +                           parameters['threads'] , +                           parameters['regularize'], +                           numpy.zeros((parameters['iterations'])), +                           parameters['isLogScale'] +                           ) +                +           elif parameters['algorithm'] == Reconstructor.Algorithm.FISTA: +               pass +              +        else: +           if parameters['projection_data'] != None and \ +                     parameters['dark_field'] != None and \ +                     parameters['flat_field'] != None: +               norm = self.normalize(parameters['projection_data'], +                                   parameters['dark_field'],  +                                   parameters['flat_field'], 0.1) +               self.pars['normalized_projection'] = norm +               return self._reconstruct(parameters) +               +                 +                 +    def _normalize(self, projection, dark, flat, def_val=0): +        a = (projection - dark) +        b = (flat-dark) +        with numpy.errstate(divide='ignore', invalid='ignore'): +            c = numpy.true_divide( a, b ) +            c[ ~ numpy.isfinite( c )] = def_val  # set to not zero if 0/0  +        return c +     +    def normalize(self, projections, dark, flat, def_val=0): +        norm = [self._normalize(projection, dark, flat, def_val) for projection in projections] +        return numpy.asarray (norm, dtype=numpy.float32) +         +     +     +class FISTA(): +    '''FISTA-based reconstruction algorithm using ASTRA-toolbox +     +    ''' +    # <<<< FISTA-based reconstruction algorithm using ASTRA-toolbox >>>> +    # ___Input___: +    # params.[] file: +    #       - .proj_geom (geometry of the projector) [required] +    #       - .vol_geom (geometry of the reconstructed object) [required] +    #       - .sino (vectorized in 2D or 3D sinogram) [required] +    #       - .iterFISTA (iterations for the main loop, default 40) +    #       - .L_const (Lipschitz constant, default Power method)                                                                                                    ) +    #       - .X_ideal (ideal image, if given) +    #       - .weights (statisitcal weights, size of the sinogram) +    #       - .ROI (Region-of-interest, only if X_ideal is given) +    #       - .initialize (a 'warm start' using SIRT method from ASTRA) +    #----------------Regularization choices------------------------ +    #       - .Regul_Lambda_FGPTV (FGP-TV regularization parameter) +    #       - .Regul_Lambda_SBTV (SplitBregman-TV regularization parameter) +    #       - .Regul_Lambda_TVLLT (Higher order SB-LLT regularization parameter) +    #       - .Regul_tol (tolerance to terminate regul iterations, default 1.0e-04) +    #       - .Regul_Iterations (iterations for the selected penalty, default 25) +    #       - .Regul_tauLLT (time step parameter for LLT term) +    #       - .Ring_LambdaR_L1 (regularization parameter for L1-ring minimization, if lambdaR_L1 > 0 then switch on ring removal) +    #       - .Ring_Alpha (larger values can accelerate convergence but check stability, default 1) +    #----------------Visualization parameters------------------------ +    #       - .show (visualize reconstruction 1/0, (0 default)) +    #       - .maxvalplot (maximum value to use for imshow[0 maxvalplot]) +    #       - .slice (for 3D volumes - slice number to imshow) +    # ___Output___: +    # 1. X - reconstructed image/volume +    # 2. output - a structure with +    #    - .Resid_error - residual error (if X_ideal is given) +    #    - .objective: value of the objective function +    #    - .L_const: Lipshitz constant to avoid recalculations +     +    # References: +    # 1. "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse +    # Problems" by A. Beck and M Teboulle +    # 2. "Ring artifacts correction in compressed sensing..." by P. Paleo +    # 3. "A novel tomographic reconstruction method based on the robust +    # Student's t function for suppressing data outliers" D. Kazantsev et.al. +    # D. Kazantsev, 2016-17 +    def __init__(self, projector_geometry, output_geometry, input_sinogram, **kwargs): +        self.params = dict() +        self.params['projector_geometry'] = projector_geometry +        self.params['output_geometry'] = output_geometry +        self.params['input_sinogram'] = input_sinogram +        detectors, nangles, sliceZ = numpy.shape(input_sinogram) +        self.params['detectors'] = detectors +        self.params['number_og_angles'] = nangles +        self.params['SlicesZ'] = sliceZ +         +        # Accepted input keywords +        kw = ('number_of_iterations', 'Lipschitz_constant' , 'ideal_image' , +              'weights' , 'region_of_interest' , 'initialize' ,  +              'regularizer' ,  +              'ring_lambda_R_L1', +              'ring_alpha') +         +        # handle keyworded parameters +        if kwargs is not None: +            for key, value in kwargs.items(): +                if key in kw: +                    #print("{0} = {1}".format(key, value))                         +                    self.pars[key] = value +                     +        # set the default values for the parameters if not set +        if 'number_of_iterations' in kwargs.keys(): +            self.pars['number_of_iterations'] = kwargs['number_of_iterations'] +        else: +            self.pars['number_of_iterations'] = 40 +        if 'weights' in kwargs.keys(): +            self.pars['weights'] = kwargs['weights'] +        else: +            self.pars['weights'] = numpy.ones(numpy.shape(self.params['input_sinogram'])) +        if 'Lipschitz_constant' in kwargs.keys(): +            self.pars['Lipschitz_constant'] = kwargs['Lipschitz_constant'] +        else: +            self.pars['Lipschitz_constant'] = self.calculateLipschitzConstantWithPowerMethod() +         +        if not self.pars['ideal_image'] in kwargs.keys(): +            self.pars['ideal_image'] = None +         +        if not self.pars['region_of_interest'] : +            if self.pars['ideal_image'] == None: +                pass +            else: +                self.pars['region_of_interest'] = numpy.nonzero(self.pars['ideal_image']>0.0) +             +        if not self.pars['regularizer'] : +            self.pars['regularizer'] = None +        else: +            # the regularizer must be a correctly instantiated object +            if not self.pars['ring_lambda_R_L1']: +                self.pars['ring_lambda_R_L1'] = 0 +            if not self.pars['ring_alpha']: +                self.pars['ring_alpha'] = 1 +         +             +             +         +    def calculateLipschitzConstantWithPowerMethod(self): +        ''' using Power method (PM) to establish L constant''' +         +        #N = params.vol_geom.GridColCount +        N = self.pars['output_geometry'].GridColCount +        proj_geom = self.params['projector_geometry'] +        vol_geom = self.params['output_geometry'] +        weights = self.pars['weights'] +        SlicesZ = self.pars['SlicesZ'] +         +        if (proj_geom['type'] == 'parallel') or (proj_geom['type'] == 'parallel3d'): +            #% for parallel geometry we can do just one slice +            #fprintf('%s \n', 'Calculating Lipshitz constant for parallel beam geometry...'); +            niter = 15;# % number of iteration for the PM +            #N = params.vol_geom.GridColCount; +            #x1 = rand(N,N,1); +            x1 = numpy.random.rand(1,N,N) +            #sqweight = sqrt(weights(:,:,1)); +            sqweight = numpy.sqrt(weights.T[0]) +            proj_geomT = proj_geom.copy(); +            proj_geomT.DetectorRowCount = 1; +            vol_geomT = vol_geom.copy(); +            vol_geomT['GridSliceCount'] = 1; +             +             +            for i in range(niter): +                if i == 0: +                    #[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); +                    sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geomT, vol_geomT); +                    y = sqweight * y # element wise multiplication +                    #astra_mex_data3d('delete', sino_id); +                    astra.matlab.data3d('delete', sino_id) +                     +                idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y, proj_geomT, vol_geomT); +                s = numpy.linalg.norm(x1) +                ### this line? +                x1 = x1/s; +                ### this line? +                sino_id, y = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); +                y = sqweight*y; +                astra.matlab.data3d('delete', sino_id); +                astra.matlab.data3d('delete', idx); +            #end +            del proj_geomT +            del vol_geomT +        else +            #% divergen beam geometry +            #fprintf('%s \n', 'Calculating Lipshitz constant for divergen beam geometry...'); +            niter = 8; #% number of iteration for PM +            x1 = numpy.random.rand(SlicesZ , N , N); +            #sqweight = sqrt(weights); +            sqweight = numpy.sqrt(weights.T[0]) +             +            sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom); +            y = sqweight*y; +            #astra_mex_data3d('delete', sino_id); +            astra.matlab.data3d('delete', sino_id); +             +            for i in range(niter): +                #[id,x1] = astra_create_backprojection3d_cuda(sqweight.*y, proj_geom, vol_geom); +                idx,x1 = astra.creators.create_backprojection3d_gpu(sqweight*y,  +                                                                    proj_geom,  +                                                                    vol_geom) +                s = numpy.linalg.norm(x1) +                ### this line? +                x1 = x1/s; +                ### this line? +                #[sino_id, y] = astra_create_sino3d_gpu(x1, proj_geom, vol_geom); +                sino_id, y = astra.creators.create_sino3d_gpu(x1,  +                                                              proj_geom,  +                                                              vol_geom); +                 +                y = sqweight*y; +                #astra_mex_data3d('delete', sino_id); +                #astra_mex_data3d('delete', id); +                astra.matlab.data3d('delete', sino_id); +                astra.matlab.data3d('delete', idx); +            #end +            #clear x1 +            del x1 +         +        return s +     +     +    def setRegularizer(self, regularizer): +        if regularizer +        self.pars['regularizer'] = regularizer +         +     +     + + +def getEntry(location): +    for item in nx[location].keys(): +        print (item) + + +print ("Loading Data") + +##fname = "D:\\Documents\\Dataset\\IMAT\\20170419_crabtomo\\crabtomo\\Sample\\IMAT00005153_crabstomo_Sample_000.tif" +####ind = [i * 1049 for i in range(360)] +#### use only 360 images +##images = 200 +##ind = [int(i * 1049 / images) for i in range(images)] +##stack_image = dxchange.reader.read_tiff_stack(fname, ind, digit=None, slc=None) + +#fname = "D:\\Documents\\Dataset\\CGLS\\24737_fd.nxs" +fname = "C:\\Users\\ofn77899\\Documents\\CCPi\\CGLS\\24737_fd_2.nxs" +nx = h5py.File(fname, "r") + +# the data are stored in a particular location in the hdf5 +for item in nx['entry1/tomo_entry/data'].keys(): +    print (item) + +data = nx.get('entry1/tomo_entry/data/rotation_angle') +angles = numpy.zeros(data.shape) +data.read_direct(angles) +print (angles) +# angles should be in degrees + +data = nx.get('entry1/tomo_entry/data/data') +stack = numpy.zeros(data.shape) +data.read_direct(stack) +print (data.shape) + +print ("Data Loaded") + + +# Normalize +data = nx.get('entry1/tomo_entry/instrument/detector/image_key') +itype = numpy.zeros(data.shape) +data.read_direct(itype) +# 2 is dark field +darks = [stack[i] for i in range(len(itype)) if itype[i] == 2 ] +dark = darks[0] +for i in range(1, len(darks)): +    dark += darks[i] +dark = dark / len(darks) +#dark[0][0] = dark[0][1] + +# 1 is flat field +flats = [stack[i] for i in range(len(itype)) if itype[i] == 1 ] +flat = flats[0] +for i in range(1, len(flats)): +    flat += flats[i] +flat = flat / len(flats) +#flat[0][0] = dark[0][1] + + +# 0 is projection data +proj = [stack[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = [angles[i] for i in range(len(itype)) if itype[i] == 0 ] +angle_proj = numpy.asarray (angle_proj) +angle_proj = angle_proj.astype(numpy.float32) + +# normalized data are +# norm = (projection - dark)/(flat-dark) + +def normalize(projection, dark, flat, def_val=0.1): +    a = (projection - dark) +    b = (flat-dark) +    with numpy.errstate(divide='ignore', invalid='ignore'): +        c = numpy.true_divide( a, b ) +        c[ ~ numpy.isfinite( c )] = def_val  # set to not zero if 0/0  +    return c +     + +norm = [normalize(projection, dark, flat) for projection in proj] +norm = numpy.asarray (norm) +norm = norm.astype(numpy.float32) + +#recon = Reconstructor(algorithm = Algorithm.CGLS, normalized_projection = norm, +#                 angles = angle_proj, center_of_rotation = 86.2 ,  +#                 flat_field = flat, dark_field = dark,  +#                 iterations = 15, resolution = 1, isLogScale = False, threads = 3) + +#recon = Reconstructor(algorithm = Reconstructor.Algorithm.CGLS, projection_data = proj, +#                 angles = angle_proj, center_of_rotation = 86.2 ,  +#                 flat_field = flat, dark_field = dark,  +#                 iterations = 15, resolution = 1, isLogScale = False, threads = 3) +#img_cgls = recon.reconstruct() +# +#pars = dict() +#pars['algorithm'] = Reconstructor.Algorithm.SIRT +#pars['projection_data'] = proj +#pars['angles'] = angle_proj +#pars['center_of_rotation'] = numpy.double(86.2) +#pars['flat_field'] = flat +#pars['iterations'] = 15 +#pars['dark_field'] = dark +#pars['resolution'] = 1 +#pars['isLogScale'] = False +#pars['threads'] = 3 +# +#img_sirt = recon.reconstruct(pars) +# +#recon.pars['algorithm'] = Reconstructor.Algorithm.MLEM +#img_mlem = recon.reconstruct() + +############################################################ +############################################################ +#recon.pars['algorithm'] = Reconstructor.Algorithm.CGLS_CONV +#recon.pars['regularize'] = numpy.double(0.1) +#img_cgls_conv = recon.reconstruct() + +niterations = 15 +threads = 3 + +img_cgls = alg.cgls(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_mlem = alg.mlem(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) +img_sirt = alg.sirt(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, False) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_conv = alg.cgls_conv(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +                              iteration_values, False) +print ("iteration values %s" % str(iteration_values)) + +iteration_values = numpy.zeros((niterations,)) +img_cgls_tikhonov = alg.cgls_tikhonov(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +                                      numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) +iteration_values = numpy.zeros((niterations,)) +img_cgls_TVreg = alg.cgls_TVreg(norm, angle_proj, numpy.double(86.2), 1 , niterations, threads, +                                      numpy.double(1e-5), iteration_values , False) +print ("iteration values %s" % str(iteration_values)) + + +##numpy.save("cgls_recon.npy", img_data) +import matplotlib.pyplot as plt +fig, ax = plt.subplots(1,6,sharey=True) +ax[0].imshow(img_cgls[80]) +ax[0].axis('off')  # clear x- and y-axes +ax[1].imshow(img_sirt[80]) +ax[1].axis('off')  # clear x- and y-axes +ax[2].imshow(img_mlem[80]) +ax[2].axis('off')  # clear x- and y-axesplt.show() +ax[3].imshow(img_cgls_conv[80]) +ax[3].axis('off')  # clear x- and y-axesplt.show() +ax[4].imshow(img_cgls_tikhonov[80]) +ax[4].axis('off')  # clear x- and y-axesplt.show() +ax[5].imshow(img_cgls_TVreg[80]) +ax[5].axis('off')  # clear x- and y-axesplt.show() + + +plt.show() + +#viewer = edo.CILViewer() +#viewer.setInputAsNumpy(img_cgls2) +#viewer.displaySliceActor(0) +#viewer.startRenderLoop() + +import vtk + +def NumpyToVTKImageData(numpyarray): +    if (len(numpy.shape(numpyarray)) == 3): +        doubleImg = vtk.vtkImageData() +        shape = numpy.shape(numpyarray) +        doubleImg.SetDimensions(shape[0], shape[1], shape[2]) +        doubleImg.SetOrigin(0,0,0) +        doubleImg.SetSpacing(1,1,1) +        doubleImg.SetExtent(0, shape[0]-1, 0, shape[1]-1, 0, shape[2]-1) +        #self.img3D.SetScalarType(vtk.VTK_UNSIGNED_SHORT, vtk.vtkInformation()) +        doubleImg.AllocateScalars(vtk.VTK_DOUBLE,1) +         +        for i in range(shape[0]): +            for j in range(shape[1]): +                for k in range(shape[2]): +                    doubleImg.SetScalarComponentFromDouble( +                        i,j,k,0, numpyarray[i][j][k]) +    #self.setInput3DData( numpy_support.numpy_to_vtk(numpyarray) ) +        # rescale to appropriate VTK_UNSIGNED_SHORT +        stats = vtk.vtkImageAccumulate() +        stats.SetInputData(doubleImg) +        stats.Update() +        iMin = stats.GetMin()[0] +        iMax = stats.GetMax()[0] +        scale = vtk.VTK_UNSIGNED_SHORT_MAX / (iMax - iMin) + +        shiftScaler = vtk.vtkImageShiftScale () +        shiftScaler.SetInputData(doubleImg) +        shiftScaler.SetScale(scale) +        shiftScaler.SetShift(iMin) +        shiftScaler.SetOutputScalarType(vtk.VTK_UNSIGNED_SHORT) +        shiftScaler.Update() +        return shiftScaler.GetOutput() +         +#writer = vtk.vtkMetaImageWriter() +#writer.SetFileName(alg + "_recon.mha") +#writer.SetInputData(NumpyToVTKImageData(img_cgls2)) +#writer.Write() diff --git a/src/Python/ccpi/reconstruction/__init__.py b/src/Python/ccpi/reconstruction/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/Python/ccpi/reconstruction/__init__.py diff --git a/src/Python/compile-fista.bat.in b/src/Python/compile-fista.bat.in new file mode 100644 index 0000000..b1db686 --- /dev/null +++ b/src/Python/compile-fista.bat.in @@ -0,0 +1,7 @@ +set CIL_VERSION=@CIL_VERSION@ + +set PREFIX=@CONDA_ENVIRONMENT_PREFIX@ +set LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ + +REM activate @CONDA_ENVIRONMENT@ +conda build fista-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi -c conda-forge diff --git a/src/Python/compile-fista.sh.in b/src/Python/compile-fista.sh.in new file mode 100644 index 0000000..267f014 --- /dev/null +++ b/src/Python/compile-fista.sh.in @@ -0,0 +1,9 @@ +#!/bin/sh +# compile within the right conda environment +#module load python/anaconda +#source activate @CONDA_ENVIRONMENT@ + +export CIL_VERSION=@CIL_VERSION@ +export LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ + +conda build fista-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi diff --git a/src/Python/compile.bat.in b/src/Python/compile.bat.in new file mode 100644 index 0000000..e5342ed --- /dev/null +++ b/src/Python/compile.bat.in @@ -0,0 +1,7 @@ +set CIL_VERSION=@CIL_VERSION@ + +set PREFIX=@CONDA_ENVIRONMENT_PREFIX@ +set LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ + +REM activate @CONDA_ENVIRONMENT@ +conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi -c conda-forge
\ No newline at end of file diff --git a/src/Python/compile.sh.in b/src/Python/compile.sh.in new file mode 100644 index 0000000..93fdba2 --- /dev/null +++ b/src/Python/compile.sh.in @@ -0,0 +1,9 @@ +#!/bin/sh +# compile within the right conda environment +#module load python/anaconda +#source activate @CONDA_ENVIRONMENT@ + +export CIL_VERSION=@CIL_VERSION@ +export LIBRARY_INC=@CONDA_ENVIRONMENT_LIBRARY_INC@ + +conda build conda-recipe --python=@PYTHON_VERSION_MAJOR@.@PYTHON_VERSION_MINOR@ --numpy=@NUMPY_VERSION@ -c ccpi diff --git a/src/Python/conda-recipe/bld.bat b/src/Python/conda-recipe/bld.bat new file mode 100644 index 0000000..69491de --- /dev/null +++ b/src/Python/conda-recipe/bld.bat @@ -0,0 +1,14 @@ +IF NOT DEFINED CIL_VERSION ( +ECHO CIL_VERSION Not Defined. +exit 1 +) + +mkdir "%SRC_DIR%\ccpi" +xcopy /e "%RECIPE_DIR%\..\.." "%SRC_DIR%\ccpi" + +cd %SRC_DIR%\ccpi\Python + +%PYTHON% setup.py build_ext +if errorlevel 1 exit 1 +%PYTHON% setup.py install +if errorlevel 1 exit 1 diff --git a/src/Python/conda-recipe/build.sh b/src/Python/conda-recipe/build.sh new file mode 100644 index 0000000..855047f --- /dev/null +++ b/src/Python/conda-recipe/build.sh @@ -0,0 +1,14 @@ + +if [ -z "$CIL_VERSION" ]; then +    echo "Need to set CIL_VERSION" +    exit 1 +fi   +mkdir "$SRC_DIR/ccpi" +cp -r "$RECIPE_DIR/../.." "$SRC_DIR/ccpi" + +cd $SRC_DIR/ccpi/Python + +$PYTHON setup.py build_ext +$PYTHON setup.py install + + diff --git a/src/Python/conda-recipe/meta.yaml b/src/Python/conda-recipe/meta.yaml new file mode 100644 index 0000000..7068e9d --- /dev/null +++ b/src/Python/conda-recipe/meta.yaml @@ -0,0 +1,30 @@ +package: +  name: ccpi-regularizers +  version: {{ environ['CIL_VERSION'] }} + + +build: +  preserve_egg_dir: False +  script_env: +    - CIL_VERSION    +#  number: 0 +   +requirements: +  build: +    - python +    - numpy +    - setuptools +    - boost ==1.64 +    - boost-cpp ==1.64 +    - cython + +  run: +    - python +    - numpy +    - boost ==1.64 + +	 +about: +  home: http://www.ccpi.ac.uk +  license:  BSD license +  summary: 'CCPi Core Imaging Library Quantification Toolbox' diff --git a/src/Python/fista-recipe/build.sh b/src/Python/fista-recipe/build.sh new file mode 100644 index 0000000..e3f3552 --- /dev/null +++ b/src/Python/fista-recipe/build.sh @@ -0,0 +1,10 @@ +if [ -z "$CIL_VERSION" ]; then +    echo "Need to set CIL_VERSION" +    exit 1 +fi   +mkdir "$SRC_DIR/ccpifista" +cp -r "$RECIPE_DIR/.." "$SRC_DIR/ccpifista" + +cd $SRC_DIR/ccpifista + +$PYTHON setup-fista.py install diff --git a/src/Python/fista-recipe/meta.yaml b/src/Python/fista-recipe/meta.yaml new file mode 100644 index 0000000..265541f --- /dev/null +++ b/src/Python/fista-recipe/meta.yaml @@ -0,0 +1,29 @@ +package: +  name: ccpi-fista +  version: {{ environ['CIL_VERSION'] }} + + +build: +  preserve_egg_dir: False +  script_env: +    - CIL_VERSION    +#  number: 0 +   +requirements: +  build: +    - python +    - numpy +    - setuptools +  +  run: +    - python +    - numpy +    #- astra-toolbox +    - ccpi-regularizers + + +	 +about: +  home: http://www.ccpi.ac.uk +  license:  Apache v.2.0 license +  summary: 'CCPi Core Imaging Library (Viewer)' diff --git a/src/Python/fista_module.cpp b/src/Python/fista_module.cpp new file mode 100644 index 0000000..aca3be0 --- /dev/null +++ b/src/Python/fista_module.cpp @@ -0,0 +1,1050 @@ +/* +This work is part of the Core Imaging Library developed by +Visual Analytics and Imaging System Group of the Science Technology +Facilities Council, STFC + +Copyright 2017 Daniil Kazantsev +Copyright 2017 Srikanth Nagella, Edoardo Pasca + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include <iostream> +#include <cmath> + +#include <boost/python.hpp> +#include <boost/python/numpy.hpp> +#include "boost/tuple/tuple.hpp" + +#include "SplitBregman_TV_core.h" +#include "FGP_TV_core.h" +#include "LLT_model_core.h" +#include "PatchBased_Regul_core.h" +#include "TGV_PD_core.h" +#include "utils.h" + + + +#if defined(_WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(_WIN64) +#include <windows.h> +// this trick only if compiler is MSVC +__if_not_exists(uint8_t) { typedef __int8 uint8_t; } +__if_not_exists(uint16_t) { typedef __int8 uint16_t; } +#endif + +namespace bp = boost::python; +namespace np = boost::python::numpy; + +/*! in the Matlab implementation this is called as +void mexFunction( +int nlhs, mxArray *plhs[], +int nrhs, const mxArray *prhs[]) +where: +prhs Array of pointers to the INPUT mxArrays +nrhs int number of INPUT mxArrays + +nlhs Array of pointers to the OUTPUT mxArrays +plhs int number of OUTPUT mxArrays + +*********************************************************** + +*********************************************************** +double mxGetScalar(const mxArray *pm); +args: pm Pointer to an mxArray; cannot be a cell mxArray, a structure mxArray, or an empty mxArray. +Returns: Pointer to the value of the first real (nonimaginary) element of the mxArray.	In C, mxGetScalar returns a double. +*********************************************************** +char *mxArrayToString(const mxArray *array_ptr); +args: array_ptr Pointer to mxCHAR array. +Returns: C-style string. Returns NULL on failure. Possible reasons for failure include out of memory and specifying an array that is not an mxCHAR array. +Description: Call mxArrayToString to copy the character data of an mxCHAR array into a C-style string. +*********************************************************** +mxClassID mxGetClassID(const mxArray *pm); +args: pm Pointer to an mxArray +Returns: Numeric identifier of the class (category) of the mxArray that pm points to.For user-defined types, +mxGetClassId returns a unique value identifying the class of the array contents. +Use mxIsClass to determine whether an array is of a specific user-defined type. + +mxClassID Value	  MATLAB Type   MEX Type	 C Primitive Type +mxINT8_CLASS 	  int8	        int8_T	     char, byte +mxUINT8_CLASS	  uint8	        uint8_T	     unsigned char, byte +mxINT16_CLASS	  int16	        int16_T	     short +mxUINT16_CLASS	  uint16	    uint16_T	 unsigned short +mxINT32_CLASS	  int32	        int32_T	     int +mxUINT32_CLASS	  uint32	    uint32_T	 unsigned int +mxINT64_CLASS	  int64	        int64_T	     long long +mxUINT64_CLASS	  uint64	    uint64_T 	 unsigned long long +mxSINGLE_CLASS	  single	    float	     float +mxDOUBLE_CLASS	  double	    double	     double + +**************************************************************** +double *mxGetPr(const mxArray *pm); +args: pm Pointer to an mxArray of type double +Returns: Pointer to the first element of the real data. Returns NULL in C (0 in Fortran) if there is no real data. +**************************************************************** +mxArray *mxCreateNumericArray(mwSize ndim, const mwSize *dims, +mxClassID classid, mxComplexity ComplexFlag); +args: ndimNumber of dimensions. If you specify a value for ndim that is less than 2, mxCreateNumericArray automatically sets the number of dimensions to 2. +dims Dimensions array. Each element in the dimensions array contains the size of the array in that dimension. +For example, in C, setting dims[0] to 5 and dims[1] to 7 establishes a 5-by-7 mxArray. Usually there are ndim elements in the dims array. +classid Identifier for the class of the array, which determines the way the numerical data is represented in memory. +For example, specifying mxINT16_CLASS in C causes each piece of numerical data in the mxArray to be represented as a 16-bit signed integer. +ComplexFlag  If the mxArray you are creating is to contain imaginary data, set ComplexFlag to mxCOMPLEX in C (1 in Fortran). Otherwise, set ComplexFlag to mxREAL in C (0 in Fortran). +Returns: Pointer to the created mxArray, if successful. If unsuccessful in a standalone (non-MEX file) application, returns NULL in C (0 in Fortran). +If unsuccessful in a MEX file, the MEX file terminates and returns control to the MATLAB prompt. The function is unsuccessful when there is not +enough free heap space to create the mxArray. +*/ + + + +bp::list SplitBregman_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { +	 +	// the result is in the following list +	bp::list result; +		 +	int number_of_dims, dimX, dimY, dimZ, ll, j, count; +	//const int  *dim_array; +	float *A, *U = NULL, *U_old = NULL, *Dx = NULL, *Dy = NULL, *Dz = NULL, *Bx = NULL, *By = NULL, *Bz = NULL, lambda, mu, epsil, re, re1, re_old; +	 +	//number_of_dims = mxGetNumberOfDimensions(prhs[0]); +	//dim_array = mxGetDimensions(prhs[0]); + +	number_of_dims = input.get_nd(); +	int dim_array[3]; + +	dim_array[0] = input.shape(0); +	dim_array[1] = input.shape(1); +	if (number_of_dims == 2) { +		dim_array[2] = -1; +	} +	else { +		dim_array[2] = input.shape(2); +	} + +	// Parameter handling is be done in Python +	///*Handling Matlab input data*/ +	//if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + +	///*Handling Matlab input data*/ +	//A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ +	A = reinterpret_cast<float *>(input.get_data()); + +	//mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ +	mu = (float)d_mu; + +	//iter = 35; /* default iterations number */ +	 +	//epsil = 0.0001; /* default tolerance constant */ +	epsil = (float)d_epsil; +	//methTV = 0;  /* default isotropic TV penalty */ +	//if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5))  iter = (int)mxGetScalar(prhs[2]); /* iterations number */ +	//if ((nrhs == 4) || (nrhs == 5))  epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ +	//if (nrhs == 5) { +	//	char *penalty_type; +	//	penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ +	//	if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); +	//	if (strcmp(penalty_type, "l1") == 0)  methTV = 1;  /* enable 'l1' penalty */ +	//	mxFree(penalty_type); +	//} +	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + +	lambda = 2.0f*mu; +	count = 1; +	re_old = 0.0f; +	/*Handling Matlab output data*/ +	dimY = dim_array[0]; dimX = dim_array[1]; dimZ = dim_array[2]; + +	if (number_of_dims == 2) { +		dimZ = 1; /*2D case*/ +		//U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		//U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		//Dx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		//Dy = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		//Bx = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		//By = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); +		np::dtype dtype = np::dtype::get_builtin<float>(); + +		np::ndarray npU = np::zeros(shape, dtype); +		np::ndarray npU_old = np::zeros(shape, dtype); +		np::ndarray npDx = np::zeros(shape, dtype); +		np::ndarray npDy = np::zeros(shape, dtype); +		np::ndarray npBx = np::zeros(shape, dtype); +		np::ndarray npBy = np::zeros(shape, dtype); + +		U = reinterpret_cast<float *>(npU.get_data()); +		U_old = reinterpret_cast<float *>(npU_old.get_data()); +		Dx = reinterpret_cast<float *>(npDx.get_data()); +		Dy = reinterpret_cast<float *>(npDy.get_data()); +		Bx = reinterpret_cast<float *>(npBx.get_data()); +		By = reinterpret_cast<float *>(npBy.get_data()); + + + +		copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + +										/* begin outer SB iterations */ +		for (ll = 0; ll < iter; ll++) { + +			/*storing old values*/ +			copyIm(U, U_old, dimX, dimY, dimZ); + +			/*GS iteration */ +			gauss_seidel2D(U, A, Dx, Dy, Bx, By, dimX, dimY, lambda, mu); + +			if (methTV == 1)  updDxDy_shrinkAniso2D(U, Dx, Dy, Bx, By, dimX, dimY, lambda); +			else updDxDy_shrinkIso2D(U, Dx, Dy, Bx, By, dimX, dimY, lambda); + +			updBxBy2D(U, Dx, Dy, Bx, By, dimX, dimY); + +			/* calculate norm to terminate earlier */ +			re = 0.0f; re1 = 0.0f; +			for (j = 0; j < dimX*dimY*dimZ; j++) +			{ +				re += pow(U_old[j] - U[j], 2); +				re1 += pow(U_old[j], 2); +			} +			re = sqrt(re) / sqrt(re1); +			if (re < epsil)  count++; +			if (count > 4) break; + +			/* check that the residual norm is decreasing */ +			if (ll > 2) { +				if (re > re_old) break; +			} +			re_old = re; +			/*printf("%f %i %i \n", re, ll, count); */ + +			/*copyIm(U_old, U, dimX, dimY, dimZ); */ +			 +		} +		//printf("SB iterations stopped at iteration: %i\n", ll); +		result.append<np::ndarray>(npU); +		result.append<int>(ll); +	} +	if (number_of_dims == 3) { +			/*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +			U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +			Dx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +			Dy = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +			Dz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +			Bx = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +			By = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +			Bz = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ +			bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); +			np::dtype dtype = np::dtype::get_builtin<float>(); + +			np::ndarray npU     = np::zeros(shape, dtype); +			np::ndarray npU_old = np::zeros(shape, dtype); +			np::ndarray npDx    = np::zeros(shape, dtype); +			np::ndarray npDy    = np::zeros(shape, dtype); +			np::ndarray npDz    = np::zeros(shape, dtype); +			np::ndarray npBx    = np::zeros(shape, dtype); +			np::ndarray npBy    = np::zeros(shape, dtype); +			np::ndarray npBz    = np::zeros(shape, dtype); + +			U     = reinterpret_cast<float *>(npU.get_data()); +			U_old = reinterpret_cast<float *>(npU_old.get_data()); +			Dx    = reinterpret_cast<float *>(npDx.get_data()); +			Dy    = reinterpret_cast<float *>(npDy.get_data()); +			Dz    = reinterpret_cast<float *>(npDz.get_data()); +			Bx    = reinterpret_cast<float *>(npBx.get_data()); +			By    = reinterpret_cast<float *>(npBy.get_data()); +			Bz    = reinterpret_cast<float *>(npBz.get_data()); + +			copyIm(A, U, dimX, dimY, dimZ); /*initialize */ + +											/* begin outer SB iterations */ +			for (ll = 0; ll<iter; ll++) { + +				/*storing old values*/ +				copyIm(U, U_old, dimX, dimY, dimZ); + +				/*GS iteration */ +				gauss_seidel3D(U, A, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda, mu); + +				if (methTV == 1) updDxDyDz_shrinkAniso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda); +				else updDxDyDz_shrinkIso3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ, lambda); + +				updBxByBz3D(U, Dx, Dy, Dz, Bx, By, Bz, dimX, dimY, dimZ); + +				/* calculate norm to terminate earlier */ +				re = 0.0f; re1 = 0.0f; +				for (j = 0; j<dimX*dimY*dimZ; j++) +				{ +					re += pow(U[j] - U_old[j], 2); +					re1 += pow(U[j], 2); +				} +				re = sqrt(re) / sqrt(re1); +				if (re < epsil)  count++; +				if (count > 4) break; + +				/* check that the residual norm is decreasing */ +				if (ll > 2) { +					if (re > re_old) break; +				} +				/*printf("%f %i %i \n", re, ll, count); */ +				re_old = re; +			} +			//printf("SB iterations stopped at iteration: %i\n", ll); +			result.append<np::ndarray>(npU); +			result.append<int>(ll); +		} +	return result; + +	} + + + +bp::list FGP_TV(np::ndarray input, double d_mu, int iter, double d_epsil, int methTV) { + +	// the result is in the following list +	bp::list result; + +	int number_of_dims, dimX, dimY, dimZ, ll, j, count; +	float *A, *D = NULL, *D_old = NULL, *P1 = NULL, *P2 = NULL, *P3 = NULL, *P1_old = NULL, *P2_old = NULL, *P3_old = NULL, *R1 = NULL, *R2 = NULL, *R3 = NULL; +	float lambda, tk, tkp1, re, re1, re_old, epsil, funcval; + +	//number_of_dims = mxGetNumberOfDimensions(prhs[0]); +	//dim_array = mxGetDimensions(prhs[0]); + +	number_of_dims = input.get_nd(); +	int dim_array[3]; + +	dim_array[0] = input.shape(0); +	dim_array[1] = input.shape(1); +	if (number_of_dims == 2) { +		dim_array[2] = -1; +	} +	else { +		dim_array[2] = input.shape(2); +	} + +	// Parameter handling is be done in Python +	///*Handling Matlab input data*/ +	//if ((nrhs < 2) || (nrhs > 5)) mexErrMsgTxt("At least 2 parameters is required: Image(2D/3D), Regularization parameter. The full list of parameters: Image(2D/3D), Regularization parameter, iterations number, tolerance, penalty type ('iso' or 'l1')"); + +	///*Handling Matlab input data*/ +	//A = (float *)mxGetData(prhs[0]); /*noisy image (2D/3D) */ +	A = reinterpret_cast<float *>(input.get_data()); + +	//mu = (float)mxGetScalar(prhs[1]); /* regularization parameter */ +	lambda = (float)d_mu; + +	//iter = 35; /* default iterations number */ + +	//epsil = 0.0001; /* default tolerance constant */ +	epsil = (float)d_epsil; +	//methTV = 0;  /* default isotropic TV penalty */ +	//if ((nrhs == 3) || (nrhs == 4) || (nrhs == 5))  iter = (int)mxGetScalar(prhs[2]); /* iterations number */ +	//if ((nrhs == 4) || (nrhs == 5))  epsil = (float)mxGetScalar(prhs[3]); /* tolerance constant */ +	//if (nrhs == 5) { +	//	char *penalty_type; +	//	penalty_type = mxArrayToString(prhs[4]); /* choosing TV penalty: 'iso' or 'l1', 'iso' is the default */ +	//	if ((strcmp(penalty_type, "l1") != 0) && (strcmp(penalty_type, "iso") != 0)) mexErrMsgTxt("Choose TV type: 'iso' or 'l1',"); +	//	if (strcmp(penalty_type, "l1") == 0)  methTV = 1;  /* enable 'l1' penalty */ +	//	mxFree(penalty_type); +	//} +	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + +	//plhs[1] = mxCreateNumericMatrix(1, 1, mxSINGLE_CLASS, mxREAL); +	bp::tuple shape1 = bp::make_tuple(dim_array[0], dim_array[1]); +	np::dtype dtype = np::dtype::get_builtin<float>(); +	np::ndarray out1 = np::zeros(shape1, dtype); +	 +	//float *funcvalA = (float *)mxGetData(plhs[1]); +	float * funcvalA = reinterpret_cast<float *>(out1.get_data()); +	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input image must be in a single precision"); } + +	/*Handling Matlab output data*/ +	dimX = dim_array[0]; dimY = dim_array[1]; dimZ = dim_array[2]; + +	tk = 1.0f; +	tkp1 = 1.0f; +	count = 1; +	re_old = 0.0f; + +	if (number_of_dims == 2) { +		dimZ = 1; /*2D case*/ +		/*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		D_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		P1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		P2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		R1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		R2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + +		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); +		np::dtype dtype = np::dtype::get_builtin<float>(); + + +		np::ndarray npD      = np::zeros(shape, dtype); +		np::ndarray npD_old  = np::zeros(shape, dtype); +		np::ndarray npP1     = np::zeros(shape, dtype); +		np::ndarray npP2     = np::zeros(shape, dtype); +		np::ndarray npP1_old = np::zeros(shape, dtype); +		np::ndarray npP2_old = np::zeros(shape, dtype); +		np::ndarray npR1     = np::zeros(shape, dtype); +		np::ndarray npR2     = np::zeros(shape, dtype); + +		D      = reinterpret_cast<float *>(npD.get_data()); +		D_old  = reinterpret_cast<float *>(npD_old.get_data()); +		P1     = reinterpret_cast<float *>(npP1.get_data()); +		P2     = reinterpret_cast<float *>(npP2.get_data()); +		P1_old = reinterpret_cast<float *>(npP1_old.get_data()); +		P2_old = reinterpret_cast<float *>(npP2_old.get_data()); +		R1     = reinterpret_cast<float *>(npR1.get_data()); +		R2     = reinterpret_cast<float *>(npR2.get_data()); + +		/* begin iterations */ +		for (ll = 0; ll<iter; ll++) { +			/* computing the gradient of the objective function */ +			Obj_func2D(A, D, R1, R2, lambda, dimX, dimY); + +			/*Taking a step towards minus of the gradient*/ +			Grad_func2D(P1, P2, D, R1, R2, lambda, dimX, dimY); + + + + +			/*updating R and t*/ +			tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f; +			Rupd_func2D(P1, P1_old, P2, P2_old, R1, R2, tkp1, tk, dimX, dimY); + +			/* calculate norm */ +			re = 0.0f; re1 = 0.0f; +			for (j = 0; j<dimX*dimY*dimZ; j++) +			{ +				re += pow(D[j] - D_old[j], 2); +				re1 += pow(D[j], 2); +			} +			re = sqrt(re) / sqrt(re1); +			if (re < epsil)  count++; +			if (count > 3) { +				Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); +				funcval = 0.0f; +				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); +				//funcvalA[0] = sqrt(funcval); +				float fv = sqrt(funcval); +				std::memcpy(funcvalA, &fv, sizeof(float)); +				break; +			} + +			/* check that the residual norm is decreasing */ +			if (ll > 2) { +				if (re > re_old) { +					Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); +					funcval = 0.0f; +					for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); +					//funcvalA[0] = sqrt(funcval); +					float fv = sqrt(funcval); +					std::memcpy(funcvalA, &fv, sizeof(float)); +					break; +				} +			} +			re_old = re; +			/*printf("%f %i %i \n", re, ll, count); */ + +			/*storing old values*/ +			copyIm(D, D_old, dimX, dimY, dimZ); +			copyIm(P1, P1_old, dimX, dimY, dimZ); +			copyIm(P2, P2_old, dimX, dimY, dimZ); +			tk = tkp1; + +			/* calculating the objective function value */ +			if (ll == (iter - 1)) { +				Obj_func2D(A, D, P1, P2, lambda, dimX, dimY); +				funcval = 0.0f; +				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); +				//funcvalA[0] = sqrt(funcval); +				float fv = sqrt(funcval); +				std::memcpy(funcvalA, &fv, sizeof(float)); +			} +		} +		//printf("FGP-TV iterations stopped at iteration %i with the function value %f \n", ll, funcvalA[0]); +		result.append<np::ndarray>(npD); +		result.append<np::ndarray>(out1); +		result.append<int>(ll); +	} +	if (number_of_dims == 3) { +		/*D = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		D_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		P1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		P2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		P3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		P1_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		P2_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		P3_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		R1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		R2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		R3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL));*/ +		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); +		np::dtype dtype = np::dtype::get_builtin<float>(); +		 +		np::ndarray npD      = np::zeros(shape, dtype); +		np::ndarray npD_old  = np::zeros(shape, dtype); +		np::ndarray npP1     = np::zeros(shape, dtype); +		np::ndarray npP2     = np::zeros(shape, dtype); +		np::ndarray npP3     = np::zeros(shape, dtype); +		np::ndarray npP1_old = np::zeros(shape, dtype); +		np::ndarray npP2_old = np::zeros(shape, dtype); +		np::ndarray npP3_old = np::zeros(shape, dtype); +		np::ndarray npR1     = np::zeros(shape, dtype); +		np::ndarray npR2     = np::zeros(shape, dtype); +		np::ndarray npR3     = np::zeros(shape, dtype); + +		D      = reinterpret_cast<float *>(npD.get_data()); +		D_old  = reinterpret_cast<float *>(npD_old.get_data()); +		P1     = reinterpret_cast<float *>(npP1.get_data()); +		P2     = reinterpret_cast<float *>(npP2.get_data()); +		P3     = reinterpret_cast<float *>(npP3.get_data()); +		P1_old = reinterpret_cast<float *>(npP1_old.get_data()); +		P2_old = reinterpret_cast<float *>(npP2_old.get_data()); +		P3_old = reinterpret_cast<float *>(npP3_old.get_data()); +		R1     = reinterpret_cast<float *>(npR1.get_data()); +		R2     = reinterpret_cast<float *>(npR2.get_data()); +		R2     = reinterpret_cast<float *>(npR3.get_data()); +		/* begin iterations */ +		for (ll = 0; ll<iter; ll++) { + +			/* computing the gradient of the objective function */ +			Obj_func3D(A, D, R1, R2, R3, lambda, dimX, dimY, dimZ); + +			/*Taking a step towards minus of the gradient*/ +			Grad_func3D(P1, P2, P3, D, R1, R2, R3, lambda, dimX, dimY, dimZ); + +			/* projection step */ +			Proj_func3D(P1, P2, P3, dimX, dimY, dimZ); + +			/*updating R and t*/ +			tkp1 = (1.0f + sqrt(1.0f + 4.0f*tk*tk))*0.5f; +			Rupd_func3D(P1, P1_old, P2, P2_old, P3, P3_old, R1, R2, R3, tkp1, tk, dimX, dimY, dimZ); + +			/* calculate norm - stopping rules*/ +			re = 0.0f; re1 = 0.0f; +			for (j = 0; j<dimX*dimY*dimZ; j++) +			{ +				re += pow(D[j] - D_old[j], 2); +				re1 += pow(D[j], 2); +			} +			re = sqrt(re) / sqrt(re1); +			/* stop if the norm residual is less than the tolerance EPS */ +			if (re < epsil)  count++; +			if (count > 3) { +				Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); +				funcval = 0.0f; +				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); +				//funcvalA[0] = sqrt(funcval); +				float fv = sqrt(funcval); +				std::memcpy(funcvalA, &fv, sizeof(float)); +				break; +			} + +			/* check that the residual norm is decreasing */ +			if (ll > 2) { +				if (re > re_old) { +					Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); +					funcval = 0.0f; +					for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); +					//funcvalA[0] = sqrt(funcval); +					float fv = sqrt(funcval); +					std::memcpy(funcvalA, &fv, sizeof(float)); +					break; +				} +			} + +			re_old = re; +			/*printf("%f %i %i \n", re, ll, count); */ + +			/*storing old values*/ +			copyIm(D, D_old, dimX, dimY, dimZ); +			copyIm(P1, P1_old, dimX, dimY, dimZ); +			copyIm(P2, P2_old, dimX, dimY, dimZ); +			copyIm(P3, P3_old, dimX, dimY, dimZ); +			tk = tkp1; + +			if (ll == (iter - 1)) { +				Obj_func3D(A, D, P1, P2, P3, lambda, dimX, dimY, dimZ); +				funcval = 0.0f; +				for (j = 0; j<dimX*dimY*dimZ; j++) funcval += pow(D[j], 2); +				//funcvalA[0] = sqrt(funcval); +				float fv = sqrt(funcval); +				std::memcpy(funcvalA, &fv, sizeof(float)); +			} + +		} +		//printf("FGP-TV iterations stopped at iteration %i with the function value %f \n", ll, funcvalA[0]); +		result.append<np::ndarray>(npD); +		result.append<np::ndarray>(out1); +		result.append<int>(ll); +	} + +	return result; +} + +bp::list LLT_model(np::ndarray input, double d_lambda, double d_tau, int iter, double d_epsil, int switcher) { +	// the result is in the following list +	bp::list result; + +	int number_of_dims, dimX, dimY, dimZ, ll, j, count; +	//const int  *dim_array; +	float *U0, *U = NULL, *U_old = NULL, *D1 = NULL, *D2 = NULL, *D3 = NULL, lambda, tau, re, re1, epsil, re_old; +	unsigned short *Map = NULL; + +	number_of_dims = input.get_nd(); +	int dim_array[3]; + +	dim_array[0] = input.shape(0); +	dim_array[1] = input.shape(1); +	if (number_of_dims == 2) { +		dim_array[2] = -1; +	} +	else { +		dim_array[2] = input.shape(2); +	} + +	///*Handling Matlab input data*/ +	//U0 = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/ +	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } +	//lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/ +	//tau = (float)mxGetScalar(prhs[2]); /* time-step */ +	//iter = (int)mxGetScalar(prhs[3]); /*iterations number*/ +	//epsil = (float)mxGetScalar(prhs[4]); /* tolerance constant */ +	//switcher = (int)mxGetScalar(prhs[5]); /*switch on (1) restrictive smoothing in Z dimension*/ +	 +	U0 = reinterpret_cast<float *>(input.get_data()); +	lambda = (float)d_lambda; +	tau = (float)d_tau; +	// iter is passed as parameter +	epsil = (float)d_epsil; +	// switcher is passed as parameter +										  /*Handling Matlab output data*/ +	dimX = dim_array[0]; dimY = dim_array[1];  dimZ = 1; + +	if (number_of_dims == 2) { +		/*2D case*/ +		/*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		D1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		D2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ + +		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); +		np::dtype dtype = np::dtype::get_builtin<float>(); + + +		np::ndarray npU = np::zeros(shape, dtype); +		np::ndarray npU_old = np::zeros(shape, dtype); +		np::ndarray npD1 = np::zeros(shape, dtype); +		np::ndarray npD2 = np::zeros(shape, dtype); +		 + +		U = reinterpret_cast<float *>(npU.get_data()); +		U_old = reinterpret_cast<float *>(npU_old.get_data()); +		D1 = reinterpret_cast<float *>(npD1.get_data()); +		D2 = reinterpret_cast<float *>(npD2.get_data()); +		 +		/*Copy U0 to U*/ +		copyIm(U0, U, dimX, dimY, dimZ); + +		count = 1; +		re_old = 0.0f; + +		for (ll = 0; ll < iter; ll++) { + +			copyIm(U, U_old, dimX, dimY, dimZ); + +			/*estimate inner derrivatives */ +			der2D(U, D1, D2, dimX, dimY, dimZ); +			/* calculate div^2 and update */ +			div_upd2D(U0, U, D1, D2, dimX, dimY, dimZ, lambda, tau); + +			/* calculate norm to terminate earlier */ +			re = 0.0f; re1 = 0.0f; +			for (j = 0; j<dimX*dimY*dimZ; j++) +			{ +				re += pow(U_old[j] - U[j], 2); +				re1 += pow(U_old[j], 2); +			} +			re = sqrt(re) / sqrt(re1); +			if (re < epsil)  count++; +			if (count > 4) break; + +			/* check that the residual norm is decreasing */ +			if (ll > 2) { +				if (re > re_old) break; +			} +			re_old = re; + +		} /*end of iterations*/ +		  //printf("HO iterations stopped at iteration: %i\n", ll); + +		result.append<np::ndarray>(npU); +	} +	else if (number_of_dims == 3) { +		/*3D case*/ +		dimZ = dim_array[2]; +		/*U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		U_old = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		D1 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		D2 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		D3 = (float*)mxGetPr(mxCreateNumericArray(3, dim_array, mxSINGLE_CLASS, mxREAL)); +		if (switcher != 0) { +			Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); +		}*/ +		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1], dim_array[2]); +		np::dtype dtype = np::dtype::get_builtin<float>(); + + +		np::ndarray npU = np::zeros(shape, dtype); +		np::ndarray npU_old = np::zeros(shape, dtype); +		np::ndarray npD1 = np::zeros(shape, dtype); +		np::ndarray npD2 = np::zeros(shape, dtype); +		np::ndarray npD3 = np::zeros(shape, dtype); +		np::ndarray npMap = np::zeros(shape, np::dtype::get_builtin<unsigned short>()); +		Map = reinterpret_cast<unsigned short *>(npMap.get_data()); +		if (switcher != 0) { +			//Map = (unsigned short*)mxGetPr(plhs[1] = mxCreateNumericArray(3, dim_array, mxUINT16_CLASS, mxREAL)); +			 +			Map = reinterpret_cast<unsigned short *>(npMap.get_data()); +		} + +		U = reinterpret_cast<float *>(npU.get_data()); +		U_old = reinterpret_cast<float *>(npU_old.get_data()); +		D1 = reinterpret_cast<float *>(npD1.get_data()); +		D2 = reinterpret_cast<float *>(npD2.get_data()); +		D3 = reinterpret_cast<float *>(npD2.get_data()); +		 +		/*Copy U0 to U*/ +		copyIm(U0, U, dimX, dimY, dimZ); + +		count = 1; +		re_old = 0.0f; +	 + +		if (switcher == 1) { +			/* apply restrictive smoothing */ +			calcMap(U, Map, dimX, dimY, dimZ); +			/*clear outliers */ +			cleanMap(Map, dimX, dimY, dimZ); +		} +		for (ll = 0; ll < iter; ll++) { + +			copyIm(U, U_old, dimX, dimY, dimZ); + +			/*estimate inner derrivatives */ +			der3D(U, D1, D2, D3, dimX, dimY, dimZ); +			/* calculate div^2 and update */ +			div_upd3D(U0, U, D1, D2, D3, Map, switcher, dimX, dimY, dimZ, lambda, tau); + +			/* calculate norm to terminate earlier */ +			re = 0.0f; re1 = 0.0f; +			for (j = 0; j<dimX*dimY*dimZ; j++) +			{ +				re += pow(U_old[j] - U[j], 2); +				re1 += pow(U_old[j], 2); +			} +			re = sqrt(re) / sqrt(re1); +			if (re < epsil)  count++; +			if (count > 4) break; + +			/* check that the residual norm is decreasing */ +			if (ll > 2) { +				if (re > re_old) break; +			} +			re_old = re; + +		} /*end of iterations*/ +		//printf("HO iterations stopped at iteration: %i\n", ll); +		result.append<np::ndarray>(npU); +		if (switcher != 0) result.append<np::ndarray>(npMap); + +	} +	return result; +} + + +bp::list PatchBased_Regul(np::ndarray input, double d_lambda, int SearchW_real, int SimilW,  double d_h) { +	// the result is in the following list +	bp::list result; + +	int N, M, Z, numdims, SearchW, /*SimilW, SearchW_real,*/ padXY, newsizeX, newsizeY, newsizeZ, switchpad_crop; +	//const int  *dims; +	float *A, *B = NULL, *Ap = NULL, *Bp = NULL, h, lambda; + +	numdims = input.get_nd(); +	int dims[3]; + +	dims[0] = input.shape(0); +	dims[1] = input.shape(1); +	if (numdims == 2) { +		dims[2] = -1; +	} +	else { +		dims[2] = input.shape(2); +	} +	/*numdims = mxGetNumberOfDimensions(prhs[0]); +	dims = mxGetDimensions(prhs[0]);*/ + +	N = dims[0]; +	M = dims[1]; +	Z = dims[2]; + +	//if ((numdims < 2) || (numdims > 3)) { mexErrMsgTxt("The input should be 2D image or 3D volume"); } +	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } + +	//if (nrhs != 5) mexErrMsgTxt("Five inputs reqired: Image(2D,3D), SearchW, SimilW, Threshold, Regularization parameter"); + +	///*Handling inputs*/ +	//A = (float *)mxGetData(prhs[0]);    /* the image to regularize/filter */ +	A = reinterpret_cast<float *>(input.get_data()); +	//SearchW_real = (int)mxGetScalar(prhs[1]); /* the searching window ratio */ +	//SimilW = (int)mxGetScalar(prhs[2]);  /* the similarity window ratio */ +	//h = (float)mxGetScalar(prhs[3]);  /* parameter for the PB filtering function */ +	//lambda = (float)mxGetScalar(prhs[4]); /* regularization parameter */ + +	//if (h <= 0) mexErrMsgTxt("Parmeter for the PB penalty function should be > 0"); +	//if (lambda <= 0) mexErrMsgTxt(" Regularization parmeter should be > 0"); + +	lambda = (float)d_lambda; +	h = (float)d_h; +	SearchW = SearchW_real + 2 * SimilW; + +	/* SearchW_full = 2*SearchW + 1; */ /* the full searching window  size */ +										/* SimilW_full = 2*SimilW + 1;  */  /* the full similarity window  size */ + + +	padXY = SearchW + 2 * SimilW; /* padding sizes */ +	newsizeX = N + 2 * (padXY); /* the X size of the padded array */ +	newsizeY = M + 2 * (padXY); /* the Y size of the padded array */ +	newsizeZ = Z + 2 * (padXY); /* the Z size of the padded array */ +	int N_dims[] = { newsizeX, newsizeY, newsizeZ }; +	/******************************2D case ****************************/ +	if (numdims == 2) { +		///*Handling output*/ +		//B = (float*)mxGetData(plhs[0] = mxCreateNumericMatrix(N, M, mxSINGLE_CLASS, mxREAL)); +		///*allocating memory for the padded arrays */ +		//Ap = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); +		//Bp = (float*)mxGetData(mxCreateNumericMatrix(newsizeX, newsizeY, mxSINGLE_CLASS, mxREAL)); +		///**************************************************************************/ + +		bp::tuple shape = bp::make_tuple(N, M); +		np::dtype dtype = np::dtype::get_builtin<float>(); + +		np::ndarray npB = np::zeros(shape, dtype); + +		shape = bp::make_tuple(newsizeX, newsizeY); +		np::ndarray npAp = np::zeros(shape, dtype); +		np::ndarray npBp = np::zeros(shape, dtype); +		B = reinterpret_cast<float *>(npB.get_data()); +		Ap = reinterpret_cast<float *>(npAp.get_data()); +		Bp = reinterpret_cast<float *>(npBp.get_data());		 + +		/*Perform padding of image A to the size of [newsizeX * newsizeY] */ +		switchpad_crop = 0; /*padding*/ +		pad_crop(A, Ap, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); +		 +		/* Do PB regularization with the padded array  */ +		PB_FUNC2D(Ap, Bp, newsizeY, newsizeX, padXY, SearchW, SimilW, (float)h, (float)lambda); +		 +		switchpad_crop = 1; /*cropping*/ +		pad_crop(Bp, B, M, N, 0, newsizeY, newsizeX, 0, padXY, switchpad_crop); +		 +		result.append<np::ndarray>(npB); +	} +	else +	{ +		/******************************3D case ****************************/ +		///*Handling output*/ +		//B = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL)); +		///*allocating memory for the padded arrays */ +		//Ap = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); +		//Bp = (float*)mxGetPr(mxCreateNumericArray(3, N_dims, mxSINGLE_CLASS, mxREAL)); +		/**************************************************************************/ +		bp::tuple shape = bp::make_tuple(dims[0], dims[1], dims[2]); +		bp::tuple shape_AB = bp::make_tuple(N_dims[0], N_dims[1], N_dims[2]); +		np::dtype dtype = np::dtype::get_builtin<float>(); + +		np::ndarray npB = np::zeros(shape, dtype); +		np::ndarray npAp = np::zeros(shape_AB, dtype); +		np::ndarray npBp = np::zeros(shape_AB, dtype); +		B = reinterpret_cast<float *>(npB.get_data()); +		Ap = reinterpret_cast<float *>(npAp.get_data()); +		Bp = reinterpret_cast<float *>(npBp.get_data()); +		/*Perform padding of image A to the size of [newsizeX * newsizeY * newsizeZ] */ +		switchpad_crop = 0; /*padding*/ +		pad_crop(A, Ap, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + +		/* Do PB regularization with the padded array  */ +		PB_FUNC3D(Ap, Bp, newsizeY, newsizeX, newsizeZ, padXY, SearchW, SimilW, (float)h, (float)lambda); + +		switchpad_crop = 1; /*cropping*/ +		pad_crop(Bp, B, M, N, Z, newsizeY, newsizeX, newsizeZ, padXY, switchpad_crop); + +		result.append<np::ndarray>(npB); +	} /*end else ndims*/ + +	return result; +} + +bp::list TGV_PD(np::ndarray input, double d_lambda, double d_alpha1, double d_alpha0, int iter) { +	// the result is in the following list +	bp::list result; +	int number_of_dims, /*iter,*/ dimX, dimY, dimZ, ll; +	//const int  *dim_array; +	float *A, *U, *U_old, *P1, *P2, *Q1, *Q2, *Q3, *V1, *V1_old, *V2, *V2_old, lambda, L2, tau, sigma, alpha1, alpha0; + +	//number_of_dims = mxGetNumberOfDimensions(prhs[0]); +	//dim_array = mxGetDimensions(prhs[0]); +	number_of_dims = input.get_nd(); +	int dim_array[3]; + +	dim_array[0] = input.shape(0); +	dim_array[1] = input.shape(1); +	if (number_of_dims == 2) { +		dim_array[2] = -1; +	} +	else { +		dim_array[2] = input.shape(2); +	} +	/*Handling Matlab input data*/ +	//A = (float *)mxGetData(prhs[0]); /*origanal noise image/volume*/ +	//if (mxGetClassID(prhs[0]) != mxSINGLE_CLASS) { mexErrMsgTxt("The input in single precision is required"); } +	 +	A = reinterpret_cast<float *>(input.get_data()); + +	//lambda = (float)mxGetScalar(prhs[1]); /*regularization parameter*/ +	//alpha1 = (float)mxGetScalar(prhs[2]); /*first-order term*/ +	//alpha0 = (float)mxGetScalar(prhs[3]); /*second-order term*/ +	//iter = (int)mxGetScalar(prhs[4]); /*iterations number*/ +	//if (nrhs != 5) mexErrMsgTxt("Five input parameters is reqired: Image(2D/3D), Regularization parameter, alpha1, alpha0, Iterations"); +	lambda = (float)d_lambda; +	alpha1 = (float)d_alpha1; +	alpha0 = (float)d_alpha0; + +	/*Handling Matlab output data*/ +	dimX = dim_array[0]; dimY = dim_array[1]; + +	if (number_of_dims == 2) { +		/*2D case*/ +		dimZ = 1; +		bp::tuple shape = bp::make_tuple(dim_array[0], dim_array[1]); +		np::dtype dtype = np::dtype::get_builtin<float>(); + +		np::ndarray npU = np::zeros(shape, dtype); +		np::ndarray npP1 = np::zeros(shape, dtype); +		np::ndarray npP2 = np::zeros(shape, dtype); +		np::ndarray npQ1 = np::zeros(shape, dtype); +		np::ndarray npQ2 = np::zeros(shape, dtype); +		np::ndarray npQ3 = np::zeros(shape, dtype); +		np::ndarray npV1 = np::zeros(shape, dtype); +		np::ndarray npV1_old = np::zeros(shape, dtype); +		np::ndarray npV2 = np::zeros(shape, dtype); +		np::ndarray npV2_old = np::zeros(shape, dtype); +		np::ndarray npU_old = np::zeros(shape, dtype); + +		U = reinterpret_cast<float *>(npU.get_data()); +		U_old = reinterpret_cast<float *>(npU_old.get_data()); +		P1 = reinterpret_cast<float *>(npP1.get_data()); +		P2 = reinterpret_cast<float *>(npP2.get_data()); +		Q1 = reinterpret_cast<float *>(npQ1.get_data()); +		Q2 = reinterpret_cast<float *>(npQ2.get_data()); +		Q3 = reinterpret_cast<float *>(npQ3.get_data()); +		V1 = reinterpret_cast<float *>(npV1.get_data()); +		V1_old = reinterpret_cast<float *>(npV1_old.get_data()); +		V2 = reinterpret_cast<float *>(npV2.get_data()); +		V2_old = reinterpret_cast<float *>(npV2_old.get_data()); +		//U = (float*)mxGetPr(plhs[0] = mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + +		/*dual variables*/ +		/*P1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		P2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + +		Q1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		Q2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		Q3 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + +		U_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); + +		V1 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		V1_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		V2 = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL)); +		V2_old = (float*)mxGetPr(mxCreateNumericArray(2, dim_array, mxSINGLE_CLASS, mxREAL));*/ +		/*printf("%i \n", i);*/ +		L2 = 12.0; /*Lipshitz constant*/ +		tau = 1.0 / pow(L2, 0.5); +		sigma = 1.0 / pow(L2, 0.5); + +		/*Copy A to U*/ +		copyIm(A, U, dimX, dimY, dimZ); +		/* Here primal-dual iterations begin for 2D */ +		for (ll = 0; ll < iter; ll++) { + +			/* Calculate Dual Variable P */ +			DualP_2D(U, V1, V2, P1, P2, dimX, dimY, dimZ, sigma); + +			/*Projection onto convex set for P*/ +			ProjP_2D(P1, P2, dimX, dimY, dimZ, alpha1); + +			/* Calculate Dual Variable Q */ +			DualQ_2D(V1, V2, Q1, Q2, Q3, dimX, dimY, dimZ, sigma); + +			/*Projection onto convex set for Q*/ +			ProjQ_2D(Q1, Q2, Q3, dimX, dimY, dimZ, alpha0); + +			/*saving U into U_old*/ +			copyIm(U, U_old, dimX, dimY, dimZ); + +			/*adjoint operation  -> divergence and projection of P*/ +			DivProjP_2D(U, A, P1, P2, dimX, dimY, dimZ, lambda, tau); + +			/*get updated solution U*/ +			newU(U, U_old, dimX, dimY, dimZ); + +			/*saving V into V_old*/ +			copyIm(V1, V1_old, dimX, dimY, dimZ); +			copyIm(V2, V2_old, dimX, dimY, dimZ); + +			/* upd V*/ +			UpdV_2D(V1, V2, P1, P2, Q1, Q2, Q3, dimX, dimY, dimZ, tau); + +			/*get new V*/ +			newU(V1, V1_old, dimX, dimY, dimZ); +			newU(V2, V2_old, dimX, dimY, dimZ); +		} /*end of iterations*/ +	 +		result.append<np::ndarray>(npU); +	} +	 + +	 +	 +	return result; +} + +BOOST_PYTHON_MODULE(cpu_regularizers) +{ +	np::initialize(); + +	//To specify that this module is a package +	bp::object package = bp::scope(); +	package.attr("__path__") = "cpu_regularizers"; + +	np::dtype dt1 = np::dtype::get_builtin<uint8_t>(); +	np::dtype dt2 = np::dtype::get_builtin<uint16_t>(); + +	def("SplitBregman_TV", SplitBregman_TV); +	def("FGP_TV", FGP_TV); +	def("LLT_model", LLT_model); +	def("PatchBased_Regul", PatchBased_Regul); +	def("TGV_PD", TGV_PD); +} diff --git a/src/Python/setup-fista.py.in b/src/Python/setup-fista.py.in new file mode 100644 index 0000000..c5c9f4d --- /dev/null +++ b/src/Python/setup-fista.py.in @@ -0,0 +1,27 @@ +from distutils.core import setup +#from setuptools import setup, find_packages +import os + +cil_version=os.environ['CIL_VERSION'] +if  cil_version == '': +    print("Please set the environmental variable CIL_VERSION") +    sys.exit(1) + +setup( +    name="ccpi-fista", +    version=cil_version, +    packages=['ccpi','ccpi.reconstruction'], +	install_requires=['numpy'], + +     zip_safe = False, + +    # metadata for upload to PyPI +    author="Edoardo Pasca", +    author_email="edo.paskino@gmail.com", +    description='CCPi Core Imaging Library - FISTA Reconstructor module', +    license="Apache v2.0", +    keywords="tomography interative reconstruction", +    url="http://www.ccpi.ac.uk",   # project home page, if any + +    # could also include long_description, download_url, classifiers, etc. +) diff --git a/src/Python/setup.py b/src/Python/setup.py new file mode 100644 index 0000000..154f979 --- /dev/null +++ b/src/Python/setup.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform	 + +cil_version=os.environ['CIL_VERSION'] +if  cil_version == '': +    print("Please set the environmental variable CIL_VERSION") +    sys.exit(1) + +library_include_path = "" +library_lib_path = "" +try: +    library_include_path = os.environ['LIBRARY_INC'] +    library_lib_path = os.environ['LIBRARY_LIB'] +except: +    library_include_path = os.environ['PREFIX']+'/include' +    pass +     +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +if platform.system() == 'Windows': +    extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB' , '/openmp' ]    +    extra_include_dirs += ["..\\..\\main_func\\regularizers_CPU\\","."] +    if sys.version_info.major == 3 :    +        extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] +    else: +        extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: +    extra_include_dirs += ["../../main_func/regularizers_CPU","."] +    if sys.version_info.major == 3: +        extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] +    else: +        extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( +    name='ccpi', +	description='CCPi Core Imaging Library - FISTA Reconstruction Module', +	version=cil_version, +    cmdclass = {'build_ext': build_ext}, +    ext_modules = [Extension("ccpi.imaging.cpu_regularizers", +                             sources=["fista_module.cpp", +                                      "../../main_func/regularizers_CPU/FGP_TV_core.c", +                                      "../../main_func/regularizers_CPU/SplitBregman_TV_core.c", +                                      "../../main_func/regularizers_CPU/LLT_model_core.c", +                                      "../../main_func/regularizers_CPU/PatchBased_Regul_core.c", +                                      "../../main_func/regularizers_CPU/TGV_PD_core.c", +                                      "../../main_func/regularizers_CPU/utils.c" +                                        ], +                             include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ),  +     +    ], +	zip_safe = False,	 +	packages = {'ccpi','ccpi.fistareconstruction'}, +) diff --git a/src/Python/setup.py.in b/src/Python/setup.py.in new file mode 100644 index 0000000..12e8af1 --- /dev/null +++ b/src/Python/setup.py.in @@ -0,0 +1,69 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform	 + +cil_version=@CIL_VERSION@ + +library_include_path = "" +library_lib_path = "" +try: +    library_include_path = os.environ['LIBRARY_INC'] +    library_lib_path = os.environ['LIBRARY_LIB'] +except: +    library_include_path = os.environ['PREFIX']+'/include' +    pass +     +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [os.path.join(library_include_path, "..", "lib")] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +extra_include_dirs += [os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" ,  "regularizers_CPU"), +                       	   os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" ,  "regularizers_GPU") ,  +						   "@CMAKE_CURRENT_SOURCE_DIR@"] +						    +if platform.system() == 'Windows': +    extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB' , '/openmp' ]    +     +    if sys.version_info.major == 3 :    +        extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] +    else: +        extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: +    if sys.version_info.major == 3: +        extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] +    else: +        extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( +    name='ccpi', +	description='CCPi Core Imaging Library - Image Regularizers', +	version=cil_version, +    cmdclass = {'build_ext': build_ext}, +    ext_modules = [Extension("ccpi.imaging.cpu_regularizers", +                             sources=[os.path.join("@CMAKE_CURRENT_SOURCE_DIR@" , "fista_module.cpp" ), +                                      os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" ,  "regularizers_CPU", "FGP_TV_core.c"), +									  os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" ,  "regularizers_CPU", "SplitBregman_TV_core.c"), +									  os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" ,  "regularizers_CPU", "LLT_model_core.c"), +									  os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" ,  "regularizers_CPU", "PatchBased_Regul_core.c"), +									  os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" ,  "regularizers_CPU", "TGV_PD_core.c"), +									  os.path.join("@CMAKE_SOURCE_DIR@" , "main_func" ,  "regularizers_CPU", "utils.c") +                                        ], +                             include_dirs=extra_include_dirs,  +							 library_dirs=extra_library_dirs,  +							 extra_compile_args=extra_compile_args,  +							 libraries=extra_libraries ),  +     +    ], +	zip_safe = False,	 +	packages = {'ccpi','ccpi.imaging'}, +) + + diff --git a/src/Python/setup_test.py b/src/Python/setup_test.py new file mode 100644 index 0000000..7c86175 --- /dev/null +++ b/src/Python/setup_test.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +import setuptools +from distutils.core import setup +from distutils.extension import Extension +from Cython.Distutils import build_ext + +import os +import sys +import numpy +import platform	 + +cil_version=os.environ['CIL_VERSION'] +if  cil_version == '': +    print("Please set the environmental variable CIL_VERSION") +    sys.exit(1) + +library_include_path = "" +library_lib_path = "" +try: +    library_include_path = os.environ['LIBRARY_INC'] +    library_lib_path = os.environ['LIBRARY_LIB'] +except: +    library_include_path = os.environ['PREFIX']+'/include' +    pass +     +extra_include_dirs = [numpy.get_include(), library_include_path] +extra_library_dirs = [library_include_path+"/../lib", "C:\\Apps\\Miniconda2\\envs\\cil27\\Library\\lib"] +extra_compile_args = ['-fopenmp','-O2', '-funsigned-char', '-Wall', '-std=c++0x'] +extra_libraries = [] +if platform.system() == 'Windows': +    extra_compile_args[0:] = ['/DWIN32','/EHsc','/DBOOST_ALL_NO_LIB']    +    #extra_include_dirs += ["..\\ContourTree\\", "..\\win32\\" , "..\\Core\\","."] +    if sys.version_info.major == 3 :    +        extra_libraries += ['boost_python3-vc140-mt-1_64', 'boost_numpy3-vc140-mt-1_64'] +    else: +        extra_libraries += ['boost_python-vc90-mt-1_64', 'boost_numpy-vc90-mt-1_64'] +else: +    #extra_include_dirs += ["../ContourTree/", "../Core/","."] +    if sys.version_info.major == 3: +        extra_libraries += ['boost_python3', 'boost_numpy3','gomp'] +    else: +        extra_libraries += ['boost_python', 'boost_numpy','gomp'] + +setup( +    name='ccpi', +	description='CCPi Core Imaging Library - FISTA Reconstruction Module', +	version=cil_version, +    cmdclass = {'build_ext': build_ext}, +    ext_modules = [Extension("prova", +                             sources=[  "Matlab2Python_utils.cpp", +                                        ], +                             include_dirs=extra_include_dirs, library_dirs=extra_library_dirs, extra_compile_args=extra_compile_args, libraries=extra_libraries ),  +     +    ], +	zip_safe = False,	 +	packages = {'ccpi','ccpi.reconstruction'}, +) diff --git a/src/Python/test.py b/src/Python/test.py new file mode 100644 index 0000000..db47380 --- /dev/null +++ b/src/Python/test.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Aug  3 14:08:09 2017 + +@author: ofn77899 +""" + +import prova +import numpy as np + +a = np.asarray([i for i in range(1*2*3)]) +a = a.reshape([1,2,3]) +print (a) +b = prova.mexFunction(a) +#print (b) +print (b[4].shape) +print (b[4]) +print (b[5]) + +def print_element(input): +	print ("f: {0}".format(input)) +	 +prova.doSomething(a, print_element, None) + +c = [] +def append_to_list(input, shouldPrint=False): +	c.append(input) +	if shouldPrint: +		print ("{0} appended to list {1}".format(input, c)) + +def element_wise_algebra(input, shouldPrint=True): +	ret = input - 7 +	if shouldPrint: +		print ("element_wise {0}".format(ret)) +	return ret +		 +prova.doSomething(a, append_to_list, None) +#print ("this is c: {0}".format(c)) + +b = prova.doSomething(a, None, element_wise_algebra) +#print (a) +print (b[5]) diff --git a/src/Python/test/astra_test.py b/src/Python/test/astra_test.py new file mode 100644 index 0000000..42c375a --- /dev/null +++ b/src/Python/test/astra_test.py @@ -0,0 +1,85 @@ +import astra +import numpy +import filefun + + +# read in the same data as the DemoRD2 +angles = filefun.dlmread("DemoRD2/angles.csv") +darks_ar = filefun.dlmread("DemoRD2/darks_ar.csv", separator=",") +flats_ar = filefun.dlmread("DemoRD2/flats_ar.csv", separator=",") + +if True: +    Sino3D = numpy.load("DemoRD2/Sino3D.npy") +else: +    sino = filefun.dlmread("DemoRD2/sino_01.csv", separator=",") +    a = map (lambda x:x, numpy.shape(sino)) +    a.append(20) + +    Sino3D = numpy.zeros(tuple(a), dtype="float") + +    for i in range(1,numpy.shape(Sino3D)[2]+1): +        print("Read file DemoRD2/sino_%02d.csv" % i) +        sino = filefun.dlmread("DemoRD2/sino_%02d.csv" % i, separator=",") +        Sino3D.T[i-1] = sino.T +     +Weights3D = numpy.asarray(Sino3D, dtype="float") + +##angles_rad = angles*(pi/180); % conversion to radians +##size_det = size(data_raw3D,1); % detectors dim +##angSize = size(data_raw3D, 2); % angles dim +##slices_tot = size(data_raw3D, 3); % no of slices +##recon_size = 950; % reconstruction size + + +angles_rad = angles * numpy.pi /180. +size_det, angSize, slices_tot = numpy.shape(Sino3D) +size_det, angSize, slices_tot = [int(i) for i in numpy.shape(Sino3D)] +recon_size = 950 +Z_slices = 3; +det_row_count = Z_slices; + +#proj_geom = astra_create_proj_geom('parallel3d', 1, 1, +# det_row_count, size_det, angles_rad); + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX +proj_geom = astra.create_proj_geom('parallel3d', +                                            detectorSpacingX, +                                            detectorSpacingY, +                                            det_row_count, +                                            size_det, +                                            angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +vol_geom = astra.create_vol_geom(recon_size,recon_size,Z_slices); + +sino = numpy.zeros((size_det, angSize, slices_tot), dtype="float") + +#weights = ones(size(sino)); +weights = numpy.ones(numpy.shape(sino)) + +##################################################################### +## PowerMethod for Lipschitz constant + +N = vol_geom['GridColCount'] +x1 = numpy.random.rand(1,N,N) +#sqweight = sqrt(weights(:,:,1)); +sqweight = numpy.sqrt(weights.T[0]).T +##proj_geomT = proj_geom; +proj_geomT = proj_geom.copy() +##proj_geomT.DetectorRowCount = 1; +proj_geomT['DetectorRowCount'] = 1 +##vol_geomT = vol_geom; +vol_geomT = vol_geom.copy() +##vol_geomT.GridSliceCount = 1; +vol_geomT['GridSliceCount'] = 1 + +##[sino_id, y] = astra_create_sino3d_cuda(x1, proj_geomT, vol_geomT); + +#sino_id, y = astra.create_sino3d_gpu(x1, proj_geomT, vol_geomT); +sino_id, y = astra.create_sino(x1, proj_geomT, vol_geomT); + +##y = sqweight.*y; +##astra_mex_data3d('delete', sino_id); +         + diff --git a/src/Python/test/readhd5.py b/src/Python/test/readhd5.py new file mode 100644 index 0000000..eff6c43 --- /dev/null +++ b/src/Python/test/readhd5.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +""" + +import h5py +import numpy + +def getEntry(nx, location): +    for item in nx[location].keys(): +        print (item) +         +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D')) +Weights3D = numpy.asarray(nx.get('/Weights3D')) +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad')) +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] + +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +#from ccpi.viewer.CILViewer2D import CILViewer2D +#v = CILViewer2D() +#v.setInputAsNumpy(Weights3D) +#v.startRenderLoop() + +import matplotlib.pyplot as plt +fig = plt.figure() + +a=fig.add_subplot(1,1,1) +a.set_title('noise') +imgplot = plt.imshow(Weights3D[0].T) +plt.show() diff --git a/src/Python/test/simple_astra_test.py b/src/Python/test/simple_astra_test.py new file mode 100644 index 0000000..905eeea --- /dev/null +++ b/src/Python/test/simple_astra_test.py @@ -0,0 +1,25 @@ +import astra +import numpy + +detectorSpacingX = 1.0 +detectorSpacingY = 1.0 +det_row_count = 128 +det_col_count = 128 + +angles_rad = numpy.asarray([i for i in range(360)], dtype=float) / 180. * numpy.pi + +proj_geom = astra.creators.create_proj_geom('parallel3d', +                                            detectorSpacingX, +                                            detectorSpacingY, +                                            det_row_count, +                                            det_col_count, +                                            angles_rad) + +image_size_x = 64 +image_size_y = 64 +image_size_z = 32 + +vol_geom = astra.creators.create_vol_geom(image_size_x,image_size_y,image_size_z) + +x1 = numpy.random.rand(image_size_z,image_size_y,image_size_x) +sino_id, y = astra.creators.create_sino3d_gpu(x1, proj_geom, vol_geom) diff --git a/src/Python/test/test_reconstructor-os.py b/src/Python/test/test_reconstructor-os.py new file mode 100644 index 0000000..3f419cf --- /dev/null +++ b/src/Python/test/test_reconstructor-os.py @@ -0,0 +1,352 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor +import astra +import matplotlib.pyplot as plt +from ccpi.imaging.Regularizer import Regularizer + +def RMSE(signal1, signal2): +    '''RMSE Root Mean Squared Error''' +    if numpy.shape(signal1) == numpy.shape(signal2): +        err = (signal1 - signal2) +        err = numpy.sum( err * err )/numpy.size(signal1);  # MSE +        err = sqrt(err);                                   # RMSE +        return err +    else: +        raise Exception('Input signals must have the same shape') +   +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") +Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 20 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', +                                            detectorSpacingX, +                                            detectorSpacingY, +                                            det_row_count, +                                            det_col_count, +                                            angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, +                                           image_size_y, +                                           image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +fistaRecon = FISTAReconstructor(proj_geom, +                                vol_geom, +                                Sino3D , +                                weights=Weights3D) + +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +fistaRecon.setParameter(number_of_iterations = 12) +fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +fistaRecon.setParameter(ring_alpha = 21) +fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + + +#reg = Regularizer(Regularizer.Algorithm.FGP_TV) +#reg.setParameter(regularization_parameter=0.005, +#                          number_of_iterations=50) +reg = Regularizer(Regularizer.Algorithm.LLT_model) +reg.setParameter(regularization_parameter=25, +                          time_step=0.0003, +                          tolerance_constant=0.0001, +                          number_of_iterations=300) + + +## Ordered subset +if True: +    subsets = 16 +    fistaRecon.setParameter(subsets=subsets) +    fistaRecon.createOrderedSubsets() +    angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] +    #binEdges = numpy.linspace(angles.min(), +    #                          angles.max(), +    #                          subsets + 1) +    binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) +    # get rearranged subset indices +    IndicesReorg = numpy.zeros((numpy.shape(angles))) +    counterM = 0 +    for ii in range(binsDiscr.max()): +        counter = 0 +        for jj in range(subsets): +            curr_index = ii + jj  + counter +            #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) +            if binsDiscr[jj] > ii: +                if (counterM < numpy.size(IndicesReorg)): +                    IndicesReorg[counterM] = curr_index +                counterM = counterM + 1 +                 +            counter = counter + binsDiscr[jj] - 1 + + +if True: +    print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +    print ("prepare for iteration") +    fistaRecon.prepareForIteration() +     +     + +    print("initializing  ...") +    if False: +        # if X doesn't exist +        #N = params.vol_geom.GridColCount +        N = vol_geom['GridColCount'] +        print ("N " + str(N)) +        X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) +    else: +        #X = fistaRecon.initialize() +        X = numpy.load("X.npy") +     +    print (numpy.shape(X)) +    X_t = X.copy() +    print ("initialized") +    proj_geom , vol_geom, sino , \ +        SlicesZ, weights , alpha_ring = fistaRecon.getParameter( +            ['projector_geometry' , 'output_geometry', +             'input_sinogram', 'SlicesZ' ,  'weights', 'ring_alpha']) +    lambdaR_L1 , alpha_ring , weights , L_const= \ +                       fistaRecon.getParameter(['ring_lambda_R_L1', +                                               'ring_alpha' , 'weights', +                                               'Lipschitz_constant']) + +    #fistaRecon.setParameter(number_of_iterations = 3) +    iterFISTA = fistaRecon.getParameter('number_of_iterations') +    # errors vector (if the ground truth is given) +    Resid_error = numpy.zeros((iterFISTA)); +    # objective function values vector +    objective = numpy.zeros((iterFISTA));  + +       +    t = 1 +     + +    ## additional for  +    proj_geomSUB = proj_geom.copy() +    fistaRecon.residual2 = numpy.zeros(numpy.shape(fistaRecon.pars['input_sinogram'])) +    residual2 = fistaRecon.residual2 +    sino_updt_FULL = fistaRecon.residual.copy() +    r_x = fistaRecon.r.copy() +     +    print ("starting iterations") +##    % Outer FISTA iterations loop +    for i in range(fistaRecon.getParameter('number_of_iterations')): +##        % With OS approach it becomes trickier to correlate independent subsets, hence additional work is required +##        % one solution is to work with a full sinogram at times +##        if ((i >= 3) && (lambdaR_L1 > 0)) +##            [sino_id2, sino_updt2] = astra_create_sino3d_cuda(X, proj_geom, vol_geom); +##            astra_mex_data3d('delete', sino_id2); +##        end +        # With OS approach it becomes trickier to correlate independent subsets, +        # hence additional work is required one solution is to work with a full +        # sinogram at times + +        r_old = fistaRecon.r.copy() +        t_old = t +        SlicesZ, anglesNumb, Detectors = \ +                    numpy.shape(fistaRecon.getParameter('input_sinogram'))        ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4 +        if (i > 1 and lambdaR_L1 > 0) : +            for kkk in range(anglesNumb): +                  +                 residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ +                                       ((sino_updt_FULL[:,kkk,:]).squeeze() - \ +                                        (sino[:,kkk,:]).squeeze() -\ +                                        (alpha_ring * r_x) +                                        ) +             +            vec = fistaRecon.residual.sum(axis = 1) +            #if SlicesZ > 1: +            #    vec = vec[:,1,:] # 1 or 0? +            r_x = fistaRecon.r_x +            fistaRecon.r = (r_x - (1./L_const) * vec).copy() + +        # subset loop +        counterInd = 1 +        geometry_type = fistaRecon.getParameter('projector_geometry')['type'] +        if geometry_type == 'parallel' or \ +           geometry_type == 'fanflat' or \ +           geometry_type == 'fanflat_vec' : +             +            for kkk in range(SlicesZ): +                sino_id, sinoT[kkk] = \ +                         astra.creators.create_sino3d_gpu( +                             X_t[kkk:kkk+1], proj_geomSUB, vol_geom) +                sino_updt_Sub[kkk] = sinoT.T.copy() +                 +        else: +            sino_id, sino_updt_Sub = \ +                astra.creators.create_sino3d_gpu(X_t, proj_geomSUB, vol_geom) +         +        astra.matlab.data3d('delete', sino_id) +   +        for ss in range(fistaRecon.getParameter('subsets')): +            print ("Subset {0}".format(ss)) +            X_old = X.copy() +            t_old = t +             +            # the number of projections per subset +            numProjSub = fistaRecon.getParameter('os_bins')[ss] +            CurrSubIndices = fistaRecon.getParameter('os_indices')\ +                             [counterInd:counterInd+numProjSub] +            #print ("Len CurrSubIndices {0}".format(numProjSub)) +            mask = numpy.zeros(numpy.shape(angles), dtype=bool) +            cc = 0 +            for j in range(len(CurrSubIndices)): +                mask[int(CurrSubIndices[j])] = True +            proj_geomSUB['ProjectionAngles'] = angles[mask] + +            shape = list(numpy.shape(fistaRecon.getParameter('input_sinogram'))) +            shape[1] = numProjSub +            sino_updt_Sub = numpy.zeros(shape) + +            if geometry_type == 'parallel' or \ +               geometry_type == 'fanflat' or \ +               geometry_type == 'fanflat_vec' : + +                for kkk in range(SlicesZ): +                    sino_id, sinoT = astra.creators.create_sino3d_gpu ( +                        X_t[kkk:kkk+1] , proj_geomSUB, vol_geom) +                    sino_updt_Sub[kkk] = sinoT.T.copy() +                     +            else: +                # for 3D geometry (watch the GPU memory overflow in ASTRA < 1.8) +                sino_id, sino_updt_Sub = \ +                     astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom) +                 +            astra.matlab.data3d('delete', sino_id) +                 +             + + +            ## RING REMOVAL +            residual = fistaRecon.residual +             +             +            if lambdaR_L1 > 0 : +                 print ("ring removal") +                 residualSub = numpy.zeros(shape) +    ##             for a chosen subset +    ##                for kkk = 1:numProjSub +    ##                    indC = CurrSubIndeces(kkk); +    ##                    residualSub(:,kkk,:) =  squeeze(weights(:,indC,:)).*(squeeze(sino_updt_Sub(:,kkk,:)) - (squeeze(sino(:,indC,:)) - alpha_ring.*r_x)); +    ##                    sino_updt_FULL(:,indC,:) = squeeze(sino_updt_Sub(:,kkk,:)); % filling the full sinogram +    ##                end +                 for kkk in range(numProjSub): +                     #print ("ring removal indC ... {0}".format(kkk)) +                     indC = int(CurrSubIndices[kkk]) +                     residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \ +                            (sino_updt_Sub[:,kkk,:].squeeze() - \ +                             sino[:,indC,:].squeeze() - alpha_ring * r_x) +                     # filling the full sinogram +                     sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze() + +            else: +                #PWLS model +                residualSub = weights[:,CurrSubIndices,:] * \ +                              ( sino_updt_Sub - sino[:,CurrSubIndices,:].squeeze() ) +                objective[i] = 0.5 * numpy.linalg.norm(residualSub) + +            if geometry_type == 'parallel' or \ +               geometry_type == 'fanflat' or \ +               geometry_type == 'fanflat_vec' : +                # if geometry is 2D use slice-by-slice projection-backprojection +                # routine +                x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32) +                for kkk in range(SlicesZ): +                     +                    x_id, x_temp[kkk] = \ +                             astra.creators.create_backprojection3d_gpu( +                                 residualSub[kkk:kkk+1], +                                 proj_geomSUB, vol_geom) +                     +            else: +                x_id, x_temp = \ +                      astra.creators.create_backprojection3d_gpu( +                          residualSub, proj_geomSUB, vol_geom) + +            astra.matlab.data3d('delete', x_id) +            X = X_t - (1/L_const) * x_temp + +         + +        ## REGULARIZATION +        ## SKIPPING FOR NOW +        ## Should be simpli +        # regularizer = fistaRecon.getParameter('regularizer') +        # for slices: +        # out = regularizer(input=X) +        print ("regularizer") +        X = reg(input=X)[0] + + +        ## FINAL +        print ("final") +        lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') +        if lambdaR_L1 > 0: +            fistaRecon.r = numpy.max( +                numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ +                numpy.sign(fistaRecon.r) +            # updating r +            r_x = fistaRecon.r + ((t_old-1)/t) * (fistaRecon.r - r_old) +         + +        if fistaRecon.getParameter('region_of_interest') is None: +            string = 'Iteration Number {0} | Objective {1} \n' +            print (string.format( i, objective[i])) +        else: +            ROI , X_ideal = fistaRecon.getParameter('region_of_interest', +                                                    'ideal_image') +             +            Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) +            string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' +            print (string.format(i,Resid_error[i], objective[i])) + +    numpy.save("X_out_os.npy", X) + +else: +    fistaRecon = FISTAReconstructor(proj_geom, +                                vol_geom, +                                Sino3D , +                                weights=Weights3D) + +    print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +    fistaRecon.setParameter(number_of_iterations = 12) +    fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +    fistaRecon.setParameter(ring_alpha = 21) +    fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) +    fistaRecon.prepareForIteration() +    X = fistaRecon.iterate(numpy.load("X.npy")) diff --git a/src/Python/test/test_reconstructor.py b/src/Python/test/test_reconstructor.py new file mode 100644 index 0000000..3342301 --- /dev/null +++ b/src/Python/test/test_reconstructor.py @@ -0,0 +1,309 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.reconstruction.FISTAReconstructor import FISTAReconstructor +import astra +import matplotlib.pyplot as plt + +def RMSE(signal1, signal2): +    '''RMSE Root Mean Squared Error''' +    if numpy.shape(signal1) == numpy.shape(signal2): +        err = (signal1 - signal2) +        err = numpy.sum( err * err )/numpy.size(signal1);  # MSE +        err = sqrt(err);                                   # RMSE +        return err +    else: +        raise Exception('Input signals must have the same shape') +   +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") +Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 20 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', +                                            detectorSpacingX, +                                            detectorSpacingY, +                                            det_row_count, +                                            det_col_count, +                                            angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, +                                           image_size_y, +                                           image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +fistaRecon = FISTAReconstructor(proj_geom, +                                vol_geom, +                                Sino3D , +                                weights=Weights3D) + +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +fistaRecon.setParameter(number_of_iterations = 12) +fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +fistaRecon.setParameter(ring_alpha = 21) +fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + +reg = Regularizer(Regularizer.Algorithm.LLT_model) +reg.setParameter(regularization_parameter=25, +                          time_step=0.0003, +                          tolerance_constant=0.0001, +                          number_of_iterations=300) +fistaRecon.setParameter(regularizer = reg) + +## Ordered subset +if False: +    subsets = 16 +    angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] +    #binEdges = numpy.linspace(angles.min(), +    #                          angles.max(), +    #                          subsets + 1) +    binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) +    # get rearranged subset indices +    IndicesReorg = numpy.zeros((numpy.shape(angles))) +    counterM = 0 +    for ii in range(binsDiscr.max()): +        counter = 0 +        for jj in range(subsets): +            curr_index = ii + jj  + counter +            #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) +            if binsDiscr[jj] > ii: +                if (counterM < numpy.size(IndicesReorg)): +                    IndicesReorg[counterM] = curr_index +                counterM = counterM + 1 +                 +            counter = counter + binsDiscr[jj] - 1 + + +if False: +    print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +    print ("prepare for iteration") +    fistaRecon.prepareForIteration() +     +     + +    print("initializing  ...") +    if False: +        # if X doesn't exist +        #N = params.vol_geom.GridColCount +        N = vol_geom['GridColCount'] +        print ("N " + str(N)) +        X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) +    else: +        #X = fistaRecon.initialize() +        X = numpy.load("X.npy") +     +    print (numpy.shape(X)) +    X_t = X.copy() +    print ("initialized") +    proj_geom , vol_geom, sino , \ +                      SlicesZ = fistaRecon.getParameter(['projector_geometry' , +                                                            'output_geometry', +                                                            'input_sinogram', +                                                         'SlicesZ']) + +    #fistaRecon.setParameter(number_of_iterations = 3) +    iterFISTA = fistaRecon.getParameter('number_of_iterations') +    # errors vector (if the ground truth is given) +    Resid_error = numpy.zeros((iterFISTA)); +    # objective function values vector +    objective = numpy.zeros((iterFISTA));  + +       +    t = 1 + + +    print ("starting iterations") +##    % Outer FISTA iterations loop +    for i in range(fistaRecon.getParameter('number_of_iterations')): +        X_old = X.copy() +        t_old = t +        r_old = fistaRecon.r.copy() +        if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ +           fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \ +           fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' : +            # if the geometry is parallel use slice-by-slice +            # projection-backprojection routine +            #sino_updt = zeros(size(sino),'single'); +            proj_geomT = proj_geom.copy() +            proj_geomT['DetectorRowCount'] = 1 +            vol_geomT = vol_geom.copy() +            vol_geomT['GridSliceCount'] = 1; +            sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) +            for kkk in range(SlicesZ): +                sino_id, sino_updt[kkk] = \ +                         astra.creators.create_sino3d_gpu( +                             X_t[kkk:kkk+1], proj_geom, vol_geom) +                astra.matlab.data3d('delete', sino_id) +        else: +            # for divergent 3D geometry (watch the GPU memory overflow in +            # ASTRA versions < 1.8) +            #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); +            sino_id, sino_updt = astra.creators.create_sino3d_gpu( +                X_t, proj_geom, vol_geom) + +        ## RING REMOVAL +        residual = fistaRecon.residual +        lambdaR_L1 , alpha_ring , weights , L_const= \ +                   fistaRecon.getParameter(['ring_lambda_R_L1', +                                           'ring_alpha' , 'weights', +                                           'Lipschitz_constant']) +        r_x = fistaRecon.r_x +        SlicesZ, anglesNumb, Detectors = \ +                    numpy.shape(fistaRecon.getParameter('input_sinogram')) +        if lambdaR_L1 > 0 : +             print ("ring removal") +             for kkk in range(anglesNumb): +                  +                 residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ +                                       ((sino_updt[:,kkk,:]).squeeze() - \ +                                        (sino[:,kkk,:]).squeeze() -\ +                                        (alpha_ring * r_x) +                                        ) +             vec = residual.sum(axis = 1) +             #if SlicesZ > 1: +             #    vec = vec[:,1,:].squeeze() +             fistaRecon.r = (r_x - (1./L_const) * vec).copy() +             objective[i] = (0.5 * (residual ** 2).sum()) +##            % the ring removal part (Group-Huber fidelity) +##            for kkk = 1:anglesNumb +##                residual(:,kkk,:) =  squeeze(weights(:,kkk,:)).* +##                 (squeeze(sino_updt(:,kkk,:)) - +##                 (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); +##            end +##            vec = sum(residual,2); +##            if (SlicesZ > 1) +##                vec = squeeze(vec(:,1,:)); +##            end +##            r = r_x - (1./L_const).*vec; +##            objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + +         + +        # Projection/Backprojection Routine +        if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ +           fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\ +           fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec': +            x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) +            print ("Projection/Backprojection Routine") +            for kkk in range(SlicesZ): +                 +                x_id, x_temp[kkk] = \ +                         astra.creators.create_backprojection3d_gpu( +                             residual[kkk:kkk+1], +                             proj_geomT, vol_geomT) +                astra.matlab.data3d('delete', x_id) +        else: +            x_id, x_temp = \ +                  astra.creators.create_backprojection3d_gpu( +                      residual, proj_geom, vol_geom)             + +        X = X_t - (1/L_const) * x_temp +        astra.matlab.data3d('delete', sino_id) +        astra.matlab.data3d('delete', x_id) +         + +        ## REGULARIZATION +        ## SKIPPING FOR NOW +        ## Should be simpli +        # regularizer = fistaRecon.getParameter('regularizer') +        # for slices: +        # out = regularizer(input=X) +        print ("skipping regularizer") + + +        ## FINAL +        print ("final") +        lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') +        if lambdaR_L1 > 0: +            fistaRecon.r = numpy.max( +                numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ +                numpy.sign(fistaRecon.r) +        t = (1 + numpy.sqrt(1 + 4 * t**2))/2 +        X_t = X + (((t_old -1)/t) * (X - X_old)) + +        if lambdaR_L1 > 0: +            fistaRecon.r_x = fistaRecon.r + \ +                             (((t_old-1)/t) * (fistaRecon.r - r_old)) + +        if fistaRecon.getParameter('region_of_interest') is None: +            string = 'Iteration Number {0} | Objective {1} \n' +            print (string.format( i, objective[i])) +        else: +            ROI , X_ideal = fistaRecon.getParameter('region_of_interest', +                                                    'ideal_image') +             +            Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) +            string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' +            print (string.format(i,Resid_error[i], objective[i])) +             +##        if (lambdaR_L1 > 0) +##            r =  max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector +##        end +##         +##        t = (1 + sqrt(1 + 4*t^2))/2; % updating t +##        X_t = X + ((t_old-1)/t).*(X - X_old); % updating X +##         +##        if (lambdaR_L1 > 0) +##            r_x = r + ((t_old-1)/t).*(r - r_old); % updating r +##        end +##         +##        if (show == 1) +##            figure(10); imshow(X(:,:,slice), [0 maxvalplot]); +##            if (lambdaR_L1 > 0) +##                figure(11); plot(r); title('Rings offset vector') +##            end +##            pause(0.01); +##        end +##        if (strcmp(X_ideal, 'none' ) == 0) +##            Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); +##            fprintf('%s %i %s %s %.4f  %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); +##        else +##            fprintf('%s %i  %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); +##        end +else: +    fistaRecon = FISTAReconstructor(proj_geom, +                                vol_geom, +                                Sino3D , +                                weights=Weights3D) + +    print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +    fistaRecon.setParameter(number_of_iterations = 12) +    fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +    fistaRecon.setParameter(ring_alpha = 21) +    fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) +    fistaRecon.prepareForIteration() +    X = fistaRecon.iterate(numpy.load("X.npy")) +    numpy.save("X_out.npy", X) diff --git a/src/Python/test/test_regularizers.py b/src/Python/test/test_regularizers.py new file mode 100644 index 0000000..27e4ed3 --- /dev/null +++ b/src/Python/test/test_regularizers.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Aug  4 11:10:05 2017 + +@author: ofn77899 +""" + +#from ccpi.viewer.CILViewer2D import Converter +#import vtk + +import matplotlib.pyplot as plt +import numpy as np +import os     +from enum import Enum +import timeit +#from PIL import Image +#from Regularizer import Regularizer +from ccpi.imaging.Regularizer import Regularizer + +############################################################################### +#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 +#NRMSE a normalization of the root of the mean squared error +#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum +# intensity from the two images being compared, and respectively the same for +# minval. RMSE is given by the square root of MSE:  +# sqrt[(sum(A - B) ** 2) / |A|], +# where |A| means the number of elements in A. By doing this, the maximum value +# given by RMSE is maxval. + +def nrmse(im1, im2): +    a, b = im1.shape +    rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b)) +    max_val = max(np.max(im1), np.max(im2)) +    min_val = min(np.min(im1), np.min(im2)) +    return 1 - (rmse / (max_val - min_val)) +############################################################################### + +############################################################################### +# +#  2D Regularizers +# +############################################################################### +#Example: +# figure; +# Im = double(imread('lena_gray_256.tif'))/255;  % loading image +# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + + +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif" +#filename = r'/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif' + +#reader = vtk.vtkTIFFReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +Im = plt.imread(filename)                      +#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255 +#img.show() +Im = np.asarray(Im, dtype='float32') + + + + +#imgplot = plt.imshow(Im) +perc = 0.05 +u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +# map the u0 u0->u0>0 +f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +u0 = f(u0).astype('float32') + +## plot  +fig = plt.figure() +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) + +a=fig.add_subplot(2,3,1) +a.set_title('noise') +imgplot = plt.imshow(u0,cmap="gray") + +reg_output = [] +############################################################################## +# Call regularizer + +####################### SplitBregman_TV ##################################### +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +use_object = True +if use_object: +    reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +    print (reg.pars) +    reg.setParameter(input=u0)     +    reg.setParameter(regularization_parameter=10.) +    # or  +    # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30, +              #tolerance_constant=1e-4,  +              #TV_Penalty=Regularizer.TotalVariationPenalty.l1) +    plotme = reg() [0] +    pars = reg.pars +    textstr = reg.printParametersToString()  +     +    #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +              #tolerance_constant=1e-4,  +    #          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +     +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +#          tolerance_constant=1e-4,  +#          TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +else: +    out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) +    pars = out2[2] +    reg_output.append(out2) +    plotme = reg_output[-1][0] +    textstr = out2[-1] + +a=fig.add_subplot(2,3,2) + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(plotme,cmap="gray") + +###################### FGP_TV ######################################### +# u = FGP_TV(single(u0), 0.05, 100, 1e-04); +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.0005, +                          number_of_iterations=50) +pars = out2[-2] + +reg_output.append(out2) + +a=fig.add_subplot(2,3,3) + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +         verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + +###################### LLT_model ######################################### +# * u0 = Im + .03*randn(size(Im)); % adding noise +# [Den] = LLT_model(single(u0), 10, 0.1, 1); +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);  +#input, regularization_parameter , time_step, number_of_iterations, +#                  tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +                          time_step=0.0003, +                          tolerance_constant=0.0001, +                          number_of_iterations=300) +pars = out2[-2] + +reg_output.append(out2) + +a=fig.add_subplot(2,3,4) + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +         verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + + +# ###################### PatchBased_Regul ######################################### +# # Quick 2D denoising example in Matlab:    +# #   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +# #   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# #   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);  + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +                       searching_window_ratio=3, +                       similarity_window_ratio=1, +                       PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) + + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + + +# ###################### TGV_PD ######################################### +# # Quick 2D denoising example in Matlab:    +# #   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +# #   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# #   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +                           first_order_term=1.3, +                           second_order_term=1, +                           number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,6) + + +textstr = out2[-1] + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +         verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + + +plt.show() + +################################################################################ +## +##  3D Regularizers +## +################################################################################ +##Example: +## figure; +## Im = double(imread('lena_gray_256.tif'))/255;  % loading image +## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" +# +#reader = vtk.vtkMetaImageReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +##vtk returns 3D images, let's take just the one slice there is as 2D +#Im = Converter.vtk2numpy(reader.GetOutput()) +#Im = Im.astype('float32') +##imgplot = plt.imshow(Im) +#perc = 0.05 +#u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +## map the u0 u0->u0>0 +#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +#u0 = f(u0).astype('float32') +#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), +#                                        reader.GetOutput().GetOrigin()) +#converter.Update() +#writer = vtk.vtkMetaImageWriter() +#writer.SetInputData(converter.GetOutput()) +#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") +##writer.Write() +# +# +### plot  +#fig3D = plt.figure() +##a=fig.add_subplot(3,3,1) +##a.set_title('Original') +##imgplot = plt.imshow(Im) +#sliceNo = 32 +# +#a=fig3D.add_subplot(2,3,1) +#a.set_title('noise') +#imgplot = plt.imshow(u0.T[sliceNo]) +# +#reg_output3d = [] +# +############################################################################### +## Call regularizer +# +######################## SplitBregman_TV ##################################### +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +# +##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +##          #tolerance_constant=1e-4,  +##          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +#          tolerance_constant=1e-4,  +#          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +# +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### FGP_TV ######################################### +## u = FGP_TV(single(u0), 0.05, 100, 1e-04); +#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +#                          number_of_iterations=200) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### LLT_model ######################################### +## * u0 = Im + .03*randn(size(Im)); % adding noise +## [Den] = LLT_model(single(u0), 10, 0.1, 1); +##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);  +##input, regularization_parameter , time_step, number_of_iterations, +##                  tolerance_constant, restrictive_Z_smoothing=0 +#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +#                          time_step=0.0003, +#                          tolerance_constant=0.0001, +#                          number_of_iterations=300) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### PatchBased_Regul ######################################### +## Quick 2D denoising example in Matlab:    +##   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +##   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +##   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);  +# +#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +#                          searching_window_ratio=3, +#                          similarity_window_ratio=1, +#                          PB_filtering_parameter=0.08) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# + +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab:    +#   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +#   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +#   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +#                          first_order_term=1.3, +#                          second_order_term=1, +#                          number_of_iterations=550) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) diff --git a/src/Python/test/test_regularizers_3d.py b/src/Python/test/test_regularizers_3d.py new file mode 100644 index 0000000..2d11a7e --- /dev/null +++ b/src/Python/test/test_regularizers_3d.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Aug  4 11:10:05 2017 + +@author: ofn77899 +""" + +#from ccpi.viewer.CILViewer2D import Converter +#import vtk + +import matplotlib.pyplot as plt +import numpy as np +import os     +from enum import Enum +import timeit +#from PIL import Image +#from Regularizer import Regularizer +from ccpi.imaging.Regularizer import Regularizer + +############################################################################### +#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 +#NRMSE a normalization of the root of the mean squared error +#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum +# intensity from the two images being compared, and respectively the same for +# minval. RMSE is given by the square root of MSE:  +# sqrt[(sum(A - B) ** 2) / |A|], +# where |A| means the number of elements in A. By doing this, the maximum value +# given by RMSE is maxval. + +def nrmse(im1, im2): +    a, b = im1.shape +    rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b)) +    max_val = max(np.max(im1), np.max(im2)) +    min_val = min(np.min(im1), np.min(im2)) +    return 1 - (rmse / (max_val - min_val)) +############################################################################### + +############################################################################### +# +#  2D Regularizers +# +############################################################################### +#Example: +# figure; +# Im = double(imread('lena_gray_256.tif'))/255;  % loading image +# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + + +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif" +#filename = r'/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif' + +#reader = vtk.vtkTIFFReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +Im = plt.imread(filename)                      +#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255 +#img.show() +Im = np.asarray(Im, dtype='float32') + +# create a 3D image by stacking N of this images + + +#imgplot = plt.imshow(Im) +perc = 0.05 +u_n = Im + (perc* np.random.normal(size=np.shape(Im))) +y,z = np.shape(u_n) +u_n = np.reshape(u_n , (1,y,z)) + +u0 = u_n.copy() +for i in range (19): +    u_n = Im + (perc* np.random.normal(size=np.shape(Im))) +    u_n = np.reshape(u_n , (1,y,z)) + +    u0 = np.vstack ( (u0, u_n) ) + +# map the u0 u0->u0>0 +f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +u0 = f(u0).astype('float32') + +print ("Passed image shape {0}".format(np.shape(u0))) + +## plot  +fig = plt.figure() +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) +sliceno = 10 + +a=fig.add_subplot(2,3,1) +a.set_title('noise') +imgplot = plt.imshow(u0[sliceno],cmap="gray") + +reg_output = [] +############################################################################## +# Call regularizer + +####################### SplitBregman_TV ##################################### +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +use_object = True +if use_object: +    reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +    print (reg.pars) +    reg.setParameter(input=u0)     +    reg.setParameter(regularization_parameter=10.) +    # or  +    # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30, +              #tolerance_constant=1e-4,  +              #TV_Penalty=Regularizer.TotalVariationPenalty.l1) +    plotme = reg() [0] +    pars = reg.pars +    textstr = reg.printParametersToString()  +     +    #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +              #tolerance_constant=1e-4,  +    #          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +     +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +#          tolerance_constant=1e-4,  +#          TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +else: +    out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) +    pars = out2[2] +    reg_output.append(out2) +    plotme = reg_output[-1][0] +    textstr = out2[-1] + +a=fig.add_subplot(2,3,2) + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(plotme[sliceno],cmap="gray") + +###################### FGP_TV ######################################### +# u = FGP_TV(single(u0), 0.05, 100, 1e-04); +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.0005, +                          number_of_iterations=50) +pars = out2[-2] + +reg_output.append(out2) + +a=fig.add_subplot(2,3,3) + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +         verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0][sliceno]) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray") + +###################### LLT_model ######################################### +# * u0 = Im + .03*randn(size(Im)); % adding noise +# [Den] = LLT_model(single(u0), 10, 0.1, 1); +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);  +#input, regularization_parameter , time_step, number_of_iterations, +#                  tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +                          time_step=0.0003, +                          tolerance_constant=0.0001, +                          number_of_iterations=300) +pars = out2[-2] + +reg_output.append(out2) + +a=fig.add_subplot(2,3,4) + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +         verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray") + + +# ###################### PatchBased_Regul ######################################### +# # Quick 2D denoising example in Matlab:    +# #   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +# #   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# #   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);  + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +                       searching_window_ratio=3, +                       similarity_window_ratio=1, +                       PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) + + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray") + + +# ###################### TGV_PD ######################################### +# # Quick 2D denoising example in Matlab:    +# #   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +# #   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# #   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +                           first_order_term=1.3, +                           second_order_term=1, +                           number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,6) + + +textstr = out2[-1] + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +         verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0][sliceno],cmap="gray") + + +plt.show() + +################################################################################ +## +##  3D Regularizers +## +################################################################################ +##Example: +## figure; +## Im = double(imread('lena_gray_256.tif'))/255;  % loading image +## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" +# +#reader = vtk.vtkMetaImageReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +##vtk returns 3D images, let's take just the one slice there is as 2D +#Im = Converter.vtk2numpy(reader.GetOutput()) +#Im = Im.astype('float32') +##imgplot = plt.imshow(Im) +#perc = 0.05 +#u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +## map the u0 u0->u0>0 +#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +#u0 = f(u0).astype('float32') +#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), +#                                        reader.GetOutput().GetOrigin()) +#converter.Update() +#writer = vtk.vtkMetaImageWriter() +#writer.SetInputData(converter.GetOutput()) +#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") +##writer.Write() +# +# +### plot  +#fig3D = plt.figure() +##a=fig.add_subplot(3,3,1) +##a.set_title('Original') +##imgplot = plt.imshow(Im) +#sliceNo = 32 +# +#a=fig3D.add_subplot(2,3,1) +#a.set_title('noise') +#imgplot = plt.imshow(u0.T[sliceNo]) +# +#reg_output3d = [] +# +############################################################################### +## Call regularizer +# +######################## SplitBregman_TV ##################################### +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +# +##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +##          #tolerance_constant=1e-4,  +##          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +#          tolerance_constant=1e-4,  +#          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +# +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### FGP_TV ######################################### +## u = FGP_TV(single(u0), 0.05, 100, 1e-04); +#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +#                          number_of_iterations=200) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### LLT_model ######################################### +## * u0 = Im + .03*randn(size(Im)); % adding noise +## [Den] = LLT_model(single(u0), 10, 0.1, 1); +##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);  +##input, regularization_parameter , time_step, number_of_iterations, +##                  tolerance_constant, restrictive_Z_smoothing=0 +#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +#                          time_step=0.0003, +#                          tolerance_constant=0.0001, +#                          number_of_iterations=300) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### PatchBased_Regul ######################################### +## Quick 2D denoising example in Matlab:    +##   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +##   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +##   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);  +# +#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +#                          searching_window_ratio=3, +#                          similarity_window_ratio=1, +#                          PB_filtering_parameter=0.08) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# + +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab:    +#   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +#   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +#   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +#                          first_order_term=1.3, +#                          second_order_term=1, +#                          number_of_iterations=550) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) diff --git a/src/Python/test_reconstructor.py b/src/Python/test_reconstructor.py new file mode 100644 index 0000000..07668ba --- /dev/null +++ b/src/Python/test_reconstructor.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 23 16:34:49 2017 + +@author: ofn77899 +Based on DemoRD2.m +""" + +import h5py +import numpy + +from ccpi.fista.FISTAReconstructor import FISTAReconstructor +import astra +import matplotlib.pyplot as plt + +def RMSE(signal1, signal2): +    '''RMSE Root Mean Squared Error''' +    if numpy.shape(signal1) == numpy.shape(signal2): +        err = (signal1 - signal2) +        err = numpy.sum( err * err )/numpy.size(signal1);  # MSE +        err = sqrt(err);                                   # RMSE +        return err +    else: +        raise Exception('Input signals must have the same shape') +   +filename = r'/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/demos/DendrData.h5' +nx = h5py.File(filename, "r") +#getEntry(nx, '/') +# I have exported the entries as children of / +entries = [entry for entry in nx['/'].keys()] +print (entries) + +Sino3D = numpy.asarray(nx.get('/Sino3D'), dtype="float32") +Weights3D = numpy.asarray(nx.get('/Weights3D'), dtype="float32") +angSize = numpy.asarray(nx.get('/angSize'), dtype=int)[0] +angles_rad = numpy.asarray(nx.get('/angles_rad'), dtype="float32") +recon_size = numpy.asarray(nx.get('/recon_size'), dtype=int)[0] +size_det = numpy.asarray(nx.get('/size_det'), dtype=int)[0] +slices_tot = numpy.asarray(nx.get('/slices_tot'), dtype=int)[0] + +Z_slices = 20 +det_row_count = Z_slices +# next definition is just for consistency of naming +det_col_count = size_det + +detectorSpacingX = 1.0 +detectorSpacingY = detectorSpacingX + + +proj_geom = astra.creators.create_proj_geom('parallel3d', +                                            detectorSpacingX, +                                            detectorSpacingY, +                                            det_row_count, +                                            det_col_count, +                                            angles_rad) + +#vol_geom = astra_create_vol_geom(recon_size,recon_size,Z_slices); +image_size_x = recon_size +image_size_y = recon_size +image_size_z = Z_slices +vol_geom = astra.creators.create_vol_geom( image_size_x, +                                           image_size_y, +                                           image_size_z) + +## First pass the arguments to the FISTAReconstructor and test the +## Lipschitz constant + +fistaRecon = FISTAReconstructor(proj_geom, +                                vol_geom, +                                Sino3D , +                                weights=Weights3D) + +print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +fistaRecon.setParameter(number_of_iterations = 12) +fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +fistaRecon.setParameter(ring_alpha = 21) +fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) + +## Ordered subset +if False: +    subsets = 16 +    angles = fistaRecon.getParameter('projector_geometry')['ProjectionAngles'] +    #binEdges = numpy.linspace(angles.min(), +    #                          angles.max(), +    #                          subsets + 1) +    binsDiscr, binEdges = numpy.histogram(angles, bins=subsets) +    # get rearranged subset indices +    IndicesReorg = numpy.zeros((numpy.shape(angles))) +    counterM = 0 +    for ii in range(binsDiscr.max()): +        counter = 0 +        for jj in range(subsets): +            curr_index = ii + jj  + counter +            #print ("{0} {1} {2}".format(binsDiscr[jj] , ii, counterM)) +            if binsDiscr[jj] > ii: +                if (counterM < numpy.size(IndicesReorg)): +                    IndicesReorg[counterM] = curr_index +                counterM = counterM + 1 +                 +            counter = counter + binsDiscr[jj] - 1 + + +if False: +    print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +    print ("prepare for iteration") +    fistaRecon.prepareForIteration() +     +     + +    print("initializing  ...") +    if False: +        # if X doesn't exist +        #N = params.vol_geom.GridColCount +        N = vol_geom['GridColCount'] +        print ("N " + str(N)) +        X = numpy.zeros((N,N,SlicesZ), dtype=numpy.float) +    else: +        #X = fistaRecon.initialize() +        X = numpy.load("X.npy") +     +    print (numpy.shape(X)) +    X_t = X.copy() +    print ("initialized") +    proj_geom , vol_geom, sino , \ +                      SlicesZ = fistaRecon.getParameter(['projector_geometry' , +                                                            'output_geometry', +                                                            'input_sinogram', +                                                         'SlicesZ']) + +    #fistaRecon.setParameter(number_of_iterations = 3) +    iterFISTA = fistaRecon.getParameter('number_of_iterations') +    # errors vector (if the ground truth is given) +    Resid_error = numpy.zeros((iterFISTA)); +    # objective function values vector +    objective = numpy.zeros((iterFISTA));  + +       +    t = 1 + + +    print ("starting iterations") +##    % Outer FISTA iterations loop +    for i in range(fistaRecon.getParameter('number_of_iterations')): +        X_old = X.copy() +        t_old = t +        r_old = fistaRecon.r.copy() +        if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ +           fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or \ +           fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec' : +            # if the geometry is parallel use slice-by-slice +            # projection-backprojection routine +            #sino_updt = zeros(size(sino),'single'); +            proj_geomT = proj_geom.copy() +            proj_geomT['DetectorRowCount'] = 1 +            vol_geomT = vol_geom.copy() +            vol_geomT['GridSliceCount'] = 1; +            sino_updt = numpy.zeros(numpy.shape(sino), dtype=numpy.float) +            for kkk in range(SlicesZ): +                sino_id, sino_updt[kkk] = \ +                         astra.creators.create_sino3d_gpu( +                             X_t[kkk:kkk+1], proj_geom, vol_geom) +                astra.matlab.data3d('delete', sino_id) +        else: +            # for divergent 3D geometry (watch the GPU memory overflow in +            # ASTRA versions < 1.8) +            #[sino_id, sino_updt] = astra_create_sino3d_cuda(X_t, proj_geom, vol_geom); +            sino_id, sino_updt = astra.creators.create_sino3d_gpu( +                X_t, proj_geom, vol_geom) + +        ## RING REMOVAL +        residual = fistaRecon.residual +        lambdaR_L1 , alpha_ring , weights , L_const= \ +                   fistaRecon.getParameter(['ring_lambda_R_L1', +                                           'ring_alpha' , 'weights', +                                           'Lipschitz_constant']) +        r_x = fistaRecon.r_x +        SlicesZ, anglesNumb, Detectors = \ +                    numpy.shape(fistaRecon.getParameter('input_sinogram')) +        if lambdaR_L1 > 0 : +             print ("ring removal") +             for kkk in range(anglesNumb): +                  +                 residual[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \ +                                       ((sino_updt[:,kkk,:]).squeeze() - \ +                                        (sino[:,kkk,:]).squeeze() -\ +                                        (alpha_ring * r_x) +                                        ) +             vec = residual.sum(axis = 1) +             #if SlicesZ > 1: +             #    vec = vec[:,1,:].squeeze() +             fistaRecon.r = (r_x - (1./L_const) * vec).copy() +             objective[i] = (0.5 * (residual ** 2).sum()) +##            % the ring removal part (Group-Huber fidelity) +##            for kkk = 1:anglesNumb +##                residual(:,kkk,:) =  squeeze(weights(:,kkk,:)).* +##                 (squeeze(sino_updt(:,kkk,:)) - +##                 (squeeze(sino(:,kkk,:)) - alpha_ring.*r_x)); +##            end +##            vec = sum(residual,2); +##            if (SlicesZ > 1) +##                vec = squeeze(vec(:,1,:)); +##            end +##            r = r_x - (1./L_const).*vec; +##            objective(i) = (0.5*sum(residual(:).^2)); % for the objective function output + +         + +        # Projection/Backprojection Routine +        if fistaRecon.getParameter('projector_geometry')['type'] == 'parallel' or \ +           fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat' or\ +           fistaRecon.getParameter('projector_geometry')['type'] == 'fanflat_vec': +            x_temp = numpy.zeros(numpy.shape(X),dtype=numpy.float32) +            print ("Projection/Backprojection Routine") +            for kkk in range(SlicesZ): +                 +                x_id, x_temp[kkk] = \ +                         astra.creators.create_backprojection3d_gpu( +                             residual[kkk:kkk+1], +                             proj_geomT, vol_geomT) +                astra.matlab.data3d('delete', x_id) +        else: +            x_id, x_temp = \ +                  astra.creators.create_backprojection3d_gpu( +                      residual, proj_geom, vol_geom)             + +        X = X_t - (1/L_const) * x_temp +        astra.matlab.data3d('delete', sino_id) +        astra.matlab.data3d('delete', x_id) +         + +        ## REGULARIZATION +        ## SKIPPING FOR NOW +        ## Should be simpli +        # regularizer = fistaRecon.getParameter('regularizer') +        # for slices: +        # out = regularizer(input=X) +        print ("skipping regularizer") + + +        ## FINAL +        print ("final") +        lambdaR_L1 = fistaRecon.getParameter('ring_lambda_R_L1') +        if lambdaR_L1 > 0: +            fistaRecon.r = numpy.max( +                numpy.abs(fistaRecon.r) - lambdaR_L1 , 0) * \ +                numpy.sign(fistaRecon.r) +        t = (1 + numpy.sqrt(1 + 4 * t**2))/2 +        X_t = X + (((t_old -1)/t) * (X - X_old)) + +        if lambdaR_L1 > 0: +            fistaRecon.r_x = fistaRecon.r + \ +                             (((t_old-1)/t) * (fistaRecon.r - r_old)) + +        if fistaRecon.getParameter('region_of_interest') is None: +            string = 'Iteration Number {0} | Objective {1} \n' +            print (string.format( i, objective[i])) +        else: +            ROI , X_ideal = fistaRecon.getParameter('region_of_interest', +                                                    'ideal_image') +             +            Resid_error[i] = RMSE(X*ROI, X_ideal*ROI) +            string = 'Iteration Number {0} | RMS Error {1} | Objective {2} \n' +            print (string.format(i,Resid_error[i], objective[i])) +             +##        if (lambdaR_L1 > 0) +##            r =  max(abs(r)-lambdaR_L1, 0).*sign(r); % soft-thresholding operator for ring vector +##        end +##         +##        t = (1 + sqrt(1 + 4*t^2))/2; % updating t +##        X_t = X + ((t_old-1)/t).*(X - X_old); % updating X +##         +##        if (lambdaR_L1 > 0) +##            r_x = r + ((t_old-1)/t).*(r - r_old); % updating r +##        end +##         +##        if (show == 1) +##            figure(10); imshow(X(:,:,slice), [0 maxvalplot]); +##            if (lambdaR_L1 > 0) +##                figure(11); plot(r); title('Rings offset vector') +##            end +##            pause(0.01); +##        end +##        if (strcmp(X_ideal, 'none' ) == 0) +##            Resid_error(i) = RMSE(X(ROI), X_ideal(ROI)); +##            fprintf('%s %i %s %s %.4f  %s %s %f \n', 'Iteration Number:', i, '|', 'Error RMSE:', Resid_error(i), '|', 'Objective:', objective(i)); +##        else +##            fprintf('%s %i  %s %s %f \n', 'Iteration Number:', i, '|', 'Objective:', objective(i)); +##        end +else: +    fistaRecon = FISTAReconstructor(proj_geom, +                                vol_geom, +                                Sino3D , +                                weights=Weights3D) + +    print ("Lipschitz Constant {0}".format(fistaRecon.pars['Lipschitz_constant'])) +    fistaRecon.setParameter(number_of_iterations = 12) +    fistaRecon.setParameter(Lipschitz_constant = 767893952.0) +    fistaRecon.setParameter(ring_alpha = 21) +    fistaRecon.setParameter(ring_lambda_R_L1 = 0.002) +    fistaRecon.prepareForIteration() +    X = fistaRecon.iterate(numpy.load("X.npy")) diff --git a/src/Python/test_regularizers.py b/src/Python/test_regularizers.py new file mode 100644 index 0000000..e76262c --- /dev/null +++ b/src/Python/test_regularizers.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Aug  4 11:10:05 2017 + +@author: ofn77899 +""" + +#from ccpi.viewer.CILViewer2D import Converter +#import vtk + +import matplotlib.pyplot as plt +import numpy as np +import os     +from enum import Enum +import timeit +#from PIL import Image +#from Regularizer import Regularizer +from ccpi.imaging.Regularizer import Regularizer + +############################################################################### +#https://stackoverflow.com/questions/13875989/comparing-image-in-url-to-image-in-filesystem-in-python/13884956#13884956 +#NRMSE a normalization of the root of the mean squared error +#NRMSE is simply 1 - [RMSE / (maxval - minval)]. Where maxval is the maximum +# intensity from the two images being compared, and respectively the same for +# minval. RMSE is given by the square root of MSE:  +# sqrt[(sum(A - B) ** 2) / |A|], +# where |A| means the number of elements in A. By doing this, the maximum value +# given by RMSE is maxval. + +def nrmse(im1, im2): +    a, b = im1.shape +    rmse = np.sqrt(np.sum((im2 - im1) ** 2) / float(a * b)) +    max_val = max(np.max(im1), np.max(im2)) +    min_val = min(np.min(im1), np.min(im2)) +    return 1 - (rmse / (max_val - min_val)) +############################################################################### + +############################################################################### +# +#  2D Regularizers +# +############################################################################### +#Example: +# figure; +# Im = double(imread('lena_gray_256.tif'))/255;  % loading image +# u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + + +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\lena_gray_512.tif" +filename = r"/home/ofn77899/Reconstruction/CCPi-FISTA_Reconstruction/data/lena_gray_512.tif" +#filename = r'/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif' + +#reader = vtk.vtkTIFFReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +Im = plt.imread(filename)                      +#Im = Image.open('/home/algol/Documents/Python/STD_test_images/lena_gray_512.tif')/255 +#img.show() +Im = np.asarray(Im, dtype='float32') + + + + +#imgplot = plt.imshow(Im) +perc = 0.05 +u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +# map the u0 u0->u0>0 +f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +u0 = f(u0).astype('float32') + +## plot  +fig = plt.figure() +#a=fig.add_subplot(3,3,1) +#a.set_title('Original') +#imgplot = plt.imshow(Im) + +a=fig.add_subplot(2,3,1) +a.set_title('noise') +imgplot = plt.imshow(u0,cmap="gray") + +reg_output = [] +############################################################################## +# Call regularizer + +####################### SplitBregman_TV ##################################### +# u = SplitBregman_TV(single(u0), 10, 30, 1e-04); + +use_object = True +if use_object: +    reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +    print (reg.pars) +    reg.setParameter(input=u0)     +    reg.setParameter(regularization_parameter=10.) +    # or  +    # reg.setParameter(input=u0, regularization_parameter=10., #number_of_iterations=30, +              #tolerance_constant=1e-4,  +              #TV_Penalty=Regularizer.TotalVariationPenalty.l1) +    plotme = reg() [0] +    pars = reg.pars +    textstr = reg.printParametersToString()  +     +    #out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +              #tolerance_constant=1e-4,  +    #          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +     +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +#          tolerance_constant=1e-4,  +#          TV_Penalty=Regularizer.TotalVariationPenalty.l1) + +else: +    out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10. ) +    pars = out2[2] +    reg_output.append(out2) +    plotme = reg_output[-1][0] +    textstr = out2[-1] + +a=fig.add_subplot(2,3,2) + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(plotme,cmap="gray") + +###################### FGP_TV ######################################### +# u = FGP_TV(single(u0), 0.05, 100, 1e-04); +out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +                          number_of_iterations=200) +pars = out2[-2] + +reg_output.append(out2) + +a=fig.add_subplot(2,3,3) + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +         verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0]) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + +###################### LLT_model ######################################### +# * u0 = Im + .03*randn(size(Im)); % adding noise +# [Den] = LLT_model(single(u0), 10, 0.1, 1); +#Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);  +#input, regularization_parameter , time_step, number_of_iterations, +#                  tolerance_constant, restrictive_Z_smoothing=0 +out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +                          time_step=0.0003, +                          tolerance_constant=0.0001, +                          number_of_iterations=300) +pars = out2[-2] + +reg_output.append(out2) + +a=fig.add_subplot(2,3,4) + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +         verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + + +# ###################### PatchBased_Regul ######################################### +# # Quick 2D denoising example in Matlab:    +# #   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +# #   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# #   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);  + +out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +                       searching_window_ratio=3, +                       similarity_window_ratio=1, +                       PB_filtering_parameter=0.08) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,5) + + +textstr = out2[-1] + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +        verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + + +# ###################### TGV_PD ######################################### +# # Quick 2D denoising example in Matlab:    +# #   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +# #   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +# #   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +                           first_order_term=1.3, +                           second_order_term=1, +                           number_of_iterations=550) +pars = out2[-2] +reg_output.append(out2) + +a=fig.add_subplot(2,3,6) + + +textstr = out2[-1] + + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +# place a text box in upper left in axes coords +a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +         verticalalignment='top', bbox=props) +imgplot = plt.imshow(reg_output[-1][0],cmap="gray") + + +plt.show() + +################################################################################ +## +##  3D Regularizers +## +################################################################################ +##Example: +## figure; +## Im = double(imread('lena_gray_256.tif'))/255;  % loading image +## u0 = Im + .05*randn(size(Im)); u0(u0 < 0) = 0; +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Reconstruction\python\test\reconstruction_example.mha" +#filename = r"C:\Users\ofn77899\Documents\GitHub\CCPi-Simpleflex\data\head.mha" +# +#reader = vtk.vtkMetaImageReader() +#reader.SetFileName(os.path.normpath(filename)) +#reader.Update() +##vtk returns 3D images, let's take just the one slice there is as 2D +#Im = Converter.vtk2numpy(reader.GetOutput()) +#Im = Im.astype('float32') +##imgplot = plt.imshow(Im) +#perc = 0.05 +#u0 = Im + (perc* np.random.normal(size=np.shape(Im))) +## map the u0 u0->u0>0 +#f = np.frompyfunc(lambda x: 0 if x < 0 else x, 1,1) +#u0 = f(u0).astype('float32') +#converter = Converter.numpy2vtkImporter(u0, reader.GetOutput().GetSpacing(), +#                                        reader.GetOutput().GetOrigin()) +#converter.Update() +#writer = vtk.vtkMetaImageWriter() +#writer.SetInputData(converter.GetOutput()) +#writer.SetFileName(r"C:\Users\ofn77899\Documents\GitHub\CCPi-FISTA_reconstruction\data\noisy_head.mha") +##writer.Write() +# +# +### plot  +#fig3D = plt.figure() +##a=fig.add_subplot(3,3,1) +##a.set_title('Original') +##imgplot = plt.imshow(Im) +#sliceNo = 32 +# +#a=fig3D.add_subplot(2,3,1) +#a.set_title('noise') +#imgplot = plt.imshow(u0.T[sliceNo]) +# +#reg_output3d = [] +# +############################################################################### +## Call regularizer +# +######################## SplitBregman_TV ##################################### +## u = SplitBregman_TV(single(u0), 10, 30, 1e-04); +# +##reg = Regularizer(Regularizer.Algorithm.SplitBregman_TV) +# +##out = reg(input=u0, regularization_parameter=10., #number_of_iterations=30, +##          #tolerance_constant=1e-4,  +##          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +#out2 = Regularizer.SplitBregman_TV(input=u0, regularization_parameter=10., number_of_iterations=30, +#          tolerance_constant=1e-4,  +#          TV_Penalty=Regularizer.TotalVariationPenalty.l1) +# +# +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### FGP_TV ######################################### +## u = FGP_TV(single(u0), 0.05, 100, 1e-04); +#out2 = Regularizer.FGP_TV(input=u0, regularization_parameter=0.005, +#                          number_of_iterations=200) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### LLT_model ######################################### +## * u0 = Im + .03*randn(size(Im)); % adding noise +## [Den] = LLT_model(single(u0), 10, 0.1, 1); +##Den = LLT_model(single(u0), 25, 0.0003, 300, 0.0001, 0);  +##input, regularization_parameter , time_step, number_of_iterations, +##                  tolerance_constant, restrictive_Z_smoothing=0 +#out2 = Regularizer.LLT_model(input=u0, regularization_parameter=25, +#                          time_step=0.0003, +#                          tolerance_constant=0.0001, +#                          number_of_iterations=300) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# +####################### PatchBased_Regul ######################################### +## Quick 2D denoising example in Matlab:    +##   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +##   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +##   ImDen = PB_Regul_CPU(single(u0), 3, 1, 0.08, 0.05);  +# +#out2 = Regularizer.PatchBased_Regul(input=u0, regularization_parameter=0.05, +#                          searching_window_ratio=3, +#                          similarity_window_ratio=1, +#                          PB_filtering_parameter=0.08) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo]) +# + +###################### TGV_PD ######################################### +# Quick 2D denoising example in Matlab:    +#   Im = double(imread('lena_gray_256.tif'))/255;  % loading image +#   u0 = Im + .03*randn(size(Im)); u0(u0<0) = 0; % adding noise +#   u = PrimalDual_TGV(single(u0), 0.02, 1.3, 1, 550); + + +#out2 = Regularizer.TGV_PD(input=u0, regularization_parameter=0.05, +#                          first_order_term=1.3, +#                          second_order_term=1, +#                          number_of_iterations=550) +#pars = out2[-2] +#reg_output3d.append(out2) +# +#a=fig3D.add_subplot(2,3,2) +# +# +#textstr = out2[-1] +# +# +## these are matplotlib.patch.Patch properties +#props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) +## place a text box in upper left in axes coords +#a.text(0.05, 0.95, textstr, transform=a.transAxes, fontsize=14, +#        verticalalignment='top', bbox=props) +#imgplot = plt.imshow(reg_output3d[-1][0].T[sliceNo])  | 
