Source code for cr.sparse._src.cluster.ssc.omp

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jax import jit, vmap
import jax.numpy as jnp

import cr.nimble as cnb

submat_multiplier = vmap(cnb.mult_with_submatrix, (None, 1, 1), 1)
submat_solver = vmap(cnb.solve_on_submatrix, (None, 1, 1), (1, 1,))


[docs]def build_representation_omp(X, K): """Builds K-sparse self-expressive representations of vectors in X in terms of other vectors in X """ # Ambient dimension D = X.shape[0] # Number of data points N = X.shape[1] # We normalize the columns of X Xn = cnb.normalize_l2_cw(X) # The dictionary Dict = Xn # The residual R = Xn # Let's conduct first iteration of OMP # The proxy representation P = Dict.T @ R # First correlation of residual with signal # Set the diagonal to zero H = cnb.set_diagonal(P, 0) # Index of best match indices = cnb.abs_max_idx_cw(H) # Initialize the array of selected indices # with current indices as the first row I = indices[jnp.newaxis, :] Z, R = submat_solver(Dict, I, X) # conduct OMP iterations for k in range(1, K): # compute the correlations H = cnb.set_diagonal(Dict.T @ R, 0) # Index of best match indices = cnb.abs_max_idx_cw(H) # Update the set of indices I = jnp.vstack((I, indices)) # Solve over these indices Z, R = submat_solver(Dict, I, X) return Z, I, R
build_representation_omp_jit = jit(build_representation_omp, static_argnums=(1,))
[docs]def batch_build_representation_omp(X, K, batch_size): """Builds K-sparse self-expressive representations of vectors in X in terms of other vectors in X """ # Ambient dimension D = X.shape[0] # Number of data points N = X.shape[1] n_batches = (N + batch_size - 1) // batch_size starts = jnp.arange(n_batches) * batch_size ends = starts + batch_size starts = [i*batch_size for i in range(n_batches)] ends = [start+batch_size for start in starts] # last batch end would be different ends[-1] = N # We normalize the columns of X Xn = cnb.normalize_l2_cw(X) # The dictionary Dict = Xn Z_res = jnp.empty((K, N)) I_res = jnp.empty((K, N), dtype=int) R_res = jnp.empty((D, N)) # diagonal indices rr, c = jnp.diag_indices(batch_size) for batch in range(n_batches): start = starts[batch] end = ends[batch] # number of signals in the batch n2 = end - start # rows for setting the self inner product to 0. r = rr + start # Let's restrict our attention to this batch only X_batch = Xn[:, start:end] # The residual R = X_batch # Let's conduct first iteration of OMP # The proxy representation P = Dict.T @ R # First correlation of residual with signal # Set the diagonal to zero H = P.at[(r, c)].set(0) # Index of best match indices = cnb.abs_max_idx_cw(H) # Initialize the array of selected indices # with current indices as the first row I_res = I_res.at[0, start:end].set(indices) Z, R = submat_solver(Dict, I_res[:1, start:end], X_batch) # conduct OMP iterations for k in range(1, K): # compute the correlations H = Dict.T @ R # Set the diagonal to zero H = H.at[(r, c)].set(0) # Index of best match indices = cnb.abs_max_idx_cw(H) # Update the set of indices I_res = I_res.at[k, start:end].set(indices) # Solve over these indices Z, R = submat_solver(Dict, I_res[:k+1, start:end], X_batch) Z_res = Z_res.at[:, start:end].set(Z) R_res = R_res.at[:, start:end].set(R) return Z_res, I_res, R_res
batch_build_representation_omp_jit = jit(batch_build_representation_omp, static_argnums=(1,2))