.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/0200_cs/cosamp_step_by_step.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_0200_cs_cosamp_step_by_step.py: CoSaMP step by step ========================== This example explains the step by step development of CoSaMP (Compressive Sensing Matching Pursuit) algorithm for sparse recovery. It then shows how to use the official implementation of CoSaMP in ``CR-Sparse``. The CoSaMP algorithm has following inputs: * A sensing matrix or dictionary ``Phi`` which has been used for data measurements. * A measurement vector ``y``. * The sparsity level ``K``. The objective of the algorithm is to estimate a K-sparse solution ``x`` such that ``y`` is approximately equal to ``Phi x``. A key quantity in the algorithm is the residual ``r = y - Phi x``. Each iteration of the algorithm successively improves the estimate ``x`` so that the energy of the residual ``r`` reduces. The algorithm proceeds as follows: * Initialize the solution ``x`` with zero. * Maintain an index set ``I`` (initially empty) of atoms selected as part of the solution. * While the residual energy is above a threshold: * **Match**: Compute the inner product of each atom in ``Phi`` with the current residual ``r``. * **Identify**: Select the indices of 2K atoms from ``Phi`` with the largest correlation with the residual. * **Merge**: merge these 2K indices with currently selected indices in ``I`` to form ``I_sub``. * **LS**: Compute the least squares solution of ``Phi[:, I_sub] z = y`` * **Prune**: Pick the largest K entries from this least square solution and keep them in ``I``. * **Update residual**: Compute ``r = y - Phi_I x_I``. It is time to see the algorithm in action. .. GENERATED FROM PYTHON SOURCE LINES 41-42 Let's import necessary libraries .. GENERATED FROM PYTHON SOURCE LINES 42-61 .. code-block:: default import jax from jax import random import jax.numpy as jnp # Some keys for generating random numbers key = random.PRNGKey(0) keys = random.split(key, 4) # For plotting diagrams import matplotlib.pyplot as plt # CR-Sparse modules import cr.sparse as crs import cr.sparse.dict as crdict import cr.sparse.data as crdata from cr.nimble.dsp import ( nonzero_indices, nonzero_values, largest_indices ) .. GENERATED FROM PYTHON SOURCE LINES 62-64 Problem Setup ------------------ .. GENERATED FROM PYTHON SOURCE LINES 64-72 .. code-block:: default # Number of measurements M = 128 # Ambient dimension N = 256 # Sparsity level K = 8 .. GENERATED FROM PYTHON SOURCE LINES 73-75 The Sparsifying Basis '''''''''''''''''''''''''' .. GENERATED FROM PYTHON SOURCE LINES 75-78 .. code-block:: default Phi = crdict.gaussian_mtx(key, M,N) print(Phi.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none (128, 256) .. GENERATED FROM PYTHON SOURCE LINES 79-80 Coherence of atoms in the sensing matrix .. GENERATED FROM PYTHON SOURCE LINES 80-82 .. code-block:: default print(crdict.coherence(Phi)) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.3881940752728321 .. GENERATED FROM PYTHON SOURCE LINES 83-85 A sparse model vector '''''''''''''''''''''''''' .. GENERATED FROM PYTHON SOURCE LINES 85-89 .. code-block:: default x0, omega = crdata.sparse_normal_representations(key, N, K) plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k') plt.plot(x0) .. image-sg:: /gallery/0200_cs/images/sphx_glr_cosamp_step_by_step_001.png :alt: cosamp step by step :srcset: /gallery/0200_cs/images/sphx_glr_cosamp_step_by_step_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none [] .. GENERATED FROM PYTHON SOURCE LINES 90-91 ``omega`` contains the set of indices at which x is nonzero (support of ``x``) .. GENERATED FROM PYTHON SOURCE LINES 91-93 .. code-block:: default print(omega) .. rst-class:: sphx-glr-script-out .. code-block:: none [ 41 60 68 89 99 198 232 244] .. GENERATED FROM PYTHON SOURCE LINES 94-96 Compressive measurements '''''''''''''''''''''''''' .. GENERATED FROM PYTHON SOURCE LINES 96-100 .. code-block:: default y = Phi @ x0 plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k') plt.plot(y) .. image-sg:: /gallery/0200_cs/images/sphx_glr_cosamp_step_by_step_002.png :alt: cosamp step by step :srcset: /gallery/0200_cs/images/sphx_glr_cosamp_step_by_step_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none [] .. GENERATED FROM PYTHON SOURCE LINES 101-103 Development of CoSaMP algorithm --------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 103-114 .. code-block:: default # In the following, we walk through the steps of CoSaMP algorithm. # Since we have access to ``x0`` and ``omega``, we can measure the # progress made by the algorithm steps by comparing the estimates # with actual ``x0`` and ``omega``. However, note that in the # real implementation of the algorithm, no access to original model # vector is there. # # Initialization # '''''''''''''''''''''''''''''''''''''''''''' .. GENERATED FROM PYTHON SOURCE LINES 115-117 We assume the initial solution to be zero and the residual ``r = y - Phi x`` to equal the measurements ``y`` .. GENERATED FROM PYTHON SOURCE LINES 117-118 .. code-block:: default r = y .. GENERATED FROM PYTHON SOURCE LINES 119-120 Squared norm/energy of the residual .. GENERATED FROM PYTHON SOURCE LINES 120-124 .. code-block:: default y_norm_sqr = float(y.T @ y) r_norm_sqr = y_norm_sqr print(f"{r_norm_sqr=}") .. rst-class:: sphx-glr-script-out .. code-block:: none r_norm_sqr=7.401212029141624 .. GENERATED FROM PYTHON SOURCE LINES 125-126 A boolean array to track the indices selected for least squares steps .. GENERATED FROM PYTHON SOURCE LINES 126-127 .. code-block:: default flags = jnp.zeros(N, dtype=bool) .. GENERATED FROM PYTHON SOURCE LINES 128-129 During the matching steps, 2K atoms will be picked. .. GENERATED FROM PYTHON SOURCE LINES 129-130 .. code-block:: default K2 = 2*K .. GENERATED FROM PYTHON SOURCE LINES 131-132 At any time, up to 3K atoms may be selected (after the merge step). .. GENERATED FROM PYTHON SOURCE LINES 132-134 .. code-block:: default K3 = K + K2 .. GENERATED FROM PYTHON SOURCE LINES 135-136 Number of iterations completed so far .. GENERATED FROM PYTHON SOURCE LINES 136-139 .. code-block:: default iterations = 0 .. GENERATED FROM PYTHON SOURCE LINES 140-141 A limit on the maximum tolerance for residual norm .. GENERATED FROM PYTHON SOURCE LINES 141-145 .. code-block:: default res_norm_rtol = 1e-3 max_r_norm_sqr = y_norm_sqr * (res_norm_rtol ** 2) print(f"{max_r_norm_sqr=:.2e}") .. rst-class:: sphx-glr-script-out .. code-block:: none max_r_norm_sqr=7.40e-06 .. GENERATED FROM PYTHON SOURCE LINES 146-148 First iteration '''''''''''''''''''''''''''''''''''''''''''' .. GENERATED FROM PYTHON SOURCE LINES 148-149 .. code-block:: default print("First iteration:") .. rst-class:: sphx-glr-script-out .. code-block:: none First iteration: .. GENERATED FROM PYTHON SOURCE LINES 150-151 Match the current residual with the atoms in ``Phi`` .. GENERATED FROM PYTHON SOURCE LINES 151-153 .. code-block:: default h = Phi.T @ r .. GENERATED FROM PYTHON SOURCE LINES 154-155 Pick the indices of 3K atoms with largest matches with the residual .. GENERATED FROM PYTHON SOURCE LINES 155-163 .. code-block:: default I_sub = largest_indices(h, K3) # Update the flags array flags = flags.at[I_sub].set(True) # Sort the ``I_sub`` array with the help of flags array I_sub, = jnp.where(flags) # Since no atoms have been selected so far, we can be more aggressive # and pick 3K atoms in first iteration. print(f"{I_sub=}") .. rst-class:: sphx-glr-script-out .. code-block:: none I_sub=Array([ 14, 30, 44, 60, 64, 78, 84, 89, 99, 116, 118, 127, 128, 149, 157, 158, 162, 168, 184, 192, 198, 203, 232, 244], dtype=int64) .. GENERATED FROM PYTHON SOURCE LINES 164-165 Check which indices from ``omega`` are there in ``I_sub``. .. GENERATED FROM PYTHON SOURCE LINES 165-166 .. code-block:: default print(jnp.intersect1d(omega, I_sub)) .. rst-class:: sphx-glr-script-out .. code-block:: none [ 60 89 99 198 232 244] .. GENERATED FROM PYTHON SOURCE LINES 167-168 Select the subdictionary of ``Phi`` consisting of atoms indexed by I_sub .. GENERATED FROM PYTHON SOURCE LINES 168-169 .. code-block:: default Phi_sub = Phi[:, flags] .. GENERATED FROM PYTHON SOURCE LINES 170-171 Compute the least squares solution of ``y`` over this subdictionary .. GENERATED FROM PYTHON SOURCE LINES 171-175 .. code-block:: default x_sub, r_sub_norms, rank_sub, s_sub = jnp.linalg.lstsq(Phi_sub, y) # Pick the indices of K largest entries in in ``x_sub`` Ia = largest_indices(x_sub, K) print(f"{Ia=}") .. rst-class:: sphx-glr-script-out .. code-block:: none Ia=Array([ 3, 7, 23, 20, 22, 8, 15, 18], dtype=int64) .. GENERATED FROM PYTHON SOURCE LINES 176-177 We need to map the indices in ``Ia`` to the actual indices of atoms in ``Phi`` .. GENERATED FROM PYTHON SOURCE LINES 177-179 .. code-block:: default I = I_sub[Ia] print(f"{I=}") .. rst-class:: sphx-glr-script-out .. code-block:: none I=Array([ 60, 89, 244, 198, 232, 99, 158, 184], dtype=int64) .. GENERATED FROM PYTHON SOURCE LINES 180-181 Select the corresponding values from the LS solution .. GENERATED FROM PYTHON SOURCE LINES 181-182 .. code-block:: default x_I = x_sub[Ia] .. GENERATED FROM PYTHON SOURCE LINES 183-184 We now have our first estimate of the solution .. GENERATED FROM PYTHON SOURCE LINES 184-189 .. code-block:: default x = jnp.zeros(N).at[I].set(x_I) plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k') plt.plot(x0, label="Original vector") plt.plot(x, '--', label="Estimated solution") plt.legend() .. image-sg:: /gallery/0200_cs/images/sphx_glr_cosamp_step_by_step_003.png :alt: cosamp step by step :srcset: /gallery/0200_cs/images/sphx_glr_cosamp_step_by_step_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 190-191 We can check how good we were in picking the correct indices from the actual support of the signal .. GENERATED FROM PYTHON SOURCE LINES 191-193 .. code-block:: default found = jnp.intersect1d(omega, I) print("Found indices: ", found) .. rst-class:: sphx-glr-script-out .. code-block:: none Found indices: [ 60 89 99 198 232 244] .. GENERATED FROM PYTHON SOURCE LINES 194-195 We found 6 out of 8 indices in the support. Here are the remaining. .. GENERATED FROM PYTHON SOURCE LINES 195-197 .. code-block:: default missing = jnp.setdiff1d(omega, I) print("Missing indices: ", missing) .. rst-class:: sphx-glr-script-out .. code-block:: none Missing indices: [41 68] .. GENERATED FROM PYTHON SOURCE LINES 198-199 It is time to compute the residual after the first iteration .. GENERATED FROM PYTHON SOURCE LINES 199-201 .. code-block:: default Phi_I = Phi[:, I] r = y - Phi_I @ x_I .. GENERATED FROM PYTHON SOURCE LINES 202-203 Compute the residual and verify that it is still larger than the allowed tolerance .. GENERATED FROM PYTHON SOURCE LINES 203-205 .. code-block:: default r_norm_sqr = float(r.T @ r) print(f"{r_norm_sqr=:.2e} > {max_r_norm_sqr=:.2e}") .. rst-class:: sphx-glr-script-out .. code-block:: none r_norm_sqr=8.28e-02 > max_r_norm_sqr=7.40e-06 .. GENERATED FROM PYTHON SOURCE LINES 206-207 Store the selected K indices in the flags array .. GENERATED FROM PYTHON SOURCE LINES 207-210 .. code-block:: default flags = flags.at[:].set(False) flags = flags.at[I].set(True) print(jnp.where(flags)) .. rst-class:: sphx-glr-script-out .. code-block:: none (Array([ 60, 89, 99, 158, 184, 198, 232, 244], dtype=int64),) .. GENERATED FROM PYTHON SOURCE LINES 211-212 Mark the completion of the iteration .. GENERATED FROM PYTHON SOURCE LINES 212-214 .. code-block:: default iterations += 1 .. GENERATED FROM PYTHON SOURCE LINES 215-217 Second iteration '''''''''''''''''''''''''''''''''''''''''''' .. GENERATED FROM PYTHON SOURCE LINES 217-218 .. code-block:: default print("Second iteration:") .. rst-class:: sphx-glr-script-out .. code-block:: none Second iteration: .. GENERATED FROM PYTHON SOURCE LINES 219-220 Match the current residual with the atoms in ``Phi`` .. GENERATED FROM PYTHON SOURCE LINES 220-221 .. code-block:: default h = Phi.T @ r .. GENERATED FROM PYTHON SOURCE LINES 222-223 Pick the indices of 2K atoms with largest matches with the residual .. GENERATED FROM PYTHON SOURCE LINES 223-226 .. code-block:: default I_2k = largest_indices(h, K2 if iterations else K3) # We can check if these include the atoms missed out in first iteration. print(jnp.intersect1d(omega, I_2k)) .. rst-class:: sphx-glr-script-out .. code-block:: none [41 68] .. GENERATED FROM PYTHON SOURCE LINES 227-228 Merge (union) the set of previous K indices with the new 2K indices .. GENERATED FROM PYTHON SOURCE LINES 228-231 .. code-block:: default flags = flags.at[I_2k].set(True) I_sub, = jnp.where(flags) print(f"{I_sub=}") .. rst-class:: sphx-glr-script-out .. code-block:: none I_sub=Array([ 8, 25, 41, 42, 60, 66, 67, 68, 72, 89, 99, 111, 129, 158, 164, 184, 190, 195, 198, 216, 220, 232, 233, 244], dtype=int64) .. GENERATED FROM PYTHON SOURCE LINES 232-233 We can check if we found all the actual atoms .. GENERATED FROM PYTHON SOURCE LINES 233-234 .. code-block:: default print("Found in I_sub: ", jnp.intersect1d(omega, I_sub)) .. rst-class:: sphx-glr-script-out .. code-block:: none Found in I_sub: [ 41 60 68 89 99 198 232 244] .. GENERATED FROM PYTHON SOURCE LINES 235-236 Indeed we did. The set difference is empty. .. GENERATED FROM PYTHON SOURCE LINES 236-238 .. code-block:: default print("Missing in I_sub: ", jnp.setdiff1d(omega, I_sub)) .. rst-class:: sphx-glr-script-out .. code-block:: none Missing in I_sub: [] .. GENERATED FROM PYTHON SOURCE LINES 239-240 Select the subdictionary of ``Phi`` consisting of atoms indexed by ``I_sub`` .. GENERATED FROM PYTHON SOURCE LINES 240-241 .. code-block:: default Phi_sub = Phi[:, flags] .. GENERATED FROM PYTHON SOURCE LINES 242-243 Compute the least squares solution of ``y`` over this subdictionary .. GENERATED FROM PYTHON SOURCE LINES 243-247 .. code-block:: default x_sub, r_sub_norms, rank_sub, s_sub = jnp.linalg.lstsq(Phi_sub, y) # Pick the indices of K largest entries in in ``x_sub`` Ia = largest_indices(x_sub, K) print(Ia) .. rst-class:: sphx-glr-script-out .. code-block:: none [ 4 9 23 18 21 10 7 2] .. GENERATED FROM PYTHON SOURCE LINES 248-249 We need to map the indices in ``Ia`` to the actual indices of atoms in ``Phi`` .. GENERATED FROM PYTHON SOURCE LINES 249-251 .. code-block:: default I = I_sub[Ia] print(I) .. rst-class:: sphx-glr-script-out .. code-block:: none [ 60 89 244 198 232 99 68 41] .. GENERATED FROM PYTHON SOURCE LINES 252-253 Check if the final K indices in ``I`` include all the indices in ``omega`` .. GENERATED FROM PYTHON SOURCE LINES 253-254 .. code-block:: default jnp.setdiff1d(omega, I) .. rst-class:: sphx-glr-script-out .. code-block:: none Array([], dtype=int64) .. GENERATED FROM PYTHON SOURCE LINES 255-256 Select the corresponding values from the LS solution .. GENERATED FROM PYTHON SOURCE LINES 256-257 .. code-block:: default x_I = x_sub[Ia] .. GENERATED FROM PYTHON SOURCE LINES 258-259 Here is our updated estimate of the solution .. GENERATED FROM PYTHON SOURCE LINES 259-264 .. code-block:: default x = jnp.zeros(N).at[I].set(x_I) plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k') plt.plot(x0, label="Original vector") plt.plot(x, '--', label="Estimated solution") plt.legend() .. image-sg:: /gallery/0200_cs/images/sphx_glr_cosamp_step_by_step_004.png :alt: cosamp step by step :srcset: /gallery/0200_cs/images/sphx_glr_cosamp_step_by_step_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 265-267 The algorithm has no direct way of knowing that it indeed found the solution It is time to compute the residual after the second iteration .. GENERATED FROM PYTHON SOURCE LINES 267-269 .. code-block:: default Phi_I = Phi[:, I] r = y - Phi_I @ x_I .. GENERATED FROM PYTHON SOURCE LINES 270-271 Compute the residual and verify that it is now below the allowed tolerance .. GENERATED FROM PYTHON SOURCE LINES 271-274 .. code-block:: default r_norm_sqr = float(r.T @ r) # It turns out that it is now below the tolerance threshold print(f"{r_norm_sqr=:.2e} < {max_r_norm_sqr=:.2e}") .. rst-class:: sphx-glr-script-out .. code-block:: none r_norm_sqr=7.09e-30 < max_r_norm_sqr=7.40e-06 .. GENERATED FROM PYTHON SOURCE LINES 275-276 We have completed the signal recovery. We can stop iterating now. .. GENERATED FROM PYTHON SOURCE LINES 276-278 .. code-block:: default iterations += 1 .. GENERATED FROM PYTHON SOURCE LINES 279-283 CR-Sparse official implementation ---------------------------------------- The JIT compiled version of this algorithm is available in ``cr.sparse.pursuit.cosamp`` module. .. GENERATED FROM PYTHON SOURCE LINES 285-286 Import the module .. GENERATED FROM PYTHON SOURCE LINES 286-287 .. code-block:: default from cr.sparse.pursuit import cosamp .. GENERATED FROM PYTHON SOURCE LINES 288-289 Run the solver .. GENERATED FROM PYTHON SOURCE LINES 289-293 .. code-block:: default solution = cosamp.matrix_solve_jit(Phi, y, K) # The support for the sparse solution I = solution.I print(I) .. rst-class:: sphx-glr-script-out .. code-block:: none [ 60 89 244 198 232 99 68 41] .. GENERATED FROM PYTHON SOURCE LINES 294-295 The non-zero values on the support .. GENERATED FROM PYTHON SOURCE LINES 295-297 .. code-block:: default x_I = solution.x_I print(x_I) .. rst-class:: sphx-glr-script-out .. code-block:: none [ 1.9097652 1.12094818 1.04348768 -0.82606793 0.64812788 0.33432345 0.29561749 0.08482584] .. GENERATED FROM PYTHON SOURCE LINES 298-299 Verify that we successfully recovered the support .. GENERATED FROM PYTHON SOURCE LINES 299-300 .. code-block:: default print(jnp.setdiff1d(omega, I)) .. rst-class:: sphx-glr-script-out .. code-block:: none [] .. GENERATED FROM PYTHON SOURCE LINES 301-302 Print the residual energy and the number of iterations when the algorithm converged. .. GENERATED FROM PYTHON SOURCE LINES 302-303 .. code-block:: default print(solution.r_norm_sqr, solution.iterations) .. rst-class:: sphx-glr-script-out .. code-block:: none 7.726387804898689e-30 3 .. GENERATED FROM PYTHON SOURCE LINES 304-305 Let's plot the solution .. GENERATED FROM PYTHON SOURCE LINES 305-310 .. code-block:: default x = jnp.zeros(N).at[I].set(x_I) plt.figure(figsize=(8,6), dpi= 100, facecolor='w', edgecolor='k') plt.plot(x0, label="Original vector") plt.plot(x, '--', label="Estimated solution") plt.legend() .. image-sg:: /gallery/0200_cs/images/sphx_glr_cosamp_step_by_step_005.png :alt: cosamp step by step :srcset: /gallery/0200_cs/images/sphx_glr_cosamp_step_by_step_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 3.595 seconds) .. _sphx_glr_download_gallery_0200_cs_cosamp_step_by_step.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: cosamp_step_by_step.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: cosamp_step_by_step.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_