Source code for

# Copyright 2021 CR-Suite Development Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import jax
from jax import lax
from typing import NamedTuple, List, Dict
import jax.numpy as jnp
from jax.numpy.linalg import norm
import cr.sparse as crs
import cr.nimble as cnb
import cr.nimble.dsp as crdsp

class MPState(NamedTuple):
    """State of matching pursuit algorithm
    # The non-zero values
    x: jnp.ndarray
    """Non-zero values"""
    r: jnp.ndarray
    """The residuals"""
    r_norm_sqr: float
    """The residual norm squared"""
    iterations: int

    def I(self):

    def x_I(self):
        return crdsp.nonzero_values(self.x)

    def length(self):
        return self.x.size

    def __str__(self):
        """Returns the string representation of the state
        s = []
        r_norm = math.sqrt(float(self.r_norm_sqr))
        x_norm = float(norm(self.x))
        for x in [
            f"m={len(self.r)}, n={self.length}, k={len(self.I)}",
        return u'\n'.join(s)

[docs]def solve(Phi, y, max_iters=100, res_norm_rtol=1e-4, tracker=crs.noop_tracker): r"""Solves the sparse recovery problem :math:`y = \Phi x + e` using matching pursuit algorithm Args: Phi: A linear operator y(jax.numpy.ndarray): Measurements max_iters(int): Maximum number of iterations res_norm_rtol(float): Acceptable residual norm Returns: MPState: A named tuple containing matching pursuit solution state MPState interface is similar to :py:class:`cr.sparse.pursuit.RecoverySolution`. Examples: - :ref:`gallery:cs:mp:1` Note: If the expected sparsity of the solution is known then, one should set max_iters as about 4 times that value. """ ## Initialize some constants for the algorithm m, n = Phi.shape # squared norm of the signal y_norm_sqr = y.T @ y # limit on r norm square max_r_norm_sqr = y_norm_sqr * (res_norm_rtol ** 2) zero_vec = jnp.zeros(n) def init_func(): # initialize residual r = y x = jnp.zeros(n) return MPState(x=x, r=r, r_norm_sqr=y_norm_sqr, iterations=0) def body_func(state): h = Phi.trans(state.r) abs_h = jnp.abs(h) # find the maximum in the column best_match_index = jnp.argmax(abs_h) # pick corresponding correlation value coeff = h[best_match_index] # update the representation x =[best_match_index].add(coeff) # find the best match atom atom = Phi.times([best_match_index].set(1)) # update the residual r = state.r - coeff * atom # Compute residual norm squared r_norm_sqr = r.T @ r return MPState(x=x, r=r, r_norm_sqr=y_norm_sqr, iterations=state.iterations+1) def cond_func(state): jax.debug.callback(tracker, state) # limit on residual norm a = state.r_norm_sqr > max_r_norm_sqr # limit on number of iterations b = state.iterations < max_iters c = jnp.logical_and(a, b) return c # state = init_func() # while cond_func(state): # state = body_func(state) state = lax.while_loop(cond_func, body_func, init_func()) return state
[docs]def matrix_solve(Phi, y, max_iters=100, res_norm_rtol=1e-4, tracker=crs.noop_tracker): r"""Solves the sparse recovery problem :math:`y = \Phi x + e` using matching pursuit algorithm Args: Phi(jax.numpy.ndarray): A sensing matrix / dictionary y(jax.numpy.ndarray): Measurements max_iters(int): Maximum number of iterations res_norm_rtol(float): Acceptable residual norm Returns: MPState: A named tuple containing matching pursuit solution state MPState interface is similar to :py:class:`cr.sparse.pursuit.RecoverySolution`. Examples: - :ref:`gallery:cs:mp:1` Note: If the expected sparsity of the solution is known then, one should set max_iters as about 4 times that value. """ ## Initialize some constants for the algorithm m, n = Phi.shape # squared norm of the signal y_norm_sqr = y.T @ y # limit on r norm square max_r_norm_sqr = y_norm_sqr * (res_norm_rtol ** 2) zero_vec = jnp.zeros(n) def init_func(): # initialize residual r = y x = jnp.zeros(n) return MPState(x=x, r=r, r_norm_sqr=y_norm_sqr, iterations=0) def body_func(state): h = Phi.T @ state.r abs_h = jnp.abs(h) # find the maximum in the column best_match_index = jnp.argmax(abs_h) # pick corresponding correlation value coeff = h[best_match_index] # update the representation x =[best_match_index].add(coeff) # find the best match atom atom = Phi[:, best_match_index] # update the residual r = state.r - coeff * atom # Compute residual norm squared r_norm_sqr = r.T @ r return MPState(x=x, r=r, r_norm_sqr=y_norm_sqr, iterations=state.iterations+1) def cond_func(state): jax.debug.callback(tracker, state) # limit on residual norm a = state.r_norm_sqr > max_r_norm_sqr # limit on number of iterations b = state.iterations < max_iters c = jnp.logical_and(a, b) return c # state = init_func() # while cond_func(state): # state = body_func(state) state = lax.while_loop(cond_func, body_func, init_func()) return state