{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Wavelet Decomposition of an Image\n\nThis is a simple example which demonstrates how \n``cr.sparse.wt.dwt2`` function can be used to \nperform 2D wavelet decomposition.\n\nIts interface is identical to the corresponding function\nin PyWavelets library.\n\nThis example is adapted from \n`PyWavelets documentation <https://pywavelets.readthedocs.io/en/latest/>`_.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Configure JAX to work with 64-bit floating point precision. \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from jax.config import config\nconfig.update(\"jax_enable_x64\", True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's import necessary libraries \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import jax.numpy as jnp\n# CR.Sparse libraries\nimport cr.sparse.wt as wt\n# We use PyWavelets only for sample data\nimport pywt.data\n# Plotting\nimport matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Load the Cameraman image\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "original = pywt.data.camera()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Perform wavelet decomposition \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "coeffs2 = wt.dwt2(original, 'bior1.3')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Split the coefficients tuple into individual parts\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "LL, (LH, HL, HH) = coeffs2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot the decomposition\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "titles = ['Approximation', ' Horizontal detail',\n          'Vertical detail', 'Diagonal detail']\nfig = plt.figure(figsize=(12, 3))\nfor i, a in enumerate([LL, LH, HL, HH]):\n    ax = fig.add_subplot(1, 4, i + 1)\n    ax.imshow(a, interpolation=\"nearest\", cmap=plt.cm.gray)\n    ax.set_title(titles[i], fontsize=10)\n    ax.set_xticks([])\n    ax.set_yticks([])\n\nfig.tight_layout()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}