Source code for cr.sparse._src.tools.performance

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import jax.numpy as jnp

from cr.nimble.dsp import (
    nonzero_indices, 
    largest_indices, 
    build_signal_from_indices_and_values,
    dynamic_range,
    nonzero_dynamic_range
)

[docs]class RecoveryPerformance: """Performance of a sparse signal recovery operation * Synthesis :math:`y = \\Phi x + e` * Recovery: :math:`y = \\Phi \\hat{x} + r` * Representation error: :math:`h = x - \\hat{x}` * Residual: :math:`y - \\Phi \\hat{x}` """ M : int = 0 """Signal/Measurement space dimension, number of rows in :math:`\\Phi`""" N : int = 0 """Representation space dimension, number of atoms/columns in :math:`\\Phi`""" K: int = 0 """Number of non-zero entries in :math:`x`""" T0 = [] """The index set of K non-zero coefficients in :math:`x`""" x_norm: float = 0 """norm of representation :math:`x`""" y_norm: float = 0 """norm of measurement/signal :math:`y`""" x_hat_norm: float = 0 """norm of the reconstruction :math:`\\hat{x}`""" x_dr: float = 0 """ Dynamic range of x """ y_dr: float = 0 """ Dynamic range of y """ x_hat_dr: float = 0 """ Dynamic range of x_hat """ h = [] """Recovery/reconstruction error :math:`h = x - \\hat{x}`""" h_norm: float = 0 """Norm of reconstruction error :math:`h`""" recovery_snr: float = 0 """Reconstruction/recovery SNR (dB) in representation space :math:`20 \\log (\\| x \\|_2 / \\| h \\|_2)`""" R0 = [] """Index set of K largest (magnitude) entries in the reconstruction :math:`\\hat{x}`""" overlap = [] """Indices overlapping between T0 and R0 :math:`T_0 \\cap R_0`""" num_correct_atoms : int = 0 """Number of entries in the overlap, i.e. number of indices of the support correctly recovered""" r = [] """The residual :math:`r = y - \\Phi \\hat{x}`""" r_norm : float = 0 """Norm of the residual""" measurement_snr: float = 0 """Measurement SNR (dB) in measurement/signal space :math:`20 \\log (\\| y \\|_2 / \\| r \\|_2)`"""
[docs] def __init__(self, Phi, y, x, x_hat=None, sol=None): """Computes all parameters related to the quality of reconstruction """ # Shape of the dictionary/sensing matrix M, N = Phi.shape if sol is not None: if 'x' in sol: x_hat = sol.x else: x_hat = build_signal_from_indices_and_values(N, sol.I, sol.x_I) self.T0 = nonzero_indices(x) K = self.T0.size self.M = M self.N = N self.K = K # Norm of representation self.x_norm = jnp.linalg.norm(x) # Norm of measurement/signal self.y_norm = jnp.linalg.norm(y) # Norm of the reconstructed representation self.x_hat_norm = jnp.linalg.norm(x_hat) self.x_dr = nonzero_dynamic_range(x) self.y_dr = dynamic_range(y) self.x_hat_dr = nonzero_dynamic_range(x_hat) # recovery error vector. N length vector h = x - x_hat self.h = h # l_2 norm of representation error self.h_norm = jnp.linalg.norm(h) # recovery SNR self.recovery_snr = 20 * jnp.log10(self.x_norm / self.h_norm) # The portion of recovery error over T0 K length vector self.h_T0 = h[self.T0] # Positions of other places (set of indices) index_set = jnp.arange(N) self.T0C = jnp.setdiff1d(index_set , self.T0) # Recovery error at T0C places N length vector hT0C = h.at[self.T0].set(0) self.h_T0C = hT0C # The K largest indices after T0 in recovery error (set of indices) self.T1 = largest_indices(hT0C, K) # The recovery error component over T1. [K] length vector. self.h_T1 = h[self.T1] # Remaining indices [N - 2K] set of indices self.TRest = jnp.setdiff1d(self.T0C , self.T1) # Recovery error over remaining indices [N - 2K] length vector self.h_TRest = h[self.TRest] # largest indices of the recovered vector self.R0 = jnp.sort(largest_indices(x_hat, K)) # Support Overlap self.overlap = jnp.intersect1d(self.T0, self.R0) self.num_correct_atoms = self.overlap.size # measurement/signal residual vector [M] length vector r = y - (Phi.times(x_hat) if hasattr(Phi, 'times') else Phi @ x_hat) self.r = r # Norm of measurement error. This must be less than epsilon self.r_norm = jnp.linalg.norm(r) # Measurement SNR self.measurement_snr = 20 * jnp.log10(self.y_norm / self.r_norm) # Ratio between the norm of recovery error and measurement error self.h_by_r_norms = self.h_norm / self.r_norm
def print(self, details=False): """Prints metrics related to reconstruction quality""" print(f'M: {self.M}, N: {self.N}, K: {self.K}') print(f'x_norm: {self.x_norm:.3f}, y_norm: {self.y_norm:.3f}') print(f'x_hat_norm: {self.x_hat_norm:.3f}, h_norm: {self.h_norm:.2e}, r_norm: {self.r_norm:.2e}') print(f'recovery_snr: {self.recovery_snr:.2f} dB, measurement_snr: {self.measurement_snr:.2f} dB') print(f'x_dr: {self.x_dr:.2f} dB, y_dr: {self.y_dr:.2f} dB, x_hat_dr: {self.x_hat_dr:.3f} dB') print(f'Correct atoms: {self.num_correct_atoms}. Ratio: {self.support_recovery_ratio:.2f}, perfect_support_recovery: {self.perfect_support_recovery}') print(f'success: {self.success}') if details: print(f'T0: {self.T0}') print(f'R0: {self.R0}') print(f'Overlap: {self.overlap}') @property def support_recovery_ratio(self): """Returns the ratio of correctly recovered atoms""" return self.num_correct_atoms / self.K @property def perfect_support_recovery(self): """Returns if the support has been recovered perfectly""" return self.num_correct_atoms >= self.K @property def success(self): """Returns True if more than 75% indices are correctly identified and recovery SNR is high (> 30 dB)""" return bool(self.support_recovery_ratio > 0.75) and bool(self.recovery_snr > 30)