{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Recovering spikes via TNIPM \n\nThis example has following features:\n\n* A sparse signal consists of a small number of spikes.\n* The sensing matrix is a random dictionary with \n  orthonormal rows.\n* The number of measurements is one fourth of ambient dimensions.\n* The measurements are corrupted by noise.\n* Truncated Newton Interior Points Method (TNIPM) a.k.a. l1-ls  \n  algorithm is being used for recovery.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's import necessary libraries \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import jax.numpy as jnp\nfrom jax import random\nnorm = jnp.linalg.norm\n\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\n\nimport cr.sparse as crs\nimport cr.sparse.data as crdata\nimport cr.sparse.lop as lop\nimport cr.sparse.cvx.l1ls as l1ls"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setup\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Number of measurements\nm = 2**10\n# Ambient dimension\nn  = 2**12\n# Number of spikes (sparsity)\nk = 160\nprint(f'{m=}, {n=}')\n\nkey = random.PRNGKey(0)\nkeys = random.split(key, 4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## The Spikes\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "xs, omega = crdata.sparse_spikes(keys[0], n, k)\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(xs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## The Sparsifying Basis\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "A = lop.random_orthonormal_rows_dict(keys[1], m, n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Measurement process\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Clean measurements\nbs = A.times(xs)\n# Noise\nsigma = 0.01\nnoise = sigma * random.normal(keys[2], (m,))\n# Noisy measurements\nb = bs + noise\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(b)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Recovery using TNIPM\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# We need to estimate the regularization paramter\nAtb = A.trans(b)\ntau = float(0.1 * jnp.max(jnp.abs(Atb)))\nprint(f'{tau=}')\n# Now run the solver\nsol = l1ls.solve_jit(A, b, tau)\n\n# number of L1-LS iterations\niterations = int(sol.iterations)\n# number of A x operations\nn_times = int(sol.n_times)\n# number of A^H y operations\nn_trans = int(sol.n_trans)\nprint(f'{iterations=} {n_times=} {n_trans=}')\n\n# residual norm\nr_norm = norm(sol.x)\nprint(f'{r_norm=:.3e}')\n\n# relative error\nrel_error = norm(xs - sol.x) / norm(xs)\nprint(f'{rel_error=:.3e}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Solution \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.subplot(211)\nplt.plot(xs)\nplt.subplot(212)\nplt.plot(sol.x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### The magnitudes of non-zero values \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(jnp.sort(jnp.abs(sol.x)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Thresholding for large values \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x = crs.hard_threshold_by(sol.x, 0.5)\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.subplot(211)\nplt.plot(xs)\nplt.subplot(212)\nplt.plot(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Verifying the support recovery \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "support_xs = crs.support(xs)\nsupport_x = crs.support(x)\njnp.all(jnp.equal(support_xs, support_x))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Improvement using least squares over support \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Identify the sub-matrix of columns for the support of recovered solution's large entries\nsupport_x = crs.largest_indices_by(sol.x, 0.5)\nAI = A.columns(support_x)\nprint(AI.shape)\n\n# Solve the least squares problem over these columns\nx_I, residuals, rank, s  = jnp.linalg.lstsq(AI, b)\n# fill the non-zero entries into the sparse least squares solution\nx_ls = jnp.zeros_like(xs)\nx_ls = x_ls.at[support_x].set(x_I)\n\n# relative error\nls_rel_error = norm(xs - x_ls) / norm(xs)\nprint(f'{ls_rel_error=:.3e}')\n\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.subplot(211)\nplt.plot(xs)\nplt.subplot(212)\nplt.plot(x_ls)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}