Source code for cr.sparse._src.pursuit.defs

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import NamedTuple, List, Dict
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node
from cr.nimble.dsp import build_signal_from_indices_and_values
norm = jnp.linalg.norm

@dataclass
class SingleRecoverySolution:
    signals: jax.Array = None
    representations : jax.Array = None
    residuals : jax.Array =  None
    residual_norms : jax.Array = None
    iterations: int = None
    support : jax.Array = None

[docs]class RecoverySolution(NamedTuple): r"""Represents the solution of a sparse recovery problem Consider a sparse recovery problem :math:`y=\Phi x + e`. Assume that :math:`x` is supported on an index set :math:`I` i.e. the non-zero values of :math:`x` are in the sub-vector :math:`x_I`, then the equation can be rewritten as :math:`y = \Phi_I x_I + e`. Solving the sparse recovery problem given :math:`\Phi` and :math:`x` involves identifying :math:`I` and estimating :math:`x_I`. Then, the residual is :math:`r = y - \Phi_I x_I`. An important quantity during the sparse recovery is the (squared) norm of the residual :math:`\| r \|_2^2` which is an estimate of the energy of error :math:`e`. This type combines all of this information together. Parameters: x_I : :estimate(s) of :math:`x_I` I : identified index set(s) :math:`I` r : residual(s) :math:`r = y - \Phi_I x_I` r_norm_sqr: squared norm of residual :math:`\| r \|_2^2` iterations: Number of iterations required for the algorithm to converge Note: The tuple can be used to solve multiple measurement vector problems also. In this case, each column (of individual parameters) represents the solution of corresponding single vector problems. """ # The non-zero values x_I: jnp.ndarray """Non-zero values""" I: jnp.ndarray """The support for non-zero values""" r: jnp.ndarray """The residuals""" r_norm_sqr: jnp.ndarray """The residual norm squared""" iterations: int """The number of iterations it took to complete""" length: int """The length of the sparse signal""" @property def x(self): return build_signal_from_indices_and_values(self.length, self.I, self.x_I) def __str__(self): """Returns the string representation """ s = [] r_norm = math.sqrt(float(self.r_norm_sqr)) x_norm = float(norm(self.x_I)) for x in [ u"iterations %s" % self.iterations, f"m={len(self.r)}, n={self.length}, k={len(self.I)}", u"r_norm %e" % r_norm, u"x_norm %e" % x_norm, ]: s.append(x.rstrip()) return u'\n'.join(s)
class PTConfig(NamedTuple): K: int M: int eta: int rho: int class PTConfigurations(NamedTuple): N: int configurations: List[PTConfig] Ms: jax.Array etas: jax.Array rhos: jax.Array reverse_map: Dict class HTPState(NamedTuple): # The non-zero values x_I: jnp.ndarray """Non-zero values""" I: jnp.ndarray """The support for non-zero values""" r: jnp.ndarray """The residuals""" r_norm_sqr: jnp.ndarray """The residual norm squared""" iterations: int """The number of iterations it took to complete""" # Information from previous iteration I_prev: jnp.ndarray x_I_prev: jnp.ndarray r_norm_sqr_prev: jnp.ndarray def __str__(self): """Returns the string representation """ s = [] for x in [ u"iterations %s" % self.iterations, u"r_norm_sqr %e" % self.r_norm_sqr, # u"r_norm_sqr_prev %e" % self.r_norm_sqr_prev, # u"I %s" % self.I, # u"I_prev %s" % self.I_prev, ]: s.append(x.rstrip()) return u'\n'.join(s) IHTState = HTPState CoSaMPState = HTPState SPState = HTPState