Source code for cr.sparse._src.opt.smooth.smooth

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from functools import reduce
from typing import NamedTuple, Callable, Tuple
import jax
from jax import jit, grad
import jax.numpy as jnp

import cr.nimble as cnb


[docs]class SmoothFunction(NamedTuple): r"""Represents a smooth function Let `op` be a variable of type `SmoothFunction` which represents some smooth function :math:`f`. Then: * `op.func(x)` returns the function value :math:`f(x)`. * `op.grad(x)` returns the gradient of function :math:`g(x) = \nabla f(x)`. * `op.grad_val(x)` returns the pair :math:`(g(x), f(x))`. """ func: Callable[[jnp.ndarray], float] """Definition of a smooth function""" grad: Callable[[jnp.ndarray, float], jnp.ndarray] """Definition of a gradient the function""" grad_val: Callable[[jnp.ndarray, float], Tuple[float, jnp.ndarray]] "A wrapper to evaluate the gradient vector and the function value together"
def build(func): r"""Creates a smooth function based on function definition :math:`f(x)` Args: func: Definition of the smooth function :math:`f : \RR^n \to \RR` Returns: SmoothFunction: A smooth function wrapper """ gradient = grad(func) func = jit(func) gradient = jit(gradient) grad_val = build_grad_val_func(func, gradient) return SmoothFunction(func=func, grad=gradient, grad_val=grad_val) def build2(func, grad): r"""Creates a smooth function with user defined :math:`f(x)` and gradient :math:`g(x)` Args: func: Definition of the smooth function :math:`f : \RR^n \to \RR` grad: Definition of the gradient :math:`g = \nabla f : \RR^n \to \RR^n` Returns: SmoothFunction: A smooth function wrapper """ func = jit(func) grad = jit(grad) grad_val = build_grad_val_func(func, grad) return SmoothFunction(func=func, grad=grad, grad_val=grad_val) def build3(func, grad, grad_val): r"""Creates a a smooth function with user defined grad and grad_val functions Args: func: Definition of the smooth function :math:`f : \RR^n \to \RR` grad: Definition of the gradient :math:`g = \nabla f : \RR^n \to \RR^n` grad_val: Definition of a combined function which computes the pair :math:`(g(x), f(x))` Returns: SmoothFunction: A smooth function wrapper """ func = jit(func) grad = jit(grad) grad_val = jit(grad_val) return SmoothFunction(func=func, grad=grad, grad_val=grad_val)
[docs]def build_grad_val_func(func, grad): r"""Constructs a `grad_val` function from the definitions of function :math:`f(x)` and gradient :math:`g(x)` Args: func: Definition of the smooth function :math:`f : \RR^n \to \RR` grad: Definition of the gradient :math:`g = \nabla f : \RR^n \to \RR^n` Returns: A function which computes the pair :math:`(g(x), f(x))` for input :math:`x` """ @jit def impl(x): g = grad(x) v = func(x) return g, v return impl
[docs]def smooth_func_translate(smooth_func, b): r"""Returns a smooth function :math:`g` for a smooth function :math:`f` s.t. :math:`g(x) = f(x + b)` Args: smooth_func (SmoothFunction): Wrapper for smooth function :math:`f : \RR^n \to \RR` b (jax.numpy.ndarray): The offset/translation vector :math:`b \in \RR^n` Returns: SmoothFunction: A smooth function wrapper for the function :math:`g` such that :math:`g(x) = f(x+b)` """ b = jnp.asarray(b) b = cnb.promote_arg_dtypes(b) @jit def func(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) return smooth_func.func(x + b) @jit def grad(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) return smooth_func.grad(x + b) @jit def grad_val(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) return smooth_func.grad_val(x + b) return SmoothFunction(func=func, grad=grad, grad_val=grad_val)