Source code for cr.sparse._src.opt.proximal_ops.lpnorms

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from jax import jit, lax

import jax.numpy as jnp
import cr.nimble as cnb
import cr.sparse.opt as opt

from .prox import build, build_from_ind_proj


[docs]def prox_l2(q=1.): r"""Returns a prox-capable wrapper for the function :math:`f(x) = \| q x \|_2` Returns: ProxCapable: A prox-capable function """ q = jnp.asarray(q) @jit def func(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) v = cnb.arr_l2norm(x) return q*v @jit def proximal_op(x, t): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) v = cnb.arr_l2norm(x) s = 1 - 1 / jnp.maximum( v / ( t * q ), 1. ) x = x * s return x return build(func, proximal_op)
[docs]def prox_l1(q=1.): r"""Returns a prox-capable wrapper for the function :math:`f(x) = \| q x \|_1` Returns: ProxCapable: A prox-capable function """ q = jnp.asarray(q) @jit def func(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) v = cnb.arr_l1norm(q*x) return v @jit def proximal_op(x, t): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) tq = t * q # shrinkage coefficients s = 1 - jnp.minimum( tq/jnp.abs(x), 1 ) # shrink x return x * s return build(func, proximal_op)
[docs]def prox_l1_pos(q=1.): r"""Returns a prox-capable wrapper for the function :math:`f(x) = \| q x \|_1 + I({x \geq 0})` Returns: ProxCapable: A prox-capable function The domain of :math:`f` is restricted to non-negative vectors. This is enforced by the indicator function component :math:`I({x \geq 0})` in the definition of :math:`f`. """ q = jnp.asarray(q) @jit def func(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) # check if any of the entries in x is negative is_invalid = jnp.any(x < 0) return lax.cond(is_invalid, # this x is outside the domain lambda _: jnp.inf, # x is inside the domain, we compute its l1-norm lambda _: cnb.arr_l1norm(q*x), None) @jit def proximal_op(x, t): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) tq = t * q # shrinkage only applies on the positive side. negative values are mapped to 0 return jnp.maximum(0, x - tq) return build(func, proximal_op)
[docs]def prox_l1_ball(q=1.): r"""Returns a prox-capable wrapper for the l1-ball :math:`\{ x : \| x \|_1 \leq q \}` indicator Returns: ProxCapable: A prox-capable function """ ind = opt.indicator_l1_ball(q=q) proj = opt.proj_l1_ball(q=q) return build_from_ind_proj(ind, proj)