Source code for cr.sparse._src.lop.lop

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import reduce
from typing import NamedTuple, Callable, Tuple
import jax
import jax.numpy as jnp

from .impl import _hermitian

[docs]def column(T, i): """Returns the i-th column of the operator T """ e = jnp.zeros(T.shape[1]).at[i].set(1.) return T.times(e)
column = jax.jit(column, static_argnums=(0, 1))
[docs]def columns(T, indices): """Returns the i-th column of the operator T """ n = T.shape[1] k = len(indices) e = jnp.zeros((n, k)) e = e.at[indices, jnp.arange(k)].set(1.) return T.times(e)
columns = jax.jit(columns, static_argnums=(0,))
[docs]class Operator(NamedTuple): """ Represents a finite linear operator :math:`T : A -> B` where :math:`A` and :math:`B` are finite vector spaces. Parameters: times: A function implementing :math:`T(x)` trans: A function implementing :math:`T^H (x)` m: The dimension of the destination vector space :math:`B` n: The dimension of the source vector space :math:`A` linear: Indicates if the operator is linear or not jit_safe: Indicates if the operator can be safely JIT compiled matrix_safe: Indicates if the operator can accept a matrix of vectors real: Indicates if a linear operator is real i.e. has a matrix representation of real numbers Note: While most of the operators in the library are linear operators, some are not. Prominent examples include operators like real part operator, imaginary part operator. These operators are provided for convenience. Most operators in this collection are real. """ times : Callable[[jnp.ndarray], jnp.ndarray] """A linear function mapping from A to B """ trans : Callable[[jnp.ndarray], jnp.ndarray] """Corresponding adjoint linear function mapping from B to A""" shape : Tuple[int, int] """Dimension of the linear operator (m, n)""" linear : bool = True """Indicates if the operator is linear or not""" jit_safe: bool = True """Indicates if the times and trans functions can be safely jit compiled""" matrix_safe: bool = True """Indicates if the operator can accept a matrix of vectors""" real: bool = True """Indicates if a linear operator is real i.e. has a matrix representation of real numbers""" column = column """Returns a specific column of the matrix representation of the operator""" columns = columns """Returns a subset of columns of the matrix representation of the operator""" def __neg__(self): """Returns the nagative of this linear operator""" return neg(self) def __add__(self, other): """Returns the sum of this linear operator with another linear operator""" return add(self, other) def __sub__(self, other): """Returns the subtraction of this linear operator with another linear operator""" return subtract(self, other) def __matmul__(self, other): """Returns the composition of this linear operator with another linear operator""" return compose(self, other) def __pow__(self, n): """Returns a linear operator which works like applying :math:`T` n times""" return power(self, n) def times_2d(self, X): """Computes Y = T X where y = T x for each column in X""" return jax.vmap(self.times, (1), (1))(X) def trans_2d(self, Y): """Computes Y = T^H X where y = T^H x for each column in X""" return jax.vmap(self.trans, (1), (1))(Y) def apply_columns(self, x): """Computes y = T_I x where I is an index set selecting a subset of columns of T""" xr = jnp.zeros(x.shape, x.dtype) xr = xr.at[I].set(x[I]) return self.times(xr) @property def input_ndim(self): """Returns the number of dimensions of input to the operator """ shape = self.shape[1] if isinstance(shape, int): # it appears to be a 1D operator return 1 return len(shape) @property def output_ndim(self): """Returns the number of dimensions of output of the operator """ shape = self.shape[0] if isinstance(shape, int): # it appears to be a 1D operator return 1 return len(shape) @property def input_size(self): """Returns the size of input to the operator """ shape = self.shape[1] if isinstance(shape, int): # it appears to be a 1D operator return shape return reduce(lambda x, y : x * y, shape) @property def output_size(self): """Returns the size of output of the operator """ shape = self.shape[0] if isinstance(shape, int): # it appears to be a 1D operator return shape return reduce(lambda x, y : x * y, shape) @property def input_shape(self): """Returns the shape of input to the operator as a tuple """ shape = self.shape[1] if isinstance(shape, int): # it appears to be a 1D operator return (shape,) return shape @property def output_shape(self): """Returns the shape of output of the operator as a tuple """ shape = self.shape[0] if isinstance(shape, int): # it appears to be a 1D operator return (shape,) return shape
[docs]def jit(operator): """Returns the same linear operator with compiled times and trans functions""" if not operator.jit_safe: raise Exception("This operator is not suitable for JIT compilation.") times = jax.jit(operator.times) trans = jax.jit(operator.trans) return Operator(times=times, trans=trans, shape=operator.shape, matrix_safe=operator.matrix_safe, real=operator.real, linear=operator.linear)
########################################################################################### # # Operator algebra # ########################################################################################### ########################################################################################### # Unary operations on linear operators ###########################################################################################
[docs]def neg(A): r"""Returns the negative of a linear operator :math:`T = -A` Args: A (Operator): A given linear operator Returns: (Operator): A linear operator T such that :math:`T x = - A x` """ times = lambda x : -A.times(x) trans = lambda x : -A.trans(x) return Operator(times=times, trans=trans, shape=A.shape, jit_safe=A.jit_safe, matrix_safe=A.matrix_safe, real=A.real)
[docs]def scale(A, alpha): r"""Returns the linear operator :math:`T = \alpha A` for the operator :math:`A` Args: A (Operator): A given linear operator Returns: (Operator): A linear operator T such that :math:`T x = \alpha A x` and :math:`T^H x = \bar{\alpha} A^H x` """ real = A.real and not isinstance(alpha, complex) alpha_c = jnp.conjugate(alpha) times = lambda x : alpha * A.times(x) trans = lambda x : alpha_c * A.trans(x) return Operator(times=times, trans=trans, shape=A.shape, jit_safe=A.jit_safe, matrix_safe=A.matrix_safe, real=real)
def hermitian(A): r"""Returns the Hermitian transpose of a given operator :math:`T = A^H` Args: A (Operator): A given linear operator Returns: (Operator): A linear operator T such that :math:`T x = A^H x` Note: Deprecated. Use `adjoint` instead. """ m, n = A.shape return Operator(times=A.trans, trans=A.times, shape=(n,m), jit_safe=A.jit_safe, matrix_safe=A.matrix_safe, real=A.real)
[docs]def adjoint(A): r"""Returns the adjoint of a given operator :math:`T = A^H` Args: A (Operator): A given linear operator Returns: (Operator): A linear operator T such that :math:`T x = A^H x` """ m, n = A.shape return Operator(times=A.trans, trans=A.times, shape=(n,m), jit_safe=A.jit_safe, matrix_safe=A.matrix_safe, real=A.real)
[docs]def transpose(A): r"""Returns the transpose of a given operator :math:`T = A^T`""" m, n = A.shape if A.real: times = A.trans trans = A.times else: times = lambda x: _hermitian(A.trans(_hermitian(x))) trans = lambda x: _hermitian(A.times(_hermitian(x))) return Operator(times=times, trans=trans, shape=(n,m), jit_safe=A.jit_safe, matrix_safe=A.matrix_safe, real=A.real)
def apply_n(func, n, x): init = (x, 0) def body(state): x, c = state return func(x), c+1 def cond(state): return state[1] < n state = jax.lax.while_loop(cond, body, init) return state[0] apply_n = jax.jit(apply_n, static_argnums=(0, 1))
[docs]def power(A, p): """Returns the linear operator :math:`T = A^p`""" m, n = A.shape assert m == n times = lambda x :apply_n(A.times, p, x) trans = lambda x : apply_n(A.trans, p, x) return Operator(times=times, trans=trans, shape=A.shape, jit_safe=A.jit_safe, matrix_safe=A.matrix_safe, real=A.real)
def partial_op(A, picks, perm=None): """Returns the linear operator T that computes (A x[perm])[picks] We are allowing for two kinds of randomizations - entries in the model vector x can be permuted (as per a random perm) - From the result y = A x[perm], a limited number of entries can be picked (as per picks) """ assert picks.ndim == 1 m, n = A.shape k = picks.shape[0] if perm is not None: assert perm.ndim == 1 if perm is None: times = lambda x: (A.times(x))[picks] def trans(x): # expand the input with zero entries shape = (m,) + x.shape[1:] tmp = jnp.zeros(shape, dtype=x.dtype) tmp = tmp.at[picks].set(x) # apply the adjoint y = A.trans(tmp) return y else: times = lambda x: (A.times(x[perm]))[picks] def trans(x): # expand the input with zero entries shape = (m,) + x.shape[1:] tmp = jnp.zeros(shape, dtype=x.dtype) tmp = tmp.at[picks].set(x) # apply the adjoint y = A.trans(tmp) # place the output in correct permutation out = jnp.empty(y.shape, dtype=y.dtype) out = out.at[perm].set(y) return out return Operator(times=times, trans=trans, shape=(k, n), jit_safe=A.jit_safe, matrix_safe=A.matrix_safe, real=A.real) ########################################################################################### # Binary operations on linear operators ###########################################################################################
[docs]def add(A, B): """Returns the sum of two linear operators :math:`T = A + B`""" ma, na = A.shape mb, nb = B.shape assert ma == mb assert na == nb jit_safe = A.jit_safe and B.jit_safe matrix_safe = A.matrix_safe and B.matrix_safe real = A.real and B.real times = lambda x: A.times(x) + B.times(x) trans = lambda x: A.trans(x) + B.trans(x) return Operator(times=times, trans=trans, shape=A.shape, jit_safe=jit_safe, matrix_safe=matrix_safe, real=real)
[docs]def subtract(A, B): """Returns a linear operator :math:`T = A - B`""" ma, na = A.shape mb, nb = B.shape assert ma == mb assert na == nb jit_safe = A.jit_safe and B.jit_safe matrix_safe = A.matrix_safe and B.matrix_safe real = A.real and B.real times = lambda x: A.times(x) - B.times(x) trans = lambda x: A.trans(x) - B.trans(x) return Operator(times=times, trans=trans, shape=A.shape, jit_safe=jit_safe, matrix_safe=matrix_safe, real=real)
[docs]def compose(A, B, ignore_compatibility=False, shape=None): """Returns the composite linear operator :math:`T = AB` such that :math:`T(x)= A(B(x))`""" ma, na = A.shape mb, nb = B.shape if not ignore_compatibility: assert na == mb, "Input shape of A must match the output shape of B" jit_safe = A.jit_safe and B.jit_safe matrix_safe = A.matrix_safe and B.matrix_safe real = A.real and B.real times = lambda x: A.times(B.times(x)) trans = lambda x: B.trans(A.trans(x)) if shape is None: shape = (ma, nb) return Operator(times=times, trans=trans, shape=shape, real=real)
[docs]def hcat(A, B): """Returns the linear operator :math:`T = [A \\, B]`""" ma, na = A.shape mb, nb = B.shape assert ma == mb m = ma n = na + nb jit_safe = A.jit_safe and B.jit_safe matrix_safe = A.matrix_safe and B.matrix_safe real = A.real and B.real times = lambda x: A.times(x[:na]) + B.times(x[na:]) trans = lambda x: jnp.concatenate((A.trans(x), B.trans(x))) return Operator(times=times, trans=trans, shape=(m,n), jit_safe=jit_safe, matrix_safe=matrix_safe, real=real)
def gram(A): """Returns the gram of a given operator :math:`T = A^H A`""" m, n = A.shape times = lambda x: A.trans(A.times(x)) return Operator(times=times, trans=times, shape=(n,n), jit_safe=A.jit_safe, matrix_safe=A.matrix_safe, real=A.real) def frame(A): """Returns the frame of a given operator :math:`T = A A^H`""" m, n = A.shape times = lambda x: A.times(A.trans(x)) return Operator(times=times, trans=times, shape=(m, m), jit_safe=A.jit_safe, matrix_safe=A.matrix_safe, real=A.real)