Source code for cr.sparse._src.dict.grass

# Copyright 2022 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Grassmannian frames
"""
import math

import jax.numpy as jnp
from jax import random, lax, jit

import cr.nimble as crn



[docs]def minimum_coherence(m, n): """Minimum achievable coherence for a Grassmannian frame """ numer = n -m denom = m * (n - 1) return math.sqrt(numer / denom)
[docs]def build_grassmannian_frame(init, frac=0.9, shrink_mu=0.9, iterations=25): """Builds a Grassmannian frame starting from a random matrix Args: init (jax.numpy.ndarray): initial frame frac (float): Threshold for fraction of off diagonal entries to keep/change shrink_mu (float): Factor by which to shrink or expand off diagonal entries iterations (int): Number of iterations for alternate projections Returns: (jax.numpy.ndarray) A frame which is approximately Grassmannian It uses an alternate projections based algorithm. """ m, n = init.shape # initial gram matrix g = init.T @ init # number of off diagonal entries in the Gram matrix n_off = n**2 - n upper_ind = round(frac * n_off) - 1 lower_ind = round((1-frac)*n_off) - 1 # indices for off diagonal entries off_ind = jnp.abs(g.ravel()-1) > 1e-6 #@jit def body_fun(k, g): # flatten the gram matrix gg = g.ravel() # Absolute values of gram matrix abs_g = jnp.abs(gg) # sort the inner products by their absolute values sorted_g = jnp.sort(abs_g) ## Shrink the high inner products # identify coherence values above the threshold upper_th = sorted_g[upper_ind] above_indices = abs_g > upper_th above_indices = jnp.logical_and(off_ind, above_indices) # Update off diagonal entries gg = jnp.where(above_indices, gg * shrink_mu, gg) ## Expand the near zero products lower_th = sorted_g[lower_ind] below_indices = abs_g < lower_th gg = jnp.where(below_indices, gg / shrink_mu, gg) # make the new gram matrix g = jnp.reshape(gg, (n, n)) ## Reduce the rank of g back to m # perform SVD U, s, VT = jnp.linalg.svd(g) # Ensure that all higher singular values are set to 0 s = s.at[m:].set(0) # Reconstruct the Gram matrix g = jnp.dot(U * s, VT) # Ensure that the diagonal elements of G continue to be 1 d = jnp.diag(g) d2 = 1. / jnp.sqrt(d) g = d2 * g * d2 return g # run the alternate projections g = lax.fori_loop(0, iterations, body_fun, g) # final dictionary U, s, VT = jnp.linalg.svd(g) s = s[:m] frame = crn.diag_premultiply(s ** 0.5, U[:, :m].T) return frame