{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# CoSaMP step by step\n\nThis example explains the step by step development of \nCoSaMP (Compressive Sensing Matching Pursuit) algorithm\nfor sparse recovery. It then shows how to use the \nofficial implementation of CoSaMP in ``CR-Sparse``.\n\n\nThe CoSaMP algorithm has following inputs:\n\n* A sensing matrix or dictionary ``Phi`` which has been used for data measurements.\n* A measurement vector ``y``.\n* The sparsity level ``K``.\n\nThe objective of the algorithm is to estimate a K-sparse solution ``x``\nsuch that ``y`` is approximately equal to ``Phi x``.\n\nA key quantity in the algorithm is the residual ``r = y - Phi x``. Each \niteration of the algorithm successively improves the estimate ``x`` so \nthat the energy of the residual ``r`` reduces.\n\nThe algorithm proceeds as follows:\n\n* Initialize the solution ``x`` with zero.\n* Maintain an index set ``I`` (initially empty) of atoms selected as part of the solution.\n* While the residual energy is above a threshold:\n\n * **Match**: Compute the inner product of each atom in ``Phi`` with the current residual ``r``.\n * **Identify**: Select the indices of 2K atoms from ``Phi`` with the largest correlation with the residual.\n * **Merge**: merge these 2K indices with currently selected indices in ``I`` to form ``I_sub``. \n * **LS**: Compute the least squares solution of ``Phi[:, I_sub] z = y``\n * **Prune**: Pick the largest K entries from this least square solution and keep them in ``I``. \n * **Update residual**: Compute ``r = y - Phi_I x_I``.\n\nIt is time to see the algorithm in action.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's import necessary libraries \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import jax\nfrom jax import random\nimport jax.numpy as jnp\n# Some keys for generating random numbers\nkey = random.PRNGKey(0)\nkeys = random.split(key, 4)\n\n# For plotting diagrams\nimport matplotlib.pyplot as plt\n# CR-Sparse modules\nimport cr.sparse as crs\nimport cr.sparse.dict as crdict\nimport cr.sparse.data as crdata\nfrom cr.nimble.dsp import (\n    nonzero_indices,\n    nonzero_values,\n    largest_indices\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Problem Setup\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Number of measurements\nM = 128\n# Ambient dimension\nN = 256\n# Sparsity level\nK = 8"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### The Sparsifying Basis\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Phi = crdict.gaussian_mtx(key, M,N)\nprint(Phi.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Coherence of atoms in the sensing matrix\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(crdict.coherence(Phi))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### A sparse model vector\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x0, omega = crdata.sparse_normal_representations(key, N, K)\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(x0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "``omega`` contains the set of indices at which x is nonzero (support of ``x``)\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(omega)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Compressive measurements\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "y = Phi @ x0\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Development of CoSaMP algorithm\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# In the following, we walk through the steps of CoSaMP algorithm.\n# Since we have access to ``x0`` and ``omega``, we can measure the\n# progress made by the algorithm steps by comparing the estimates\n# with actual ``x0`` and ``omega``. However, note that in the \n# real implementation of the algorithm, no access to original model\n# vector is there.\n#\n# Initialization\n# ''''''''''''''''''''''''''''''''''''''''''''"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We assume the initial solution to be zero and \nthe residual ``r = y - Phi x`` to equal the measurements ``y``\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "r = y"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Squared norm/energy of the residual\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "y_norm_sqr = float(y.T @ y)\nr_norm_sqr = y_norm_sqr\nprint(f\"{r_norm_sqr=}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "A boolean array to track the indices selected for least squares steps\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "flags = jnp.zeros(N, dtype=bool)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "During the matching steps, 2K atoms will be picked.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "K2 = 2*K"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "At any time, up to 3K atoms may be selected (after the merge step).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "K3 = K + K2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Number of iterations completed so far\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "iterations = 0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "A limit on the maximum tolerance for residual norm\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "res_norm_rtol = 1e-3\nmax_r_norm_sqr = y_norm_sqr * (res_norm_rtol ** 2)\nprint(f\"{max_r_norm_sqr=:.2e}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### First iteration\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(\"First iteration:\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Match the current residual with the atoms in ``Phi``\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "h = Phi.T @ r"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Pick the indices of 3K atoms with largest matches with the residual\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "I_sub =  largest_indices(h, K3)\n# Update the flags array\nflags = flags.at[I_sub].set(True)\n# Sort the ``I_sub`` array with the help of flags array\nI_sub, = jnp.where(flags)\n# Since no atoms have been selected so far, we can be more aggressive\n# and pick 3K atoms in first iteration. \nprint(f\"{I_sub=}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Check which indices from ``omega`` are there in ``I_sub``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(jnp.intersect1d(omega, I_sub))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Select the subdictionary of ``Phi`` consisting of atoms indexed by I_sub\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Phi_sub = Phi[:, flags]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Compute the least squares solution of ``y`` over this subdictionary\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_sub, r_sub_norms, rank_sub, s_sub = jnp.linalg.lstsq(Phi_sub, y)\n# Pick the indices of K largest entries in in ``x_sub`` \nIa = largest_indices(x_sub, K)\nprint(f\"{Ia=}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We need to map the indices in ``Ia`` to the actual indices of atoms in ``Phi``\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "I = I_sub[Ia]\nprint(f\"{I=}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Select the corresponding values from the LS solution\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_I = x_sub[Ia]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We now have our first estimate of the solution\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x = jnp.zeros(N).at[I].set(x_I)\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(x0, label=\"Original vector\")\nplt.plot(x, '--', label=\"Estimated solution\")\nplt.legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can check how good we were in picking the correct indices from the actual support of the signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "found = jnp.intersect1d(omega, I)\nprint(\"Found indices: \", found)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We found 6 out of 8 indices in the support. Here are the remaining.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "missing = jnp.setdiff1d(omega, I)\nprint(\"Missing indices: \", missing)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "It is time to compute the residual after the first iteration\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Phi_I = Phi[:, I]\nr = y - Phi_I @ x_I"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Compute the residual and verify that it is still larger than the allowed tolerance\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "r_norm_sqr = float(r.T @ r)\nprint(f\"{r_norm_sqr=:.2e} > {max_r_norm_sqr=:.2e}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Store the selected K indices in the flags array\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "flags = flags.at[:].set(False)\nflags = flags.at[I].set(True)\nprint(jnp.where(flags))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Mark the completion of the iteration\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "iterations += 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Second iteration\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(\"Second iteration:\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Match the current residual with the atoms in ``Phi``\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "h = Phi.T @ r"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Pick the indices of 2K atoms with largest matches with the residual\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "I_2k =  largest_indices(h, K2 if iterations else K3)\n# We can check if these include the atoms missed out in first iteration.\nprint(jnp.intersect1d(omega, I_2k))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Merge (union) the set of previous K indices with the new 2K indices\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "flags = flags.at[I_2k].set(True)\nI_sub, = jnp.where(flags)\nprint(f\"{I_sub=}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can check if we found all the actual atoms\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(\"Found in I_sub: \", jnp.intersect1d(omega, I_sub))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Indeed we did. The set difference is empty.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(\"Missing in I_sub: \", jnp.setdiff1d(omega, I_sub))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Select the subdictionary of ``Phi`` consisting of atoms indexed by ``I_sub``\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Phi_sub = Phi[:, flags]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Compute the least squares solution of ``y`` over this subdictionary\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_sub, r_sub_norms, rank_sub, s_sub = jnp.linalg.lstsq(Phi_sub, y)\n# Pick the indices of K largest entries in in ``x_sub`` \nIa = largest_indices(x_sub, K)\nprint(Ia)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We need to map the indices in ``Ia`` to the actual indices of atoms in ``Phi``\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "I = I_sub[Ia]\nprint(I)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Check if the final K indices in ``I`` include all the indices in ``omega``\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "jnp.setdiff1d(omega, I)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Select the corresponding values from the LS solution\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_I = x_sub[Ia]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Here is our updated estimate of the solution\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x = jnp.zeros(N).at[I].set(x_I)\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(x0, label=\"Original vector\")\nplt.plot(x, '--', label=\"Estimated solution\")\nplt.legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The algorithm has no direct way of knowing that it indeed found the solution\nIt is time to compute the residual after the second iteration\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Phi_I = Phi[:, I]\nr = y - Phi_I @ x_I"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Compute the residual and verify that it is now below the allowed tolerance\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "r_norm_sqr = float(r.T @ r)\n# It turns out that it is now below the tolerance threshold\nprint(f\"{r_norm_sqr=:.2e} < {max_r_norm_sqr=:.2e}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We have completed the signal recovery. We can stop iterating now.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "iterations += 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## CR-Sparse official implementation\nThe JIT compiled version of this algorithm is available in \n``cr.sparse.pursuit.cosamp`` module.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Import the module\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from cr.sparse.pursuit import cosamp"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Run the solver\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "solution =  cosamp.matrix_solve_jit(Phi, y, K)\n# The support for the sparse solution\nI = solution.I\nprint(I)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The non-zero values on the support\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_I = solution.x_I\nprint(x_I)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Verify that we successfully recovered the support\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(jnp.setdiff1d(omega, I))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Print the residual energy and the number of iterations when the algorithm converged.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(solution.r_norm_sqr, solution.iterations)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's plot the solution\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x = jnp.zeros(N).at[I].set(x_I)\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(x0, label=\"Original vector\")\nplt.plot(x, '--', label=\"Estimated solution\")\nplt.legend()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}