Source code for cr.sparse._src.cvx.spgl1

# Copyright 2022 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
JAX based implementation of the SPG-L1 algorithm.


References

* E. van den Berg and M. P. Friedlander, "Probing the Pareto frontier
  for basis pursuit solutions", SIAM J. on Scientific Computing,
  31(2):890-912. (2008).
"""

import math

from typing import NamedTuple, Callable

from numpy import array_str
import jax
from jax import lax, jit, device_get
import jax.numpy as jnp

import cr.sparse as crs
import cr.nimble as crn
norm = crn.arr_l2norm

############################################################################
#  Constants
############################################################################

EPS = jnp.finfo(jnp.float32).eps


############################################################################
#  Data Types for this module
############################################################################

[docs]class SPGL1Options(NamedTuple): """Options for the SPGL1 algorithm """ bp_tol: float = 1e-6 "Tolerance for basis pursuit solution" ls_tol: float = 1e-6 "Tolerance for least squares solution" opt_tol: float = 1e-4 "Optimality tolerance" dec_tol: float = 1e-4 "Required relative change in primal objective for Newton steps" gamma: float = 1e-4 "Sufficient decrease parameter" alpha_min: float = 1e-16 "Minimum spectral step" alpha_max: float = 1e5 "Maximum spectral step" memory: int = 3 "Number of past objective values to be retained" max_matvec: int = 100000 "Maximum number of A x and A^T x to be computed" max_iters: int = 100
############################################################################ # L1-Ball Projections ############################################################################ def _project_to_l1_ball(x, q): """Projects a vector inside an l1 norm ball """ # sort the absolute values in descending order u = jnp.sort(jnp.abs(x))[::-1] # compute the cumulative sums cu = jnp.cumsum(u) # find the index where the cumulative sum is below the threshold cu_diff = cu - q u_scaled = u*jnp.arange(1, 1+len(u)) flags = cu_diff > u_scaled K = jnp.argmax(flags) K = jnp.where(K == 0, len(flags), K) # compute the shrinkage threshold kappa = (cu[K-1] - q)/K # perform shrinkage if jnp.iscomplexobj(x): return jnp.maximum(jnp.abs(x) - kappa, 0.) * jnp.exp(1j * jnp.angle(x)) else: return jnp.maximum(0, x - kappa) + jnp.minimum(0, x + kappa) def project_to_l1_ball(x, q=1.): """Projects a vector inside an l1 norm ball """ x = jnp.asarray(x) shape = x.shape x = jnp.ravel(x) invalid = crn.arr_l1norm(x) > q return lax.cond(invalid, # find the shrinkage threshold and shrink lambda x: _project_to_l1_ball(x, q), # no changes necessary lambda x : x, x).reshape(shape) def project_to_l1_ball_at(x, b, q=1.): """Projects a vector inside an l1 norm ball centered at b """ x = jnp.asarray(x) # compute difference from center r = x - b r = project_to_l1_ball(r, q) # translate to the center return r + b ############################################################################ # Weighted and unweighted primal and dual norms ############################################################################ primal_norm = crn.arr_l1norm dual_norm = crn.norm_linf def weighted_primal_l1_norm(x, w): return crn.norm_l1(x * w) def weighted_dual_linf_norm(x, w): return crn.norm_linf(x / w) def obj_val(r): """ Objective value is half of squared norm of the residual """ return 0.5 * jnp.abs(jnp.vdot(r, r)) ############################################################################ # Curvilinear line search ############################################################################ class CurvyLineSearchState(NamedTuple): """State for the line search algorithm """ alpha: float scale: float x_new: jnp.ndarray r_new: jnp.ndarray d_new: jnp.ndarray gtd: float d_norm_old: float f_val: float f_lim : float n_iters: int n_safe: int def __str__(self): """Returns the string representation of the state """ s = [] x_norm = norm(self.x_new) r_norm = norm(self.r_new) d_norm = norm(self.d_new) for x in [ f'step: {self.alpha}, scale: {self.scale}', f'gtd: {self.gtd}', f'f_val: {self.f_val:.2f}, f_lim: {self.f_lim:.2f}', # f'x_norm: {x_norm:.2f}, r_norm: {r_norm:.2f}', # f'd_norm:{d_norm:.2f}', # f'n_iters: {self.n_iters}, n_safe: {self.n_safe}', ]: s.append(x.rstrip()) return u' '.join(s) def curvy_line_search(A, b, x, g, alpha0, f_max, proj, tau, gamma): """curvilinear line search """ max_iters = 10 g = alpha0 * g n = x.size n2 = math.sqrt(n) g_norm = norm(g) / n2 def candidate(alpha, scale): x_new = proj(x - alpha * scale * g, tau) r_new = b - A.times(x_new) d_new = x_new - x gtd = scale * jnp.real(jnp.vdot(g, d_new)) f_val = obj_val(r_new) f_lim = f_max + gamma * alpha * gtd return x_new, r_new, d_new, gtd, f_val, f_lim def init(): alpha = 1. scale = 1. x_new, r_new, d_new, gtd, f_val, f_lim = candidate(alpha, scale) return CurvyLineSearchState(alpha=alpha, scale=scale, x_new=x_new, r_new=r_new, d_new=d_new, gtd=gtd, d_norm_old=0., f_val=f_val, f_lim=f_lim, n_iters=0, n_safe=0) def next_func(state): alpha = state.alpha # reduce alpha size alpha /= 2. # check if the scale needs to be reduced d_norm = norm(state.d_new) / n2 d_norm_old = state.d_norm_old # check if the iterates of x are too close to each other too_close = jnp.abs(d_norm - d_norm_old) <= 1e-6 * d_norm scale = state.scale n_safe = state.n_safe scale, n_safe = lax.cond(too_close, lambda _: ((d_norm / g_norm / (2. ** n_safe)), n_safe + 1), lambda _: (scale, n_safe), None) x_new, r_new, d_new, gtd, f_val, f_lim = candidate(alpha, scale) return CurvyLineSearchState(alpha=alpha, scale=scale, x_new=x_new, r_new=r_new, d_new=d_new, gtd=gtd, d_norm_old=d_norm, f_val=f_val, f_lim=f_lim, n_iters=state.n_iters+1, n_safe=n_safe) def cond_func(state): # print(state) a = state.n_iters < max_iters b = state.gtd < 0 c = state.f_val >= state.f_lim a = jnp.logical_and(a, b) a = jnp.logical_and(a, c) return a state = init() state = lax.while_loop(cond_func, next_func, state) # print(state) # while cond_func(state): # state = next_func(state) # print(state) return state ############################################################################ # SPG-L1 Solver for LASSO problem ############################################################################ def lasso_metrics(b, x, g, r, f, tau): # dual norm of the gradient g_dnorm = dual_norm(g) # norm of the residual r_norm = norm(r) # duality gap gap = jnp.dot(jnp.conj(r), r - b) + tau * g_dnorm # relative duality gap f_m = jnp.maximum(1, f) r_gap = jnp.abs(gap) / f_m return r_norm, r_gap
[docs]class SPGL1LassoState(NamedTuple): """Solution state of the SPGL1 algorithm for LASSO problem """ x: jnp.ndarray "Solution vector" g : jnp.ndarray "Gradient vector" r : jnp.ndarray "residual vector" f_past: jnp.ndarray "Past function values" r_norm: float "Residual norm" r_gap: float "Relative duality gap" alpha: float "Step size in the current iteration" alpha_next: float "Step size for the next iteration" # counters iterations: int n_times: int "Number of multiplications with A" n_trans: int "Number of multiplications with A^T" n_ls_iters: int "Number of line search iterations in the current iteration" def __str__(self): """Returns the string representation of the state """ s = [] f_val = self.f_past[0] g_norm = norm(self.g) for x in [ f'[{self.iterations}] ', f'f_val:{f_val:.3f} r_gap: {self.r_gap:.3f}', f'g_norm:{g_norm:.3f} r_norm: {self.r_norm:.3f}', f'lsi:{self.n_ls_iters}, alpha: {self.alpha:.3f}, alpha_n: {self.alpha_next:.3f}', # f'x: {format(self.x)}', # f'g: {format(self.g)}', ]: s.append(x.rstrip()) return u' '.join(s)
[docs]def solve_lasso_from(A, b: jnp.ndarray, tau: float, x0: jnp.ndarray, options: SPGL1Options = SPGL1Options(), tracker=crs.noop_tracker): """Solves the LASSO problem using SPGL1 algorithm with an initial solution """ # shape of the linear operator m, n = A.shape alpha_min = options.alpha_min alpha_max = options.alpha_max opt_tol = options.opt_tol b_norm = norm(b) def init(): x = jnp.asarray(x0) x = project_to_l1_ball(x, tau) # initial residual r = b - A.times(x) # initial gradient g = -A.trans(r) # objective value f = obj_val(r) # prepare the memory of past function values f_past = jnp.full(options.memory, f) # projected gradient direction d = project_to_l1_ball(x - g, tau) - x # initial step length calculation d_norm = crn.norm_linf(d) alpha = 1. / d_norm alpha = jnp.clip(alpha, alpha_min, alpha_max) r_norm, r_gap = lasso_metrics(b, x, g, r, f, tau) return SPGL1LassoState(x=x, g=g, r=r, f_past=f_past, r_norm=r_norm, r_gap=r_gap, alpha=alpha, alpha_next=alpha, iterations=1, n_times=1, n_trans=1, n_ls_iters=0) def body_func(state): f_max = jnp.max(state.f_past) lsearch = curvy_line_search(A, b, state.x, state.g, state.alpha_next, f_max, project_to_l1_ball, tau, options.gamma) n_times = state.n_times + lsearch.n_iters + 1 # new x value x = lsearch.x_new # new residual r = lsearch.r_new # new gradient g = -A.trans(r) n_trans = state.n_trans + 1 # new function value f = lsearch.f_val # update past values f_past = crn.cbuf_push_left(state.f_past, f) r_norm, r_gap = lasso_metrics(b, x, g, r, f, tau) s = x - state.x y = g - state.g sts = jnp.real(jnp.dot(jnp.conj(s), s)) sty = jnp.real(jnp.dot(jnp.conj(s), y)) alpha_next = lax.cond(sty <= 0, lambda _: alpha_max, lambda _: jnp.clip(sts / sty, alpha_min, alpha_max), None) return SPGL1LassoState(x=x, g=g, r=r, f_past=f_past, r_norm=r_norm, r_gap=r_gap, alpha=lsearch.alpha, alpha_next=alpha_next, iterations=state.iterations+1, n_times=n_times, n_trans=n_times, n_ls_iters=state.n_ls_iters + lsearch.n_iters) def cond_func(state): # print(state) a = state.iterations < options.max_iters b = state.r_gap > opt_tol c = state.r_norm >= opt_tol * b_norm a = jnp.logical_and(a, b) a = jnp.logical_and(a, c) jax.debug.callback(tracker, state, more=a) return a state = init() state = lax.while_loop(cond_func, body_func, state) # while cond_func(state): # state = body_func(state) return state
[docs]def solve_lasso(A, b: jnp.ndarray, tau: float, options: SPGL1Options = SPGL1Options(), tracker=crs.noop_tracker): """Solves the LASSO problem using SPGL1 algorithm """ m, n = A.shape x0 = jnp.zeros(n) return solve_lasso_from(A, b, tau, x0, options=options, tracker=tracker)
solve_lasso_jit = jit(solve_lasso, static_argnames=("A", "tracker")) def analyze_lasso_state(A, b, tau, options, state, x0): m, n = A.shape x = state.x r = state.r g = state.g print(f'm={m}, n={n}, tau: {tau:.2f}, b_norm: {norm(b):.2f}') print(f'iterations={state.iterations}, times={state.n_times},' + f' trans={state.n_trans}, line search={state.n_ls_iters}') snr = crn.signal_noise_ratio(x0, x) prd = crn.percent_rms_diff(x0, x) print(f'SNR: {snr:.2f} dB, PRD: {prd:.1f} %') print(f'x0: l1: {crn.norm_l1(x0):.3f}, l2: {crn.norm_l2(x0):.3f}, linf: {crn.norm_linf(x0):.3f}') print(f'x : l1: {crn.norm_l1(x):.3f}, l2: {crn.norm_l2(x):.3f}, linf: {crn.norm_linf(x):.3f}') r_norm = state.r_norm print(f'r_norm: {r_norm:.4f}') print(f'alpha: {state.alpha:.3f}, alpha_n: {state.alpha_next:.3f}') f_past = state.f_past f = f_past[0] f_prev = f_past[1] f_change = jnp.abs(f - f_prev) rel_f_change = f_change / f print(f'f_val:{f:.2e} f_prev: {f_prev:.2e}, change: {f_change:.2e}, rel change: {rel_f_change * 100:.2f}%') print(f'g_norm:{norm(g):.2e} g_dnorm: {crn.norm_linf(g):.2e}') ############################################################################ # SPG-L1 Solver for BPIC problem ############################################################################
[docs]class SPGL1BPState(NamedTuple): """Solution state of the SPGL1 algorithm for BPIC problem """ x: jnp.ndarray "Solution vector" g : jnp.ndarray "Gradient vector" r : jnp.ndarray "residual vector" f_past: jnp.ndarray "Past function values" tau: float "The limit on the l1-norm" tau_changed: bool "Flag indicating if tau was changed" r_norm: float "Residual norm" r_gap: float "Relative duality gap" r_res_error: float "Relative error of residual norm from sigma" r_f_error: float "Relative error of objective value from sigma^2/2" alpha: float "Step size in the current iteration" alpha_next: float "Step size for the next iteration" # counters iterations: int n_times: int "Number of multiplications with A" n_trans: int "Number of multiplications with A^T" n_newton: int "Number of newton steps" n_ls_iters: int "Number of line search iterations in the current iteration" def __str__(self): """Returns the string representation of the state """ s = [] f_val = self.f_past[0] g_norm = norm(self.g) ch = ' C ' if self.tau_changed else "" for x in [ f'[{self.iterations}] ', f'r_norm: {self.r_norm:.6f}', f'r_gap: {self.r_gap:.2e}', f'g_norm:{g_norm:.3f}', f'f_val:{f_val:.4f} ', f'alpha: {self.alpha:.3f}', f'lsi:{self.n_ls_iters}', f'tau: {self.tau:.4f}{ch}', ]: s.append(x.rstrip()) return u' '.join(s)
def bpic_metrics(b, x, g, r, f, sigma, tau): # dual norm of the gradient g_dnorm = dual_norm(g) # norm of the residual r_norm = norm(r) # duality gap gap = jnp.vdot(r, r - b) + tau * g_dnorm # relative duality gap f_m = jnp.maximum(1, f) r_m = jnp.maximum(1., r_norm) r_gap = jnp.abs(gap) / f_m res_error = r_norm - sigma f_error = f - sigma**2 / 2.0 r_res_error = jnp.abs(res_error) / r_m r_f_error = jnp.abs(f_error) / f_m return g_dnorm, r_norm, r_gap, r_res_error, r_f_error def tau_change(A, r, r_norm, sigma): y = r / r_norm lambda_ = crn.norm_linf(A.trans(y)) phi = r_norm phi_d = -lambda_ change = (sigma - phi) / phi_d return change def compute_rgf(A, b, x): # update residual r = b - A.times(x) # compute gradient g = -A.trans(r) # objective value f = obj_val(r) return r, g, f def update_xrgf(A, b, x, tau): # bring x to this ball x = project_to_l1_ball(x, tau) # update residual r = b - A.times(x) # compute gradient g = -A.trans(r) # objective value f = obj_val(r) return x, r, g, f
[docs]def solve_bpic_from(A, b: jnp.ndarray, sigma: float, x0: jnp.ndarray, options: SPGL1Options = SPGL1Options(), tracker=crs.noop_tracker): """Solves the BPIC problem using SPGL1 algorithm with an initial solution """ # shape of the linear operator m, n = A.shape alpha_min = options.alpha_min alpha_max = options.alpha_max opt_tol = options.opt_tol b_norm = norm(b) def init(): x = jnp.asarray(x0) # initial value of tau tau = crn.norm_l1(x) # compute initial residual gradient etc. x, r, g, f = update_xrgf(A, b, x, tau) # compute all the metrics g_dnorm, r_norm, r_gap, r_res_error, r_f_error = bpic_metrics(b, x, g, r, f, sigma, tau) tau = jnp.maximum(0, tau + (r_norm * (r_norm - sigma) ) / g_dnorm) # # update x as per the new value of tau # x, r, g, f = update_xrgf(A, b, x, tau) # # update the metrics # g_dnorm, r_norm, r_gap, r_res_error, r_f_error = bpic_metrics(b, x, g, r, f, sigma, tau) # projected gradient direction d = project_to_l1_ball(x - g, 0.) - x # initial step length calculation d_norm = crn.norm_linf(d) alpha = 1. / d_norm alpha = jnp.clip(alpha, alpha_min, alpha_max) # prepare the memory of past function values f_past = jnp.full(options.memory, f) return SPGL1BPState(x=x, g=g, r=r, f_past=f_past, tau=tau, tau_changed=True, r_norm=r_norm, r_gap=r_gap, r_res_error=r_res_error, r_f_error=r_f_error, alpha=alpha, alpha_next=alpha, iterations=1, n_times=2, n_trans=2, n_newton=1, n_ls_iters=0) #@jit def body_func(state): f_max = jnp.max(state.f_past) # perform line search lsearch = curvy_line_search(A, b, state.x, state.g, state.alpha_next, f_max, project_to_l1_ball, state.tau, options.gamma) n_times = state.n_times + lsearch.n_iters + 1 # new x value x = lsearch.x_new # new residual r = lsearch.r_new # new gradient g = -A.trans(r) n_trans = state.n_trans + 1 # new function value f = lsearch.f_val # compute various metrics g_dnorm, r_norm, r_gap, r_res_error, r_f_error = bpic_metrics(b, x, g, r, f, sigma, state.tau) # checks if we need to update tau f_old = state.f_past[0] f_change = jnp.abs(f- f_old) tc_a = f_change <= options.dec_tol * f tc_b = f_change <= 1e-1 * f * jnp.abs(r_norm - sigma) flag_c = jnp.logical_or( jnp.logical_and( tc_a, r_norm > 2 * sigma), jnp.logical_and( tc_b, r_norm <= 2 * sigma), ) # print(f'f:{f}, f_old: {f_old}, fc:{f_change} a: {tc_a}, b: {tc_b}, tc: {flag_c}') # we shall change tau only if it didn't change in the last iteration change_tau = jnp.logical_and(flag_c, jnp.logical_not(state.tau_changed)) # update tau if necessary tau = lax.cond(change_tau, lambda _ : jnp.maximum(0, state.tau + (r_norm * (r_norm - sigma) ) / g_dnorm), lambda _: state.tau, None) # update the solution to be consistent with new tau value if necessary tau_reduced = tau < state.tau x, r, g, f = lax.cond(tau_reduced, lambda _: update_xrgf(A, b, x, tau), lambda _: (x, r, g, f), None) n_times, n_trans = n_times + tau_reduced, n_trans + tau_reduced n_newton = state.n_newton + change_tau # update past objective values with the new objective value f_past = crn.cbuf_push_left(state.f_past, f) # compute the new step size s = x - state.x y = g - state.g sts = jnp.real(jnp.vdot(s, s)) sty = jnp.real(jnp.vdot(s, y)) alpha_next = lax.cond(sty <= 0, lambda _: alpha_max, lambda _: jnp.clip(sts / sty, alpha_min, alpha_max), None) return SPGL1BPState(x=x, g=g, r=r, f_past=f_past, tau=tau, tau_changed=change_tau, r_norm=r_norm, r_gap=r_gap, r_res_error=r_res_error, r_f_error=r_f_error, alpha=lsearch.alpha, alpha_next=alpha_next, iterations=state.iterations+1, n_times=n_times, n_trans=n_trans, n_newton=n_newton, n_ls_iters=state.n_ls_iters + lsearch.n_iters) @jit def cond_func(state): # if a and b are true then we continue. Otherwise we check more conditions a = state.r_gap > jnp.maximum(opt_tol, state.r_f_error) b = state.r_res_error > opt_tol # we check the following three conditions if either a or b is false u = state.r_norm > sigma v = state.r_res_error > opt_tol w = state.r_norm > options.bp_tol * b_norm x = jnp.all(jnp.array([u, v, w])) cont = jnp.logical_or(jnp.logical_and(a, b), x) # check on maximum number of iterations cont = jnp.logical_and(cont, state.iterations < options.max_iters) # check on maximum number of matrix vector products f = state.n_times + state.n_trans < options.max_matvec cont = jnp.logical_and(cont, f) jax.debug.callback(tracker, state, more=cont) return cont state = init() state = lax.while_loop(cond_func, body_func, state) # while cond_func(state): # print(state) # state = body_func(state) # print(state) return state
solve_bpic_from_jit = jit(solve_bpic_from, static_argnames=("A", "options", "tracker"))
[docs]def solve_bpic(A, b: jnp.ndarray, sigma: float, options: SPGL1Options = SPGL1Options(), tracker=crs.noop_tracker): """Solves the BPIC problem using SPGL1 algorithm """ m, n = A.shape x0 = jnp.zeros(n, dtype=b.dtype) return solve_bpic_from(A, b, sigma, x0, options=options, tracker=tracker)
solve_bpic_jit = jit(solve_bpic, static_argnames=("A", "options", "tracker")) def analyze_bpic_state(A, b, sigma, options, state, x0): m, n = A.shape x = state.x r = state.r g = state.g print(f'm={m}, n={n}, sigma: {sigma:.2f}, b_norm: {norm(b):.2f}') print(f'iterations={state.iterations}, times={state.n_times},' + f' trans={state.n_trans}, newton={state.n_newton}, line search={state.n_ls_iters}') snr = crn.signal_noise_ratio(x0, x) prd = crn.percent_rms_diff(x0, x) print(f'SNR: {snr:.2f} dB, PRD: {prd:.1f} %') print(f'x0: l1: {crn.norm_l1(x0):.3f}, l2: {crn.norm_l2(x0):.3f}, linf: {crn.norm_linf(x0):.3f}') print(f'x : l1: {crn.norm_l1(x):.3f}, l2: {crn.norm_l2(x):.3f}, linf: {crn.norm_linf(x):.3f}') r_norm = state.r_norm rs = r_norm / sigma print(f'r_norm: {r_norm:.4f} r/sigma: {rs:.3f}') print(f'tau: {state.tau:.2e}, alpha: {state.alpha:.3f}, alpha_n: {state.alpha_next:.3f}') f_past = state.f_past f = f_past[0] f_prev = f_past[1] f_change = jnp.abs(f - f_prev) rel_f_change = f_change / f print(f'f_val:{f:.2e} f_prev: {f_prev:.2e}, change: {f_change:.2e}, rel change: {rel_f_change * 100:.2f}%') print(f'g_norm:{norm(g):.2e} g_dnorm: {crn.norm_linf(g):.2e}') if state.r_gap <= jnp.maximum(options.opt_tol, state.r_f_error): print(f'Relative gap {state.r_gap:.2e} is below optimality tolerance') if state.r_res_error <= options.opt_tol: print(f'Relative residual error {state.r_res_error:.2e} is below optimality tolerance')
[docs]def solve_bp(A, b: jnp.ndarray, options: SPGL1Options = SPGL1Options(), tracker=crs.noop_tracker): """Solves the Basis Pursuit problem using SPGL1 algorithm Examples: * :ref:`gallery:0002` * :ref:`gallery:0003` """ m, n = A.shape dtype = jnp.complex128 if not A.real else b.dtype x0 = jnp.zeros(n, dtype=dtype) sigma = 0. return solve_bpic_from(A, b, sigma, x0, options=options, tracker=tracker)
solve_bp_jit = jit(solve_bp, static_argnames=("A", "options", "tracker"))