{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# 1 bit Compressive Sensing\n\nThis example demonstrates following features\n- Making 1-bit quantized compressive measurements of a sparse signal \n- Recovering the original signal using the BIHT (Binary Iterative Hard Thresholding) algorithm.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's import necessary libraries \n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import jax.numpy as jnp\nfrom jax.numpy.linalg import norm\n\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\n\nimport cr.nimble as cnb\nimport cr.sparse as crs\nimport cr.sparse.dict as crdict\nimport cr.sparse.data as crdata\nimport cr.sparse.cs.cs1bit as cs1bit\n\nfrom cr.nimble.dsp import (\n    build_signal_from_indices_and_values\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setup\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Number of measurements\nM = 256\n# Ambient dimension\nN = 512\n# Sparsity level\nK = 4"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Sensing Matrix\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Phi = crdict.gaussian_mtx(cnb.KEYS[0], M, N, normalize_atoms=False)\n# frame bound\ns0 = crdict.upper_frame_bound(Phi)\nprint(s0)\nfig=plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.imshow(Phi, extent=[0, 2, 0, 1])\nplt.gray()\nplt.colorbar()\nplt.title(r'$\\Phi$')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## K-sparse signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x, omega = crdata.sparse_normal_representations(cnb.KEYS[1], N, K)\n# normalize signal\nx = x / norm(x)\n# the support indices\nprint(omega)\nfig=plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.stem(x, markerfmt='.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Measurement process\nmeasurements\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "y = cs1bit.measure_1bit(Phi, x)\nfig=plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.stem(y, markerfmt='.')\nprint(y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Signal Reconstruction using BIHT\nsolver step-size\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "tau = 0.98 * s0\n# solution\nsol = cs1bit.biht_jit(Phi, y, K, tau)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_rec = build_signal_from_indices_and_values(N, sol.I, sol.x_I)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Verification\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig=plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k')\nplt.subplot(211)\nplt.title('original')\nplt.stem(x, markerfmt='.', linefmt='gray')\nplt.subplot(212)\nplt.stem(x_rec, markerfmt='.')\nplt.title('reconstruction')\n\n# recovered support\nI = jnp.sort(sol.I)\nprint(I)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "check if the support is recovered correctly\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(jnp.array_equal(omega, I))\n# normalize recovered signal\nx_rec = x_rec / norm(x_rec)\n# the norm of error\nprint(norm(x - x_rec))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}