Source code for cr.sparse._src.opt.proximal_ops.prox_sorted_l1
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import jit, lax
from jax.ops import segment_sum
import jax.numpy as jnp
import cr.nimble as cnb
import cr.sparse.opt as opt
from .prox import build, build_from_ind_proj
"""
See https://github.com/google/jax/discussions/8862 for the cool trick of averaging of increasing segments
"""
def prox_ordered_l1_b(state):
y, l = state
x = y - l
n = len(x)
mask = jnp.zeros(len(x), dtype=bool).at[1:].set(jnp.diff(x) > 0)
segment_ids = jnp.cumsum(~mask) - 1
y_sums = segment_sum(y, segment_ids, num_segments=n)
l_sums = segment_sum(l, segment_ids, num_segments=n)
norms = segment_sum(jnp.ones_like(x), segment_ids, num_segments=n)
y = (y_sums / norms)[segment_ids]
l = (l_sums / norms)[segment_ids]
return y, l
def is_not_nonincreasing(state):
y, l = state
return jnp.logical_not(cnb.is_nonincreasing_vec(y -l))
def prox_ordered_l1_a(y, l):
l = jnp.ravel(l)
# convert them to 1d arrays
y = jnp.ravel(y)
# get the sign vector of y
sgn = jnp.sign(y)
# take the absolute values
y = jnp.abs(y)
# sort entries in y by magnitude
idx = jnp.argsort(y)
# go in descending order
idx = idx[::-1]
y = y[idx]
# make sure that lambda and y are in same shape
l = jnp.broadcast_to(l, y.shape)
state = y, l
state = lax.while_loop(is_not_nonincreasing, prox_ordered_l1_b, state)
y, l = state
x = (y - l)
# keep only the postive part
x = jnp.where(x > 0, x , 0)
# restore x at the original indices
x = jnp.zeros_like(x).at[idx].set(x)
# restore the sign
x = sgn * x
return x
[docs]def prox_owl1(lambda_ = 1.):
r"""Returns a prox-capable wrapper for the ordered and weighted l1-norm function ``f(x) = sum(lambda * sort(abs(x), 'descend'))``
Args:
lambda_ (jax.numpy.ndarray): A strictly positive vector which is sorted in decreasing order
Returns:
ProxCapable: A prox-capable function
Let :math:`x \in \RR^n`. Let :math:`|x|` represent a vector of absolute values of entries in :math:`x`.
Let :math:`|x|_{\downarrow}` represent a vector consisting of entries in :math:`|x|` sorted in descending order.
Let :math:`|x|_{(1)} \geq |x|_{(2)} \geq |x|_{(3)} \geq \dots \geq |x|_{(n)}` represent the order statistic of :math:`x`,
i.e. entries in :math:`x` arranged in descending order by magnitude.
Let :math:`\lambda \in \RR^n_{+}` be a weight vector such that
:math:`\lambda_1 \geq \lambda_2 \geq \dots \geq \lambda_n` and :math:`\lambda \neq 0` i.e.
not all entries in :math:`\lambda` are zero.
Then the ordered weighted :math:`\ell_1` norm of :math:`x` w.r.t. the weight vector :math:`\lambda` is defined as:
.. math::
J_{\lambda} (x) = \sum_{1}^n \lambda_i | x |_{(i)}
The function is computed in following steps:
- Take absolute values of entries in x
- Sort the entries of x in descending order
- Multiply the sorted entries with entries in lambda (component wise)
- Compute the sum of the entries
For the derivation of the proximal operator for the ordered and weighted l1 norm, see :cite:`lgorzata2013statistical`.
"""
lambda_ = jnp.asarray(lambda_)
lambda_ = cnb.promote_arg_dtypes(lambda_)
lambda_ = jnp.ravel(lambda_)
@jit
def func(x):
x = jnp.asarray(x)
x = cnb.promote_arg_dtypes(x)
# take absolute values
x = jnp.abs(x)
# convert x to 1d array
x = jnp.ravel(x)
# sort the entries in ascending order
x = jnp.sort(x)
# reverse the order
x = x[::-1]
# compute element wise product
x = lambda_ * x
# return the sum
return jnp.sum(x)
@jit
def proximal_op(x, t):
# make sure that x is a JAX array
x = jnp.asarray(x)
# make sure that x is float
x = cnb.promote_arg_dtypes(x)
# capture original shape
shape = x.shape
# convert x to 1d array
x = jnp.ravel(x)
# compute the proximal vector
z = prox_ordered_l1_a(x, t*lambda_)
# put it back into original shape
z = jnp.reshape(z, shape)
return z
return build(func, proximal_op)