Source code for cr.sparse._src.pursuit.defs

# Copyright 2021 CR.Sparse Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple, List, Dict
from dataclasses import dataclass
import jax.numpy as jnp
from jax.tree_util import register_pytree_node

@dataclass
class SingleRecoverySolution:
    signals: jnp.DeviceArray = None
    representations : jnp.DeviceArray = None
    residuals : jnp.DeviceArray =  None
    residual_norms : jnp.DeviceArray = None
    iterations: int = None
    support : jnp.DeviceArray = None

[docs]class RecoverySolution(NamedTuple): """Represents the solution of a sparse recovery problem Consider a sparse recovery problem :math:`y=\Phi x + e`. Assume that :math:`x` is supported on an index set :math:`I` i.e. the non-zero values of :math:`x` are in the sub-vector :math:`x_I`, then the equation can be rewritten as :math:`y = \Phi_I x_I + e`. Solving the sparse recovery problem given :math:`\Phi` and :math:`x` involves identifying :math:`I` and estimating :math:`x_I`. Then, the residual is :math:`r = y - \Phi_I x_I`. An important quantity during the sparse recovery is the (squared) norm of the residual :math:`\| r \|_2^2` which is an estimate of the energy of error :math:`e`. This type combines all of this information together. Parameters: x_I : :estimate(s) of :math:`x_I` I : identified index set(s) :math:`I` r : residual(s) :math:`r = y - \Phi_I x_I` r_norm_sqr: squared norm of residual :math:`\| r \|_2^2` iterations: Number of iterations required for the algorithm to converge Note: The tuple can be used to solve multiple measurement vector problems also. In this case, each column (of individual parameters) represents the solution of corresponding single vector problems. """ # The non-zero values x_I: jnp.ndarray """Non-zero values""" I: jnp.ndarray """The support for non-zero values""" r: jnp.ndarray """The residuals""" r_norm_sqr: jnp.ndarray """The residual norm squared""" iterations: int """The number of iterations it took to complete"""
class PTConfig(NamedTuple): K: int M: int eta: int rho: int class PTConfigurations(NamedTuple): N: int configurations: List[PTConfig] Ms: jnp.DeviceArray etas: jnp.DeviceArray rhos: jnp.DeviceArray reverse_map: Dict class HTPState(NamedTuple): # The non-zero values x_I: jnp.ndarray """Non-zero values""" I: jnp.ndarray """The support for non-zero values""" r: jnp.ndarray """The residuals""" r_norm_sqr: jnp.ndarray """The residual norm squared""" iterations: int """The number of iterations it took to complete""" # Information from previous iteration I_prev: jnp.ndarray x_I_prev: jnp.ndarray r_norm_sqr_prev: jnp.ndarray IHTState = HTPState CoSaMPState = HTPState