Source code for cr.sparse._src.lop.basic

# Copyright 2021 CR.Sparse Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp

from .impl import _hermitian
from .lop import Operator

###########################################################################################
#  Basic operators
###########################################################################################

[docs]def identity(m, n=None): """Returns an identity linear operator from A to B""" n = m if n is None else n times = lambda x: x trans = lambda x : x return Operator(times=times, trans=trans, shape=(m,n))
[docs]def matrix(A): """Converts a two-dimensional matrix to a linear operator""" m, n = A.shape times = lambda x: A @ x trans = lambda x : _hermitian(_hermitian(x) @ A ) return Operator(times=times, trans=trans, shape=(m,n))
[docs]def diagonal(d): """Returns a linear operator which can be represented by a diagonal matrix""" assert d.ndim == 1 n = d.shape[0] times = lambda x: d * x trans = lambda x: _hermitian(d) * x return Operator(times=times, trans=trans, shape=(n,n))
[docs]def zero(m,n=None): """Returns a linear operator which maps everything to 0 vector in data space""" n = m if n is None else n times = lambda x: jnp.zeros( (m,) + x.shape[1:], dtype=x.dtype) trans = lambda x: jnp.zeros((n,) + x.shape[1:], dtype=x.dtype) return Operator(times=times, trans=trans, shape=(m,n))
[docs]def flipud(n): """Returns an operator which flips the order of entries in input upside down""" times = lambda x: jnp.flipud(x) trans = lambda x: jnp.flipud(x) return Operator(times=times, trans=trans, shape=(n,n))
[docs]def sum(n): """Returns an operator which computes the sum of a vector""" times = lambda x: jnp.sum(x, keepdims=True, axis=0) trans = lambda x: jnp.repeat(x, n, axis=0) return Operator(times=times, trans=trans, shape=(1,n))
[docs]def pad_zeros(n, before, after): """Adds zeros before and after a vector. Note: This operator is not JIT compliant """ pad_1_dim = (before, after) pad_2_dim = ((before, after), (0, 0)) m = before + n + after def times(x): return jnp.pad(x, pad_1_dim) def trans(x): return x[before:before+n] return Operator(times=times, trans=trans, shape=(m,n), matrix_safe=False)
[docs]def real(n): """Returns the real parts of a vector of complex numbers Note: This is a self-adjoint operator. This is not a linear operator. """ times = lambda x: jnp.real(x) trans = lambda x: jnp.real(x) return Operator(times=times, trans=trans, shape=(n,n), linear=False)
[docs]def symmetrize(n): """An operator which constructs a symmetric vector by pre-pending the input in reversed order """ times = lambda x: jnp.concatenate((jnp.flipud(x), x)) trans = lambda x: x[n:] + x[n-1::-1] return Operator(times=times, trans=trans, shape=(2*n,n))
[docs]def restriction(n, indices): """An operator which computes y = x[I] over an index set I """ k = len(indices) times = lambda x: x[indices] trans = lambda x: jnp.zeros((n,)+x.shape[1:]).at[indices].set(x) return Operator(times=times, trans=trans, shape=(k,n))