diff options
author | dkazanc <dkazanc@hotmail.com> | 2018-12-04 16:13:38 +0000 |
---|---|---|
committer | dkazanc <dkazanc@hotmail.com> | 2018-12-04 16:13:38 +0000 |
commit | c9ee9ecc84881595b04f19280c93bcd587171270 (patch) | |
tree | 5074abd308c3e2f4425ee27251d242f3273f1dd4 /Wrappers/Python/demos | |
parent | 8b8dfc68fa6b70ec7eefcdfb928fb383196bec97 (diff) | |
download | regularization-c9ee9ecc84881595b04f19280c93bcd587171270.tar.gz regularization-c9ee9ecc84881595b04f19280c93bcd587171270.tar.bz2 regularization-c9ee9ecc84881595b04f19280c93bcd587171270.tar.xz regularization-c9ee9ecc84881595b04f19280c93bcd587171270.zip |
GPU version, this completes implementation of nltv #68
Diffstat (limited to 'Wrappers/Python/demos')
-rw-r--r-- | Wrappers/Python/demos/demo_cpu_regularisers.py | 18 | ||||
-rw-r--r-- | Wrappers/Python/demos/demo_cpu_vs_gpu_regularisers.py | 55 | ||||
-rw-r--r-- | Wrappers/Python/demos/demo_gpu_regularisers.py | 74 |
3 files changed, 141 insertions, 6 deletions
diff --git a/Wrappers/Python/demos/demo_cpu_regularisers.py b/Wrappers/Python/demos/demo_cpu_regularisers.py index 31e4cad..78e9aff 100644 --- a/Wrappers/Python/demos/demo_cpu_regularisers.py +++ b/Wrappers/Python/demos/demo_cpu_regularisers.py @@ -400,20 +400,29 @@ plt.title('{}'.format('CPU results')) print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") print ("___Nonlocal patches pre-calculation____") print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +start_time = timeit.default_timer() # set parameters pars = {'algorithm' : PatchSelect, \ 'input' : u0,\ 'searchwindow': 7, \ 'patchwindow': 2,\ 'neighbours' : 15 ,\ - 'edge_parameter':0.23} + 'edge_parameter':0.18} H_i, H_j, Weights = PatchSelect(pars['input'], pars['searchwindow'], pars['patchwindow'], pars['neighbours'], pars['edge_parameter'],'cpu') - + +txtstr = printParametersToString(pars) +txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +print (txtstr) +""" +plt.figure() +plt.imshow(Weights[0,:,:],cmap="gray",interpolation="nearest",vmin=0, vmax=1) +plt.show() +""" #%% print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") print ("___Nonlocal Total Variation penalty____") @@ -431,10 +440,9 @@ pars2 = {'algorithm' : NLTV, \ 'H_j': H_j,\ 'H_k' : 0,\ 'Weights' : Weights,\ - 'regularisation_parameter': 0.085,\ - 'iterations': 2 + 'regularisation_parameter': 0.04,\ + 'iterations': 3 } -#%% start_time = timeit.default_timer() nltv_cpu = NLTV(pars2['input'], pars2['H_i'], diff --git a/Wrappers/Python/demos/demo_cpu_vs_gpu_regularisers.py b/Wrappers/Python/demos/demo_cpu_vs_gpu_regularisers.py index 3d6e92f..616eab0 100644 --- a/Wrappers/Python/demos/demo_cpu_vs_gpu_regularisers.py +++ b/Wrappers/Python/demos/demo_cpu_vs_gpu_regularisers.py @@ -13,6 +13,7 @@ import numpy as np import os import timeit from ccpi.filters.regularisers import ROF_TV, FGP_TV, SB_TV, TGV, LLT_ROF, FGP_dTV, NDF, DIFF4th +from ccpi.filters.regularisers import PatchSelect from qualitymetrics import rmse ############################################################################### def printParametersToString(pars): @@ -732,4 +733,58 @@ if (diff_im.sum() > 1): print ("Arrays do not match!") else: print ("Arrays match") +#%% +print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +print ("____Non-local regularisation bench_________") +print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") + +## plot +fig = plt.figure() +plt.suptitle('Comparison of Nonlocal TV regulariser using CPU and GPU implementations') +a=fig.add_subplot(1,2,1) +a.set_title('Noisy Image') +imgplot = plt.imshow(u0,cmap="gray") + +pars = {'algorithm' : PatchSelect, \ + 'input' : u0,\ + 'searchwindow': 7, \ + 'patchwindow': 2,\ + 'neighbours' : 15 ,\ + 'edge_parameter':0.18} + +print ("############## Nonlocal Patches on CPU##################") +start_time = timeit.default_timer() +H_i, H_j, WeightsCPU = PatchSelect(pars['input'], + pars['searchwindow'], + pars['patchwindow'], + pars['neighbours'], + pars['edge_parameter'],'cpu') +txtstr = printParametersToString(pars) +txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +print (txtstr) + +print ("############## Nonlocal Patches on GPU##################") +start_time = timeit.default_timer() +start_time = timeit.default_timer() +H_i, H_j, WeightsGPU = PatchSelect(pars['input'], + pars['searchwindow'], + pars['patchwindow'], + pars['neighbours'], + pars['edge_parameter'],'gpu') +txtstr = printParametersToString(pars) +txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +print (txtstr) + +print ("--------Compare the results--------") +tolerance = 1e-05 +diff_im = np.zeros(np.shape(u0)) +diff_im = abs(WeightsCPU[0,:,:] - WeightsGPU[0,:,:]) +diff_im[diff_im > tolerance] = 1 +a=fig.add_subplot(1,2,2) +imgplot = plt.imshow(diff_im, vmin=0, vmax=1, cmap="gray") +plt.title('{}'.format('Pixels larger threshold difference')) +if (diff_im.sum() > 1): + print ("Arrays do not match!") +else: + print ("Arrays match") #%%
\ No newline at end of file diff --git a/Wrappers/Python/demos/demo_gpu_regularisers.py b/Wrappers/Python/demos/demo_gpu_regularisers.py index de0cbde..2ada559 100644 --- a/Wrappers/Python/demos/demo_gpu_regularisers.py +++ b/Wrappers/Python/demos/demo_gpu_regularisers.py @@ -13,6 +13,7 @@ import numpy as np import os import timeit from ccpi.filters.regularisers import ROF_TV, FGP_TV, SB_TV, TGV, LLT_ROF, FGP_dTV, NDF, DIFF4th +from ccpi.filters.regularisers import PatchSelect, NLTV from qualitymetrics import rmse ############################################################################### def printParametersToString(pars): @@ -84,7 +85,7 @@ pars = {'algorithm': ROF_TV, \ 'input' : u0,\ 'regularisation_parameter':0.04,\ 'number_of_iterations': 1200,\ - 'time_marching_parameter': 0.0025 + 'time_marching_parameter': 0.0025 } print ("##############ROF TV GPU##################") start_time = timeit.default_timer() @@ -394,6 +395,77 @@ plt.title('{}'.format('GPU results')) #%% print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +print ("___Nonlocal patches pre-calculation____") +print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +start_time = timeit.default_timer() +# set parameters +pars = {'algorithm' : PatchSelect, \ + 'input' : u0,\ + 'searchwindow': 7, \ + 'patchwindow': 2,\ + 'neighbours' : 15 ,\ + 'edge_parameter':0.18} + +H_i, H_j, Weights = PatchSelect(pars['input'], + pars['searchwindow'], + pars['patchwindow'], + pars['neighbours'], + pars['edge_parameter'],'gpu') + +txtstr = printParametersToString(pars) +txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +print (txtstr) +""" +plt.figure() +plt.imshow(Weights[0,:,:],cmap="gray",interpolation="nearest",vmin=0, vmax=1) +plt.show() +""" +#%% +print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +print ("___Nonlocal Total Variation penalty____") +print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") +## plot +fig = plt.figure() +plt.suptitle('Performance of NLTV regulariser using the CPU') +a=fig.add_subplot(1,2,1) +a.set_title('Noisy Image') +imgplot = plt.imshow(u0,cmap="gray") + +pars2 = {'algorithm' : NLTV, \ + 'input' : u0,\ + 'H_i': H_i, \ + 'H_j': H_j,\ + 'H_k' : 0,\ + 'Weights' : Weights,\ + 'regularisation_parameter': 0.02,\ + 'iterations': 3 + } +start_time = timeit.default_timer() +nltv_cpu = NLTV(pars2['input'], + pars2['H_i'], + pars2['H_j'], + pars2['H_k'], + pars2['Weights'], + pars2['regularisation_parameter'], + pars2['iterations']) + +rms = rmse(Im, nltv_cpu) +pars['rmse'] = rms + +txtstr = printParametersToString(pars) +txtstr += "%s = %.3fs" % ('elapsed time',timeit.default_timer() - start_time) +print (txtstr) +a=fig.add_subplot(1,2,2) + +# these are matplotlib.patch.Patch properties +props = dict(boxstyle='round', facecolor='wheat', alpha=0.75) +# place a text box in upper left in axes coords +a.text(0.15, 0.25, txtstr, transform=a.transAxes, fontsize=14, + verticalalignment='top', bbox=props) +imgplot = plt.imshow(nltv_cpu, cmap="gray") +plt.title('{}'.format('CPU results')) +#%% +print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") print ("____________FGP-dTV bench___________________") print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") |