Source code for cr.sparse._src.lop.calculus

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import jax.numpy as jnp

from .impl import _hermitian
from .lop import Operator


FORWARD_DERIVATIVE_FILTER = jnp.array([1., -1.])

def _derivative_fwd(x, dx):
    append = jnp.array([x[-1]])
    return jnp.diff(x, append=append) / dx

def _derivative_fwd_adj(x, dx):
    x = x.at[-1].set(0)
    prepend = jnp.array([0])
    return jnp.diff(-x, prepend=prepend) / dx

def _derivative_bwd(x, dx):
    prepend = jnp.array([x[0]])
    return jnp.diff(x, prepend=prepend) / dx

def _derivative_bwd_adj(x, dx):
    x = x.at[0].set(0)
    append = jnp.array([0])
    return jnp.diff(-x, append=append) / dx

def _derivative_centered(x, dx):
    diffs = (0.5 * x[2:] - 0.5 * x[:-2]) / dx
    return jnp.pad(diffs, (1,1))

def _derivative_centered_adj(x, dx):
    y = jnp.zeros(x.shape)
    y = y.at[0:-2].add(-0.5*x[1:-1])
    y = y.at[2:].add(0.5*x[1:-1])
    return y


[docs]def first_derivative(n, dx=1., kind='centered'): """Computes the first derivative """ if kind == 'forward': times = partial(_derivative_fwd, dx=dx) trans = partial(_derivative_fwd_adj, dx=dx) elif kind == 'backward': times = partial(_derivative_bwd, dx=dx) trans = partial(_derivative_bwd_adj, dx=dx) elif kind == 'centered': times = partial(_derivative_centered, dx=dx) trans = partial(_derivative_centered_adj, dx=dx) else: raise NotImplementedError() return Operator(times=times, trans=trans, shape=(n,n))
[docs]def second_derivative(n, dx=1.): filter = jnp.array([1., -2., 1.]) / dx / dx times = lambda x : jnp.pad(jnp.convolve(x, filter, 'valid'), (1,1)) trans = lambda x : jnp.convolve(x[1:-1], filter, 'full') return Operator(times=times, trans=trans, shape=(n,n))