{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Wavelet Transform Operators\n    :depth: 2\n    :local:\n\nThis example demonstrates following features:\n\n- ``cr.sparse.lop.convolve2D`` A 2D convolution linear operator\n- ``cr.sparse.sls.lsqr`` LSQR algorithm for solving a least square problem on 2D images\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's import necessary libraries \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import jax.numpy as jnp\n# For plotting diagrams\nimport matplotlib.pyplot as plt\n## CR-Sparse modules\nimport cr.nimble as cnb\n# Linear operators\nfrom cr.sparse import lop\n# Sample images\nimport skimage.data\n# Utilities\nfrom cr.nimble.dsp import time_values\n# Configure JAX for 64-bit computing\nfrom jax.config import config\nconfig.update(\"jax_enable_x64\", True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 1D Wavelet Transform Operator\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### A signal consisting of multiple sinusoids \nIndividual sinusoids have different frequencies and amplitudes.\nSampling frequency\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fs = 1000.\n# Time duration\nT = 2\n# time values\nt = time_values(fs, T)\n# Number of samples\nn = t.size\nx = jnp.zeros(n)\nfreqs = [25, 7, 9]\namps = [1, -3, .8]\nfor  (f, amp) in zip(freqs, amps):\n    sinusoid = amp * jnp.sin(2 * jnp.pi * f * t)\n    x = x + sinusoid\n# Plot the signal\nplt.figure(figsize=(8,2))\nplt.plot(t, x, 'k', label='Composite signal')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 1D wavelet transform operator\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "DWT_op = lop.dwt(n, wavelet='dmey', level=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Wavelet coefficients\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "alpha = DWT_op.times(x)\nplt.figure(figsize=(8,2))\nplt.plot(alpha, label='Wavelet coefficients')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Compression\nLet's keep only 10 percent of the coefficients\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cutoff = n // 10\nalpha2 = alpha.at[cutoff:].set(0)\nplt.figure(figsize=(8,2))\nplt.plot(alpha2, label='Wavelet coefficients after compression')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Reconstruction\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_rec = DWT_op.trans(alpha2)\n# RMSE \nrmse = cnb.root_mse(x, x_rec)\nprint(rmse)\n# SNR \nsnr = cnb.signal_noise_ratio(x, x_rec)\nprint(snr)\nplt.figure(figsize=(8,2))\nplt.plot(x, 'k', label='Original')\nplt.plot(x_rec, 'r', label='Reconstructed')\nplt.title('Reconstructed signal')\nplt.legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 2D Wavelet Transform Operator\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Sample image\nimage = skimage.data.grass()\nDWT2_op = lop.dwt2D(image.shape, wavelet='haar', level=5)\nDWT2_op = lop.jit(DWT2_op)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Wavelet coefficients\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "coefs = DWT2_op.times(image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Compression\nLet's keep only 1/16 of the coefficients\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "h, w = coefs.shape\ncoefs2 = jnp.zeros_like(coefs)\ncoefs2 = coefs2.at[:h//4, :w//4].set(coefs[:h//4, :w//4])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Reconstruction\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "image_rec = DWT2_op.trans(coefs2)\n# RMSE \nrmse = cnb.root_mse(image, image_rec)\nprint(rmse)\n# PSNR \npsnr = cnb.peak_signal_noise_ratio(image, image_rec)\nprint(psnr)\n\n# Plot everything\nfig, axs = plt.subplots(1, 4, figsize=(16, 3))\naxs[0].imshow(image, cmap='gray')\naxs[0].set_title('Image')\naxs[0].axis('tight')\n\naxs[1].imshow(coefs, cmap='gray_r', vmin=-1e2, vmax=1e2)\naxs[1].set_title('DWT2 coefficients')\naxs[1].axis('tight')\n\naxs[2].imshow(coefs2, cmap='gray_r', vmin=-1e2, vmax=1e2)\naxs[2].set_title('After compression')\naxs[2].axis('tight')\n\naxs[3].imshow(image_rec, cmap='gray')\naxs[3].set_title('Reconstructed image')\naxs[3].axis('tight')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}