{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Chirp CWT with Ricker \n\nIn this example, we analyze a chirp signal with a Ricker (a.k.a. Mexican Hat wavelet)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Configure JAX to work with 64-bit floating point precision. \nfrom jax.config import config\nconfig.update(\"jax_enable_x64\", True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's import necessary libraries \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nimport jax.numpy as jnp\n# CR.Sparse libraries\nimport cr.sparse as crs\nimport cr.sparse.wt as wt\n# Utilty functions to construct sinusoids\nimport cr.sparse.dsp.signals as signals\n# Plotting\nimport matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Test signal generation\nSampling frequency in Hz\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fs = 100\n# Signal duration in seconds\nT = 10\n# Initial instantaneous frequency for the chirp\nf0 = 1\n# Final instantaneous frequency for the chirp\nf1 = 4\n# Construct the chirp signal\nt, x = signals.chirp(fs, T, f0, f1, initial_phase=0)\n# Plot the chirp signal\nfig, ax = plt.subplots(figsize=(12, 4))\nax.plot(t, x)\nax.grid('on')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Power spectrum\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Compute the power spectrum\nf, sxx = crs.power_spectrum(x, dt=1/fs)\n# Plot the power spectrum\nfig, ax = plt.subplots(1, figsize=(12,4))\nax.plot(f, sxx)\nax.grid('on')\nax.set_xlabel('Frequency (Hz)')\nax.set_ylabel('Power')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As expected, the power spectrum is able to identify the\nfrequencies in the zone 1Hz to 4Hz in the chirp. \nHowever, the spectrum is unable to localize the \nchanges in frequency over time.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Ricker/Mexican Hat Wavelet\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "wavelet = wt.build_wavelet('mexh')\n# generate the wavelet function for the range of time [-8, 8]\npsi, t_psi = wavelet.wavefun()\n# plot the wavelet\nfig, ax = plt.subplots(figsize=(12, 4))\nax.plot(t_psi, psi)\nax.grid('on')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Wavelet Analysis\nselect a set of scales for wavelet analysis\nvoices per octave\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "nu = 8\nscales = wt.scales_from_voices_per_octave(nu, jnp.arange(32))\n# Compute the wavelet analysis\noutput = wt.cwt(x, scales, wavelet)\n# Identify the frequencies for the analysis\nfrequencies = wt.scale2frequency(wavelet, scales) * fs\n# Plot the analysis\ncmap = plt.cm.seismic\nfig, ax = plt.subplots(1, figsize=(10,10))\n\ntitle = 'Wavelet Transform (Power Spectrum) of signal'\nylabel = 'Frequency (Hz)'\nxlabel = 'Time'\n\npower = (abs(output)) ** 2\nlevels = [0.0625, 0.125, 0.25, 0.5, 1, 2, 4, 8]\ncontourlevels = np.log2(levels)\n\nim = ax.contourf(t, jnp.log2(frequencies), jnp.log2(power), contourlevels, extend='both',cmap=cmap)\n\nax.set_title(title, fontsize=20)\nax.set_ylabel(ylabel, fontsize=18)\nax.set_xlabel(xlabel, fontsize=18)\n\nyticks = 2**np.arange(np.ceil(np.log2(frequencies.min())), np.ceil(np.log2(frequencies.max())))\nax.set_yticks(np.log2(yticks))\nax.set_yticklabels(yticks)\nylim = ax.get_ylim()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}