{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Image Deblurring\n    :depth: 2\n    :local:\n\nThis example demonstrates following features:\n\n- ``cr.sparse.lop.convolve2D`` A 2D convolution linear operator\n- ``cr.sparse.sls.lsqr`` LSQR algorithm for solving a least square problem on 2D images\n- ``cr.sparse.lop.dwt2D`` A 2D discrete wavelet basis operator\n- ``cr.sparse.sls.fista`` Fast Iterative Shrinkage and Thresholding Algorithm on 2D images\n\nImage deblurring can be treated as a deconvolution problem if the filter used\nfor blurring the image is known.\n\nPlease see the deconvolution example for some background.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's import necessary libraries \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import jax.numpy as jnp\n# For plotting diagrams\nimport matplotlib.pyplot as plt\n## CR-Sparse modules\nimport cr.nimble as cnb\n# Linear operators\nfrom cr.sparse import lop\n# Image processing utilities\nfrom cr.sparse import vision\n# Solvers for sparse linear systems\nfrom cr.sparse import sls\n# Several thresholding functions are available in this module\nfrom cr.sparse import geo\n# Sample images\nimport skimage.data\n# Configure JAX for 64-bit computing\nfrom jax.config import config\nconfig.update(\"jax_enable_x64\", True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Problem Setup\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "image = skimage.data.checkerboard()\nprint(image.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Gaussian blur kernel\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "h  = vision.kernel_gaussian((15,25), (8,4))\n# plot the kernel\nfig, ax = plt.subplots(1, 1, figsize=(5, 3))\nhim = ax.imshow(h)\nax.set_title('Blurring kernel')\nfig.colorbar(him, ax=ax)\nax.axis('tight')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### The linear operator for the blur kernel\nLocate the center of the filter\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "offset = cnb.arr_largest_index(h)\nprint(offset)\n# Construct a 2D convolution operator based on the kernel\nH = lop.convolve2D(image.shape, h, offset=offset)\n# JIT compile the convolution operator for efficiency\nH = lop.jit(H)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### The blurring\nApply the blurring operator to the original image \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "blurred_image = H.times(image)\n# Measure the PSNR\nprint(\"Blurred PSNR: \", cnb.peak_signal_noise_ratio(image, blurred_image), 'dB')\n# plot the original and the blurred images\nfig, ax = plt.subplots(ncols=2, figsize=(10, 5))\nax[0].imshow(image, cmap=plt.cm.gray)\nax[0].set_title('Original')\nax[1].imshow(blurred_image, cmap=plt.cm.gray)\nax[1].set_title('After blurring')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## The deblurring using LSQR algorithm\nAn initial guess of the deblurred image is all zeros\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x0 = jnp.zeros_like(blurred_image)\n# We run LSQR algorithm to deblur the image for 50 iterations\nsol = sls.lsqr(H, blurred_image, x0, max_iters=50)\ndeblurred_image = sol.x\n# Measure the PSNR\nprint(\"Deblurred PSNR: \", cnb.peak_signal_noise_ratio(image, deblurred_image), 'dB')\n# Plot the original, blurred and deblurred image\nfig, ax = plt.subplots(ncols=3, figsize=(15, 5))\nax[0].imshow(image, cmap=plt.cm.gray)\nax[0].set_title('Original')\nax[1].imshow(blurred_image, cmap=plt.cm.gray)\nax[1].set_title('After blurring')\nax[2].imshow(deblurred_image, cmap=plt.cm.gray)\nax[2].set_title('After deblurring')\n\nprint(sol)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## A wavelet basis for the images\nConstruct the basis\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "DWT_basis = lop.dwt2D(image.shape, wavelet='haar', level=3, basis=True)\nDWT_basis = lop.jit(DWT_basis)\n# Visualize the wavelet transform of the image\ncoefs = DWT_basis.trans(image)\nfig, ax = plt.subplots(ncols=2, figsize=(10, 5))\nax[0].imshow(image, cmap=plt.cm.gray)\nax[0].set_title('Image')\nax[1].imshow(coefs, cmap=plt.cm.gray)\nax[1].set_title('Wavelet coefficients')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Deblurring with Fast Iterative Shrinkage and Thresholding Algorithm\nWe combine the convolution operator and the wavelet basis operator \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "A = H @ DWT_basis\n# Step size for the FISTA algorithm\nstep_size = 1.\n# Thresholding function for the FISTA algorithm\nthreshold_func = lambda i, x : geo.soft_threshold(x, 0.02)\n# Initial guess for the wavelet coefficients matrix is all zeros\nx0 = jnp.zeros(DWT_basis.shape[1])\n# Solve the \\| A x - b \\|_2^2 + \\lambda \\| x \\|_1 problem\nsol = sls.fista_jit(\n    # The combined convolution+wavelet basis operator\n    A, \n    # The blurred image as input\n    b=blurred_image, \n    # Initial guess for the coefficients\n    x0=x0, \n    # Step size for the FISTA algorithm\n    step_size=1., \n    # Thresholding function to be used for FISTA\n    threshold_func=threshold_func, \n    # Maximum number of iterations for which the algorithm will be run\n    max_iters=50)\nprint(f\"Number of FISTA iterations {sol.iterations}\")\n# Compute the deblurred image from the coefficients given by FISTA\ndeblurred_image = DWT_basis.times(sol.x)\n# Measure the PSNR\nprint(\"Deblurred PSNR: \", cnb.peak_signal_noise_ratio(image, deblurred_image), 'dB')\nfig, ax = plt.subplots(ncols=3, figsize=(15, 5))\nax[0].imshow(image, cmap=plt.cm.gray)\nax[0].set_title('Original')\nax[1].imshow(blurred_image, cmap=plt.cm.gray)\nax[1].set_title('After blurring')\nax[2].imshow(deblurred_image, cmap=plt.cm.gray)\nax[2].set_title('FISTA deblurring')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}