{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n\n# Block Sparse Bayesian Learning\n    :depth: 2\n    :local:\n\nIn this example, we demonstrate the\nBSBL (Block Sparse Bayesian Learning) algorithm\n:cite:`zhang2012recovery,zhang2013extension`\nfor reconstruction of block sparse signals\nwith intra block correlations from their\ncompressive measurements. In particular,\nwe show\n\n- Creation of block sparse signals with intra-block correlation\n- Compressive sampling of the signal with Gaussian and sparse\n  binary sensing matrices.\n- Reconstruction using BSBL EM algorithm (Expectation Maximization).\n- Reconstruction using BSBL BO algorithm (Bound Optimization).\n- Reconstruction in the presence of high measurement noise\n\n\nOur implementation of BSBL is fully JIT compilable.\nTo achieve this, we limit ourselves of equal sized\nblocks where the block size is user defined. This\nis not a problem in practice. As shown in\n:cite:`zhang2012compressed`, the reconstruction\nfrom compressive measurements of real life signals\nis not affected much by the block size. \n\n\nThe basic compressive sensing model is given by\n\n\\begin{align}\\by = \\Phi \\bx + \\be\\end{align}\n\nwhere $\\by$ is a known measurement vector,\n$\\Phi$ is a known sensing matrix and\n$\\bx$ is a sparse signal to be recovered\nfrom the measurements.\n\nWe introduce the block/group structure on $\\bx$\nas\n\n\\begin{align}\\bx = \\begin{pmatrix}\n    \\bx_1 & \\bx_2 & \\dots & \\bx_g\n    \\end{pmatrix}\\end{align}\n\nwhere each $\\bx_i$ is a block of $b$\nvalues. The signal $\\bx$ consists of $g$\nsuch blocks/groups.  We only consider the case of\nequal sized blocks in our implementation.\nUnder the block sparsity model, only a few\n$k \\ll g$ blocks are nonzero (active)\nin the signal $\\bx$ however, the locations\nof these blocks are unknown.\n\nWe can rewrite the sensing equation as:\n\n\\begin{align}\\by = \\sum_{i=1}^g \\Phi_i \\bx_i + \\be\\end{align}\n\nby splitting the sensing matrix into blocks of columns appropriately.\n\nUnder the sparse Bayesian framework, each block\nis assumed to satisfy a parametrized multivariate\nGaussian distribution:\n\n\\begin{align}\\PP(\\bx_i ; \\gamma_i, \\bB_i) = \\NNN(\\bzero, \\gamma_i \\bB_i), \\Forall i=1,\\dots,g.\\end{align}\n\nThe covariance matrix $\\bB_i$ captures the intra block correlations.\n\n\n\nWe further assume that the blocks are mutually uncorrelated.\nThe prior of $\\bx$ can then be written as\n\n\\begin{align}\\PP(\\bx; \\{ \\gamma_i, \\bB_i\\}_i ) = \\NNN(\\bzero, \\Sigma_0)\\end{align}\n\nwhere\n\n\\begin{align}\\Sigma_0 = \\text{diag}\\{\\gamma_1 \\bB_1, \\dots, \\gamma_g \\bB_g \\}.\\end{align}\n\n\nWe also model the correlation among the values\nwithin each active block as an AR-1 process. Under this\nassumption the matrices $\\bB_i$ take the form of a Toeplitz\nmatrix\n\n\\begin{align}\\bB = \\begin{bmatrix}\n    1 & r & \\dots & r^{b-1}\\\\\n    r & 1 & \\dots & r^{b-2}\\\\\n    \\vdots &  & \\ddots & \\vdots\\\\\n    r^{b-1} & r^{b-2} & \\dots & 1\n    \\end{bmatrix}\\end{align}\n\nwhere $r$ is the AR-1 model coefficient. This constraint\nsignificantly reduces the model parameters to be learned.\n\nMeasurement noise is modeled as independent zero mean Gaussian\nnoise $\\PP(\\be; \\lambda) \\sim \\NNN(\\bzero, \\lambda \\bI)$.\nBSBL doesn't require you to provide the value of noise variance\nas input. It is able to estimate $\\lambda$ within a algorithm.\n\nThe estimate of $\\bx$ under Bayesian learning framework\nis given by the posterior mean of $\\bx$ given the measurements\n$\\by$.\n\n\nPlease also refer to the\n[BSBL website](http://dsp.ucsd.edu/~zhilin/BSBL.html)\nby the authors of the original algorithm for further information.\n\nRelated Examples\n\n- `gallery:cs:sparse_binary_sensor`\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setup\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Configure JAX for 64-bit computing\nfrom jax.config import config\nconfig.update(\"jax_enable_x64\", True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's import necessary libraries\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from matplotlib import pyplot as plt\n# jax imports\nimport jax.numpy as jnp\nfrom jax import random, jit\n# cr-suite imports\nimport cr.nimble as crn\nimport cr.nimble.dsp as crdsp\nimport cr.sparse as crs\nimport cr.sparse.dict as crdict\nimport cr.sparse.data as crdata\nimport cr.sparse.plots as crplot\n\nimport cr.sparse.block.bsbl as bsbl"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Problem Configuration\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# ambient dimension\nn = 300\n# block length\nb = 4\n# number of blocks\nnb = n // b\n# Block sparsity: number of nonzero blocks\nk = 6\n# Number of measurements\nm = 100"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Block Sparse Signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Block sparse signal with intra block correlation\nx, blocks, indices  = crdata.sparse_normal_blocks(\n    crn.KEYS[2], n, k, b, cor=0.9, normalize_blocks=True)\nax = crplot.one_plot()\nax.stem(x);"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Gaussian Sensing\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Sensing matrix\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Phi = crdict.gaussian_mtx(crn.KEYS[0], m, n, normalize_atoms=True)\nax = crplot.one_plot()\nax.imshow(Phi);"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Measurements\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "y = Phi @ x\nax = crplot.one_plot()\ncrplot.plot_signal(ax, y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Reconstruction using BSBL EM\nWe need to provide the sensing matrix, measurements\nand the block size as parameters to the\nreconstruction algorithm\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "options = bsbl.bsbl_em_options(y, learn_lambda=0)\nsol = bsbl.bsbl_em_jit(Phi, y, b, options)\nprint(sol)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_hat = sol.x\nprint(f'BSBL-EM: PRD: {crn.prd(x, x_hat):.1f} %, SNR: {crn.signal_noise_ratio(x, x_hat):.2f} dB.' )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot the original and reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ax = crplot.h_plots(2)\nax[0].stem(x)\nax[1].stem(x_hat)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Reconstruction using BSBL BO\nWe need to provide the sensing matrix, measurements\nand the block size as parameters to the\nreconstruction algorithm\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "options = bsbl.bsbl_bo_options(y, learn_lambda=0)\nsol = bsbl.bsbl_bo_jit(Phi, y, b, options)\nprint(sol)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_hat = sol.x\nprint(f'BSBL-BO: PRD: {crn.prd(x, x_hat):.1f} %, SNR: {crn.signal_noise_ratio(x, x_hat):.2f} dB.' )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot the original and reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ax = crplot.h_plots(2)\nax[0].stem(x)\nax[1].stem(x_hat)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Observations:\n\n* We specified ``learn_lambda=0`` since we knew that this is a noiseless problem.\n* Note the nonzero blocks count. They have been identified correctly.\n* Recovery is perfect for both algorithms. In other words, both the\n  nonzero coefficient values and locations have been\n  correctly estimated and identified respectively.\n* BSBL-BO is faster compared to BSBL-EM.\n  See how it finished in far less number of iterations.\n  This is on expected lines as BSBL-BO accelerates the convergence\n  using bound optimization a.k.a. majorization-minimization.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Sparse Binary Sensing\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We shall have just 12 ones in each column of the sensing matrix\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "d = 12"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Build the sensing matrix\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Phi = crdict.sparse_binary_mtx(crn.KEYS[0], m, n, d, \n    normalize_atoms=True, dense=True)\nax = crplot.one_plot()\nax.spy(Phi);"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Measurements\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "y = Phi @ x\nax = crplot.one_plot()\ncrplot.plot_signal(ax, y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Reconstruction using BSBL EM\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "options = bsbl.bsbl_em_options(y, learn_lambda=0)\nsol = bsbl.bsbl_em_jit(Phi, y, b, options)\nprint(sol)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_hat = sol.x\nprint(f'BSBL-EM: PRD: {crn.prd(x, x_hat):.1f} %, SNR: {crn.signal_noise_ratio(x, x_hat):.2f} dB.' )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot the original and reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ax = crplot.h_plots(2)\nax[0].stem(x)\nax[1].stem(x_hat)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Reconstruction using BSBL BO\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "options = bsbl.bsbl_bo_options(y, learn_lambda=0)\nsol = bsbl.bsbl_bo_jit(Phi, y, b, options)\nprint(sol)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_hat = sol.x\nprint(f'BSBL-BO: PRD: {crn.prd(x, x_hat):.1f} %, SNR: {crn.signal_noise_ratio(x, x_hat):.2f} dB.' )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot the original and reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ax = crplot.h_plots(2)\nax[0].stem(x)\nax[1].stem(x_hat)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Observations:\n\n* Recovery is perfect for both algorithms.\n* BSBL-BO is much faster.\n* Both algorithms are converging in same number of iterations\n  as Gaussian sensing matrices.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Noisy Measurements\nWe now consider an example where compressive measurements\nare corrupted with significant amount of noise.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# rows (measurements)\nm = 80\n# columns (signal space)\nn = 162\n# block length\nb = 6\n# number of nonzero blocks\nk = 5\n# number of signals\ns = 1\n# number of blocks\nnb = n // b\n# Signal to Noise Ratio in DB\nsnr = 15"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "generate block sparse signal with high intra block correlation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x, blocks, indices  = crdata.sparse_normal_blocks(\n    crn.KEYS[1], n, k, b, s, cor=0.95, normalize_blocks=True)\n\nax = crplot.one_plot()\nax.stem(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Sensing matrix\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "Phi = crdict.gaussian_mtx(crn.KEYS[2], m, n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Noiseless measurements\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "y0 = Phi @ x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Noise at an SNR of 15 dB\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "noise = crdsp.awgn_at_snr_std(crn.KEYS[3], y0, snr)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Addition of noise to measurements\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "y = y0 + noise\nprint(f'measurement SNR: {crn.signal_noise_ratio(y0, y):.2f} dB')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot the noiseless and noisy measurements\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ax = crplot.h_plots(2)\nax[0].plot(y0)\nax[1].plot(y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Reconstruction Benchmark\nAn oracle reconstruction is possible if one knows the\nnonzero indices of x. One can then compute a least\nsquare solution over these indices.\nThe reconstruction SNR of this solution gives us\na good benchmark against which we can evaluate\nthe quality of reconstruction by any other algorithm.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Find the least square solution\nx_ls_coeffs = jnp.linalg.pinv(Phi[:, indices]) @ y\nx_ls = crdsp.build_signal_from_indices_and_values(n, indices, x_ls_coeffs)\nprint(f'Benchmark rec: PRD: {crn.prd(x, x_ls):.1f} %, SNR: {crn.signal_noise_ratio(x, x_ls):.2f} dB')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot the oracle reconstruction\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ax = crplot.h_plots(2)\nax[0].stem(x)\nax[1].stem(x_ls)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Reconstruction using BSBL EM\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "options = bsbl.bsbl_em_options(y)\nsol = bsbl.bsbl_em_jit(Phi, y, b, options)\nprint(sol)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_hat = sol.x\nprint(f'BSBL-EM: PRD: {crn.prd(x, x_hat):.1f} %, SNR: {crn.signal_noise_ratio(x, x_hat):.2f} dB.' )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot the original and reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ax = crplot.h_plots(2)\nax[0].stem(x)\nax[1].stem(x_hat)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Reconstruction using BSBL BO\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "options = bsbl.bsbl_bo_options(y)\nsol = bsbl.bsbl_bo_jit(Phi, y, b, options)\nprint(sol)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "x_hat = sol.x\nprint(f'BSBL-BO: PRD: {crn.prd(x, x_hat):.1f} %, SNR: {crn.signal_noise_ratio(x, x_hat):.2f} dB.' )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot the original and reconstructed signal\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ax = crplot.h_plots(2)\nax[0].stem(x)\nax[1].stem(x_hat)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Observations:\n\n* Benchmark SNR is slightly better than measurement SNR.\n  This is due to the fact that we are using our knowledge of\n  support for x.\n* Both BSBL-EM and BSBL-BO are able to detect the correct support of x.\n* SNR for both BSBL-EM and BSBL-BO is better than the benchmark SNR.\n  This is due to the fact that BSBL is exploiting the intra-block correlation\n  modeled as an AR-1 process.\n* The ordinary least squares solution is not attempting to exploit the\n  intra block correlation structure at all.\n* In this example BSBL-BO is somewhat faster than BSBL-EM but not by much.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}