Source code for cr.sparse._src.opt.projectors.basic

# Copyright 2021 CR-Suite Development Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import math

from jax import jit, lax

import jax.numpy as jnp
from jax.numpy.linalg import qr, norm
from jax.scipy.linalg import cholesky

import cr.nimble as cnb

[docs]def proj_zero(): r"""Projects to the zero vector for :math:`\RR^n` """ @jit def projector(x): x = jnp.asarray(x) return jnp.zeros_like(x) return projector
[docs]def proj_identity(): r"""Projects on :math:`\RR^n` (x => x) """ @jit def projector(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) return x return projector
[docs]def proj_singleton(c): c = jnp.asarray(c) c = cnb.promote_arg_dtypes(c) @jit def projector(x): x = jnp.asarray(x) return jnp.broadcast_to(c, x.shape) return projector
[docs]def proj_affine(A, b=0): """Returns a function which projects a point to the solution set A x = b """ # A must be specified A = jnp.asarray(A) A = cnb.promote_arg_dtypes(A) # b by default is 0 b = jnp.asarray(b) b = cnb.promote_arg_dtypes(b) # Compute the QR decomposition of A R = qr(A.T, 'r') @jit def projector(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) r = b - A @ x v = cnb.solve_UTx_b(R, r) w = cnb.solve_Ux_b(R, v) h = A.T @ w return x + h return projector
[docs]def proj_box(l=None, u=None): """Projector function for element-wise lower and upper bounds """ if l is None and u is None: raise ValueError("At least lower or upper bound must be defined.") if l is not None: l = jnp.asarray(l) l = cnb.promote_arg_dtypes(l) if u is not None: u = jnp.asarray(u) u = cnb.promote_arg_dtypes(u) @jit def lower_bound(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) return jnp.maximum(x, l) @jit def upper_bound(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) return jnp.minimum(x, u) @jit def box_bound(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) x = jnp.maximum(x, l) x = jnp.minimum(x, u) return x if l is None: return upper_bound if u is None: return lower_bound return box_bound
def proj_box_affine(l, u, a, alpha=0., tol=1e-6): """Projector function for the constraints l <= x <= u and a' x = alpha #TODO complete this """ if a is None: raise ValueError("a is required") a = jnp.asarray(a) a = cnb.promote_arg_dtypes(a) n = a.size if l is None: l = jnp.full_like(a, -jnp.inf) if u is None: u = jnp.full_like(a, jnp.inf) @jit def box_bound(x): x = jnp.maximum(x, l) x = jnp.minimum(x, u) return x def projector(x): # Turning points for constraints = l (l for lower) T1 = (x -l) / a # Turning points for constraints = u (u for upper) T2 = (x - u) / a T = jnp.concatenate((T1, T2)) T = jnp.sort(T) lower_bound = 0 upper_bound = 2 * n k = math.ceil(math.log2(2*n)) for i in range(k): index = (lower_bound + upper_bound) // 2 beta = T[index] # trial solution y = box_bound(x - beta * a)
[docs]def proj_conic(): r"""Projector function for Lorentz/ice-cream cone {(x,t): \| x \|_2 \leq t} """ @jit def proj(x): x = jnp.asarray(x) x = cnb.promote_arg_dtypes(x) x_pre = x[:-1] t = x[-1] abs_t = jnp.abs(t) norm_x = norm(x_pre) outside = norm_x > abs_t def project_inside(_): pre = x_pre / norm_x y = jnp.append(pre, 1) alpha = (norm_x + t) / 2 return alpha * y return lax.cond(outside, # bring the point inside the border project_inside, # now we need to check the sign of t lambda _ : lax.cond(t > 0, # the point in the upper half space, we can keep it as is lambda _ : x, # in the lower half space, the nearest point of the cone is the origin lambda _ : jnp.zeros_like(x), None) , None) return proj