.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/0100_lop/deblurring.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_0100_lop_deblurring.py: .. _gallery:lop:image_deblurring: Image Deblurring =========================== .. contents:: :depth: 2 :local: This example demonstrates following features: - ``cr.sparse.lop.convolve2D`` A 2D convolution linear operator - ``cr.sparse.sls.lsqr`` LSQR algorithm for solving a least square problem on 2D images - ``cr.sparse.lop.dwt2D`` A 2D discrete wavelet basis operator - ``cr.sparse.sls.fista`` Fast Iterative Shrinkage and Thresholding Algorithm on 2D images Image deblurring can be treated as a deconvolution problem if the filter used for blurring the image is known. Please see the deconvolution example for some background. .. GENERATED FROM PYTHON SOURCE LINES 26-27 Let's import necessary libraries .. GENERATED FROM PYTHON SOURCE LINES 27-46 .. code-block:: default import jax.numpy as jnp # For plotting diagrams import matplotlib.pyplot as plt ## CR-Sparse modules import cr.nimble as crn # Linear operators from cr.sparse import lop # Image processing utilities from cr.sparse import vision # Solvers for sparse linear systems from cr.sparse import sls # Several thresholding functions are available in this module from cr.sparse import geo # Sample images import skimage.data # Configure JAX for 64-bit computing from jax.config import config config.update("jax_enable_x64", True) .. GENERATED FROM PYTHON SOURCE LINES 47-49 Problem Setup ------------------ .. GENERATED FROM PYTHON SOURCE LINES 49-53 .. code-block:: default image = skimage.data.checkerboard() print(image.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none (200, 200) .. GENERATED FROM PYTHON SOURCE LINES 54-56 Gaussian blur kernel '''''''''''''''''''''''''''''''''''''''''''''''''''''' .. GENERATED FROM PYTHON SOURCE LINES 56-64 .. code-block:: default h = vision.kernel_gaussian((15,25), (8,4)) # plot the kernel fig, ax = plt.subplots(1, 1, figsize=(5, 3)) him = ax.imshow(h) ax.set_title('Blurring kernel') fig.colorbar(him, ax=ax) ax.axis('tight') .. image-sg:: /gallery/0100_lop/images/sphx_glr_deblurring_001.png :alt: Blurring kernel :srcset: /gallery/0100_lop/images/sphx_glr_deblurring_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none (-0.5, 24.5, 14.5, -0.5) .. GENERATED FROM PYTHON SOURCE LINES 65-68 The linear operator for the blur kernel '''''''''''''''''''''''''''''''''''''''''''''''''''''' Locate the center of the filter .. GENERATED FROM PYTHON SOURCE LINES 68-75 .. code-block:: default offset = crn.arr_largest_index(h) print(offset) # Construct a 2D convolution operator based on the kernel H = lop.convolve2D(image.shape, h, offset=offset) # JIT compile the convolution operator for efficiency H = lop.jit(H) .. rst-class:: sphx-glr-script-out .. code-block:: none (Array(7, dtype=int64), Array(12, dtype=int64)) .. GENERATED FROM PYTHON SOURCE LINES 76-79 The blurring '''''''''''''''''''''''''''''''''''''''''''''''''''''' Apply the blurring operator to the original image .. GENERATED FROM PYTHON SOURCE LINES 79-89 .. code-block:: default blurred_image = H.times(image) # Measure the PSNR print("Blurred PSNR: ", crn.peak_signal_noise_ratio(image, blurred_image), 'dB') # plot the original and the blurred images fig, ax = plt.subplots(ncols=2, figsize=(10, 5)) ax[0].imshow(image, cmap=plt.cm.gray) ax[0].set_title('Original') ax[1].imshow(blurred_image, cmap=plt.cm.gray) ax[1].set_title('After blurring') .. image-sg:: /gallery/0100_lop/images/sphx_glr_deblurring_002.png :alt: Original, After blurring :srcset: /gallery/0100_lop/images/sphx_glr_deblurring_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Blurred PSNR: 14.592858290084076 dB Text(0.5, 1.0, 'After blurring') .. GENERATED FROM PYTHON SOURCE LINES 90-93 The deblurring using LSQR algorithm ------------------------------------------------- An initial guess of the deblurred image is all zeros .. GENERATED FROM PYTHON SOURCE LINES 93-111 .. code-block:: default x0 = jnp.zeros_like(blurred_image) # We run LSQR algorithm to deblur the image for 50 iterations sol = sls.lsqr(H, blurred_image, x0, max_iters=50) deblurred_image = sol.x # Measure the PSNR print("Deblurred PSNR: ", crn.peak_signal_noise_ratio(image, deblurred_image), 'dB') # Plot the original, blurred and deblurred image fig, ax = plt.subplots(ncols=3, figsize=(15, 5)) ax[0].imshow(image, cmap=plt.cm.gray) ax[0].set_title('Original') ax[1].imshow(blurred_image, cmap=plt.cm.gray) ax[1].set_title('After blurring') ax[2].imshow(deblurred_image, cmap=plt.cm.gray) ax[2].set_title('After deblurring') print(sol) .. image-sg:: /gallery/0100_lop/images/sphx_glr_deblurring_003.png :alt: Original, After blurring, After deblurring :srcset: /gallery/0100_lop/images/sphx_glr_deblurring_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Deblurred PSNR: 21.206455209076868 dB x: (200, 200) A_norm: 4.946758875785462 A_cond: 167.30699088560272 x_norm: 34971.66794553033 r_norm: 28.394130441229127 atr_norm: 3.81619135390487 iterations: 50 n_times: 50 n_trans: 50 .. GENERATED FROM PYTHON SOURCE LINES 112-115 A wavelet basis for the images --------------------------------------------------------------------------- Construct the basis .. GENERATED FROM PYTHON SOURCE LINES 115-126 .. code-block:: default DWT_basis = lop.dwt2D(image.shape, wavelet='haar', level=3, basis=True) DWT_basis = lop.jit(DWT_basis) # Visualize the wavelet transform of the image coefs = DWT_basis.trans(image) fig, ax = plt.subplots(ncols=2, figsize=(10, 5)) ax[0].imshow(image, cmap=plt.cm.gray) ax[0].set_title('Image') ax[1].imshow(coefs, cmap=plt.cm.gray) ax[1].set_title('Wavelet coefficients') .. image-sg:: /gallery/0100_lop/images/sphx_glr_deblurring_004.png :alt: Image, Wavelet coefficients :srcset: /gallery/0100_lop/images/sphx_glr_deblurring_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Text(0.5, 1.0, 'Wavelet coefficients') .. GENERATED FROM PYTHON SOURCE LINES 127-130 Deblurring with Fast Iterative Shrinkage and Thresholding Algorithm --------------------------------------------------------------------------- We combine the convolution operator and the wavelet basis operator .. GENERATED FROM PYTHON SOURCE LINES 130-163 .. code-block:: default A = H @ DWT_basis # Step size for the FISTA algorithm step_size = 1. # Thresholding function for the FISTA algorithm threshold_func = lambda i, x : geo.soft_threshold(x, 0.02) # Initial guess for the wavelet coefficients matrix is all zeros x0 = jnp.zeros(DWT_basis.shape[1]) # Solve the \| A x - b \|_2^2 + \lambda \| x \|_1 problem sol = sls.fista_jit( # The combined convolution+wavelet basis operator A, # The blurred image as input b=blurred_image, # Initial guess for the coefficients x0=x0, # Step size for the FISTA algorithm step_size=1., # Thresholding function to be used for FISTA threshold_func=threshold_func, # Maximum number of iterations for which the algorithm will be run max_iters=50) print(f"Number of FISTA iterations {sol.iterations}") # Compute the deblurred image from the coefficients given by FISTA deblurred_image = DWT_basis.times(sol.x) # Measure the PSNR print("Deblurred PSNR: ", crn.peak_signal_noise_ratio(image, deblurred_image), 'dB') fig, ax = plt.subplots(ncols=3, figsize=(15, 5)) ax[0].imshow(image, cmap=plt.cm.gray) ax[0].set_title('Original') ax[1].imshow(blurred_image, cmap=plt.cm.gray) ax[1].set_title('After blurring') ax[2].imshow(deblurred_image, cmap=plt.cm.gray) ax[2].set_title('FISTA deblurring') .. image-sg:: /gallery/0100_lop/images/sphx_glr_deblurring_005.png :alt: Original, After blurring, FISTA deblurring :srcset: /gallery/0100_lop/images/sphx_glr_deblurring_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Number of FISTA iterations 50 Deblurred PSNR: 20.411790224661722 dB Text(0.5, 1.0, 'FISTA deblurring') .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 10.518 seconds) .. _sphx_glr_download_gallery_0100_lop_deblurring.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: deblurring.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: deblurring.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_