Source code for cr.sparse._src.matrix

# Copyright 2021 CR.Sparse Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp

from .util import promote_arg_dtypes


[docs]def transpose(A): """Returns the transpose of an array Args: A (jax.numpy.ndarray): A JAX array Returns: jax.numpy.ndarray: Transpose of the array """ return jnp.swapaxes(A, -1, -2)
[docs]def hermitian(a): """Returns the conjugate transpose of an array Args: A (jax.numpy.ndarray): A JAX array Returns: jax.numpy.ndarray: Conjugate transpose of the array """ return jnp.conjugate(jnp.swapaxes(a, -1, -2))
[docs]def is_matrix(A): """Checks if an array is a matrix Args: A (jax.numpy.ndarray): A JAX array Returns: bool: True if the array is a matrix, False otherwise. """ return A.ndim == 2
[docs]def is_square(A): """Checks if an array is a square matrix Args: A (jax.numpy.ndarray): A JAX array Returns: bool: True if the array is a square matrix, False otherwise. """ shape = A.shape return A.ndim == 2 and shape[0] == shape[1]
[docs]def is_symmetric(A): """Checks if an array is a symmetric matrix Args: A (jax.numpy.ndarray): A JAX array Returns: bool: True if the array is a symmetric matrix, False otherwise. """ shape = A.shape if A.ndim != 2: return False return jnp.array_equal(A, A.T)
[docs]def is_hermitian(A): """Checks if an array is a Hermitian matrix Args: A (jax.numpy.ndarray): A JAX array Returns: bool: True if the array is a Hermitian matrix, False otherwise. """ shape = A.shape if A.ndim != 2: return False return jnp.array_equal(A, hermitian(A))
[docs]def is_positive_definite(A): """Checks if an array is a symmetric positive definite matrix Args: A (jax.numpy.ndarray): A JAX array Returns: bool: True if the array is a symmetric positive definite matrix, False otherwise. Symmetric positive definite matrices have real and positive eigen values. This function checks if all the eigen values are positive. """ if not is_symmetric(A): return False A = promote_arg_dtypes(A) return jnp.all(jnp.real(jnp.linalg.eigvals(A)) > 0)
[docs]def has_orthogonal_columns(A, atol=1e-6): """Checks if a matrix has orthogonal columns Args: A (jax.numpy.ndarray): A JAX real 2D array Returns: bool: True if the matrix has orthogonal columns, False otherwise. """ G = A.T @ A m = G.shape[0] I = jnp.eye(m) return jnp.allclose(G, I, atol=m*m*atol)
[docs]def has_orthogonal_rows(A, atol=1e-6): """Checks if a matrix has orthogonal rows Args: A (jax.numpy.ndarray): A JAX real 2D array Returns: bool: True if the matrix has orthogonal rows, False otherwise. """ G = A @ A.T m = G.shape[0] I = jnp.eye(m) return jnp.allclose(G, I, atol=m*m*atol)
[docs]def has_unitary_columns(A): """Checks if a matrix has unitary columns Args: A (jax.numpy.ndarray): A JAX real or complex 2D array Returns: bool: True if the matrix has unitary columns, False otherwise. """ G = hermitian(A) @ A m = G.shape[0] I = jnp.eye(m) return jnp.allclose(G, I, atol=m*1e-6)
[docs]def has_unitary_rows(A): """Checks if a matrix has unitary rows Args: A (jax.numpy.ndarray): A JAX real or complex 2D array Returns: bool: True if the matrix has unitary rows, False otherwise. """ G = A @ hermitian(A) m = G.shape[0] I = jnp.eye(m) return jnp.allclose(G, I, atol=m*1e-6)