Source code for cr.sparse._src.signal

# Copyright 2021 CR.Sparse Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.numpy as jnp
from jax import random, jit
import jax.numpy.fft as jfft

from .norm import sqr_norms_l2_cw, sqr_norms_l2_rw
from .matrix import is_matrix
from .discrete.number import next_pow_of_2

def find_first_signal_with_energy_le_rw(X, energy):
    """Returns the index of the first row which has energy less than the specified threshold
    """
    assert is_matrix(X)
    energies = sqr_norms_l2_rw(X)
    index = jnp.argmax(energies <= energy)
    return index if energies[index] <= energy else jnp.array(-1)

def find_first_signal_with_energy_le_cw(X, energy):
    """Returns the index of the first column which has energy less than the specified threshold
    """
    assert is_matrix(X)
    energies = sqr_norms_l2_cw(X)
    index = jnp.argmax(energies <= energy)
    return index if energies[index] <= energy else jnp.array(-1)


[docs]def randomize_rows(key, X): """Randomizes the rows in X """ assert is_matrix(X) m, n = X.shape r = random.permutation(key, m) return X[r, :]
[docs]def randomize_cols(key, X): """Randomizes the columns in X """ assert is_matrix(X) m, n = X.shape r = random.permutation(key, n) return X[:, r]
[docs]def largest_indices(x, K): """Returns the indices of K largest entries in x by magnitude """ indices = jnp.argsort(jnp.abs(x)) return indices[:-K-1:-1]
[docs]def largest_indices_rw(X, K): """Returns the indices of K largest entries by magnitude in each row of X """ indices = jnp.argsort(jnp.abs(X), axis=1) return indices[:, :-K-1:-1]
[docs]def largest_indices_cw(X, K): """Returns the indices of K largest entries by magnitude in each column of X """ indices = jnp.argsort(jnp.abs(X), axis=0) return indices[:-K-1:-1, :]
[docs]def take_along_rows(X, indices): """Picks K entries from each row of X specified by indices matrix """ return jnp.take_along_axis(X, indices, axis=1)
[docs]def take_along_cols(X, indices): """Picks K entries from each column of X specified by indices matrix """ return jnp.take_along_axis(X, indices, axis=0)
[docs]def sparse_approximation(x, K): """Keeps only largest K non-zero entries by magnitude in a vector x """ if K == 0: return x.at[:].set(0) indices = jnp.argsort(jnp.abs(x)) #print(x, K, indices) return x.at[indices[:-K]].set(0)
[docs]def sparse_approximation_cw(X, K): #return jax.vmap(sparse_approximation, in_axes=(1, None), out_axes=1)(X, K) """Keeps only largest K non-zero entries by magnitude in each column of X """ if K == 0: return X.at[:].set(0) indices = jnp.argsort(jnp.abs(X), axis=0) for c in range(X.shape[1]): ind = indices[:-K, c] X = X.at[ind, c].set(0) return X
[docs]def sparse_approximation_rw(X, K): """Keeps only largest K non-zero entries by magnitude in each row of X """ if K == 0: return X.at[:].set(0) indices = jnp.argsort(jnp.abs(X), axis=1) for r in range(X.shape[0]): ind = indices[r, :-K] X = X.at[r, ind].set(0) return X
[docs]def build_signal_from_indices_and_values(length, indices, values): """Builds a sparse signal from its non-zero entries (specified by their indices and values) """ x = jnp.zeros(length) indices = jnp.asarray(indices) values = jnp.asarray(values) return x.at[indices].set(values)
[docs]def nonzero_values(x): """Returns the values of non-zero entries in x """ return x[x != 0]
[docs]def nonzero_indices(x): """Returns the indices of non-zero entries in x """ return jnp.nonzero(x)[0]
def support(x): """Returns the indices of non-zero entries in x """ return jnp.nonzero(x)[0]
[docs]def hard_threshold(x, K): """Returns the indices and corresponding values of largest K non-zero entries in a vector x """ indices = jnp.argsort(jnp.abs(x)) I = indices[:-K-1:-1] x_I = x[I] return I, x_I
[docs]def hard_threshold_sorted(x, K): """Returns the sorted indices and corresponding values of largest K non-zero entries in a vector x """ # Sort entries in x by their magnitude indices = jnp.argsort(jnp.abs(x)) # Pick the indices of K-largest (magnitude) entries in x (from behind) I = indices[:-K-1:-1] # Make sure that indices are sorted in ascending order I = jnp.sort(I) # Pick corresponding values x_I = x[I] return I, x_I
def hard_threshold_by(x, t): """ Sets all entries in x to be zero which are less than t in magnitude """ valid = jnp.abs(x) >= t return x * valid def largest_indices_by(x, t): """ Returns the locations of all entries in x which are larger than t in magnitude """ return jnp.where(jnp.abs(x) >= t)[0]
[docs]def dynamic_range(x): """Returns the ratio of largest and smallest values (by magnitude) in x (dB) """ x = jnp.sort(jnp.abs(x)) return 20 * jnp.log10(x[-1] / x[0])
[docs]def nonzero_dynamic_range(x): """Returns the ratio of largest and smallest non-zero values (by magnitude) in x (dB) """ x = nonzero_values(x) return dynamic_range(x)
def normalize(data, axis=-1): """Normalizes a data vector (data - mu) / sigma """ mu = jnp.mean(data, axis) data = data - mu variance = jnp.var(data, axis) data = data / jnp.sqrt(variance) return data normalize_jit = jit(normalize, static_argnums=(1,)) def frequency_spectrum(x, dt=1.): """Frequency spectrum of 1D data using FFT """ n = len(x) nn = next_pow_of_2(n) X = jfft.fft(x, nn) f = jfft.fftfreq(nn, d=dt) X = jfft.fftshift(X) f = jfft.fftshift(f) return f, X def power_spectrum(x, dt=1.): """Power spectrum of 1D data using FFT """ n = len(x) T = dt * n f, X = frequency_spectrum(x, dt) nn = len(f) n2 = nn // 2 f = f[n2:] X = X[n2:] sxx = (X * jnp.conj(X)) / T sxx = jnp.abs(sxx) return f, sxx def energy(data, axis=-1): """ Computes the energy of the signal along the specified axis """ power = jnp.abs(data) ** 2 return jnp.sum(power, axis) def interpft(x, N): """Interpolates x to n points in Fourier Transform domain """ n = len(x) assert n < N a = jfft.fft(x) nyqst = (n + 1) // 2 z = jnp.zeros(N -n) a1 = a[:nyqst+1] a2 = a[nyqst+1:] b = jnp.concatenate((a1, z, a2)) if n % 2 == 0: b = b.at[nyqst].set(b[nyqst] /2 ) b = b.at[nyqst + N -n].set(b[nyqst]) y = jfft.ifft(b) if jnp.isrealobj(x): y = jnp.real(y) # scale it up y = y * (N / n) return y