{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Random Orthogonal Measurements, Cosine Basis, ADMM\n\nThis example has following features:\n\n* The signal being measured is not sparse by itself.\n* It does have a sparse representation in discrete cosine basis.\n* Measurements are taken by a partial Walsh Hadamard sensing matrix\n  with small number of orthonormal rows\n* The number of measurements is 8 times lower than the dimension of\n  the signal space.\n* ADMM based Basis pursuit denoising is being used to solve the recovery problem.\n\nThis example is adapted from YALL1 package.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's import necessary libraries \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import jax.numpy as jnp\nfrom jax import random\nnorm = jnp.linalg.norm\n\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\n\nfrom cr.sparse import lop\nfrom cr.sparse.cvx.adm import yall1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setup\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Number of measurements\nm = 1024\n# Ambient dimension\nn = m*8\n\nkey = random.PRNGKey(0)\nkeys = random.split(key, 4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Non-sparse signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "xs = 100 * jnp.cumsum(random.normal(keys[0], (n,)))\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(xs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## The Sparsifying Basis\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Psi  = lop.jit(lop.cosine_basis(n))\n\nalpha = Psi.trans(xs)\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(alpha)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Partial Walsh Hadamard Measurements Operator\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# indices of the measurements to be picked\np = random.permutation(keys[1], n)\npicks = jnp.sort(p[:m])\n# Make sure that DC component is always picked up\npicks = picks.at[0].set(0)\nprint(f\"{picks=}\")\n\n# a random permutation of input\nperm = random.permutation(keys[2], n)\nprint(f\"{perm=}\")\n\n# Walsh Hadamard Basis operator\nTwh = lop.walsh_hadamard_basis(n)\n\n# Wrap it with picks and perm\nTpwh = lop.jit(lop.partial_op(Twh, picks, perm))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Measurement process\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Perform exact measurement\nbs = Tpwh.times(xs)\n\n# Add some noise\nsigma = 0.2\nnoise = sigma * random.normal(keys[3], (m,))\nb = bs + noise\n\nplt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(b)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Recovery using ADMM\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# tolerance for solution convergence\ntol = 5e-4\n# BPDN parameter\nrho = 5e-4\n# Run the solver\nsol = yall1.solve(Tpwh, b, rho=rho, tolerance=tol, W=Psi)\niterations = int(sol.iterations)\n#Number of iterations\nprint(f'{iterations=}')\n# Relative error\nrel_error = norm(sol.x-xs)/norm(xs)\nprint(f'{rel_error=:.4e}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Solution \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.plot(xs, label='original')\nplt.plot(sol.x, label='recovered')\nplt.legend()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}