Source code for cr.sparse._src.cvx.adm.yall1

# Copyright 2021 CR-Suite Development Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
This module implements algorithms from the paper

from typing import NamedTuple, List, Dict

import jax.numpy as jnp
from jax import jit, lax, vmap
norm = jnp.linalg.norm

from cr.sparse import lop
from cr.sparse import RecoveryFullSolution

def project_to_box(z, w):
    ww = jnp.maximum(w, jnp.abs(z))
    factors = w / ww
    return z * factors

def project_to_real_upper_limit(z, w):
    return jnp.minimum(w, jnp.real(z))

class BPState(NamedTuple):
    x: jnp.ndarray
    """Primal variable"""
    x_prev: jnp.ndarray
    """Previous value of the primal variable"""
    z: jnp.ndarray
    """Dual variable"""
    rp: jnp.ndarray
    "Primal residual"
    rd: jnp.ndarray
    "Dual residual"
    primal_objective: float
    "Primal objective function value"
    dual_objective: float
    "Dual objective function value"
    iterations: int
    """Number of iterations"""
    n_times: int = 0
    """Number of times A x computed """
    n_trans : int = 0
    """Number of times A.T b computed """

def bp_setup(A, b):
    """Returns the parameters for calling `solv_bp`
    m = b.shape[0]
    Atb = A.trans(b)
    n_trans = 1
    n_times = 0
    n = Atb.shape[0]
    b_max  = float(norm(b, ord=jnp.inf))
    atb_max = float(norm(Atb, ord=jnp.inf))
    x0 = Atb / b_max
    z0 = jnp.zeros(n)
    w = jnp.ones(n)
    b = b / b_max
    return b, x0, z0, w, b_max, n_times, n_trans

def finalize(state, b_max, n_times, n_trans, W=None, nonneg=False):
    """Finalizes the YALL1 solver
    x = jnp.where(nonneg, jnp.maximum(0, state.x), state.x)
    if W:
        # go back from sparsifying basis to signal space
        x = W.times(x)
    return RecoveryFullSolution(x=b_max*x, r=b_max*state.rp, 

[docs]def solve_bp(A, b, x0, z0, w, nonneg, gamma, tolerance, max_iters): r""" Solves the problem :math:`\min \| x \|_1 \, \\text{s.t.}\, A x = b` using ADMM This function implements eq 2.29 of the paper. """ times = A.times trans = A.trans mu = jnp.mean(jnp.abs(b)) mu_orig = mu b_by_mu = b / mu rp_norm_threshold = tolerance * norm(b) def init(): # primal residual rp = b - times(x0) # dual residual rd = - trans(b) primal_objective = jnp.sum(jnp.abs(w*x0)) # update dual objective dual_objective = 0. # initial state return BPState(x=x0, x_prev=jnp.zeros(x0.shape), z=z0, rp=rp, rd=rd, primal_objective=primal_objective, dual_objective=dual_objective, iterations=0, n_times=1, n_trans=1) def iteration(state): # update y x_by_mu = state.x / mu y = times(state.z - x_by_mu) + b_by_mu Aty = trans(y) # update z z = Aty + x_by_mu z = jnp.where(nonneg, project_to_real_upper_limit(z, w), project_to_box(z, w)) n_times = state.n_times + 1 n_trans = state.n_trans + 1 # dual residual rd = z - Aty # update x x = state.x - (gamma*mu) * rd # primal residual rp = b - times(x) n_times += 1 # primary objective primal_objective = jnp.sum(jnp.abs(w*x)) # dual objective dual_objective = b.T @ y # updatd state return BPState(x=x, x_prev=state.x, z=z, rp=rp, rd=rd, primal_objective=primal_objective, dual_objective=dual_objective, iterations=state.iterations+1, n_times=n_times, n_trans=n_trans) def double_iteration(state): state = iteration(state) return iteration(state) def cond(state): """ Stopping condition: - Either relative change in x should be within tolerance. - Or both relative duality gap and relative dual norm should be within tolerance. """ q = 0.1 # limit on number of iterations more_iters = state.iterations < max_iters # x norm x_norm = norm(state.x) # relative change in x norm x_relative_change = norm(state.x - state.x_prev) / x_norm # condition on relative change in x norm x_unstable = x_relative_change > tolerance * (1 - q) # condition on dual residual norm rel_rd = norm(state.rd) / norm(state.z) d_infeasible = rel_rd > tolerance # duality gap duality_gap = jnp.abs(state.dual_objective - state.primal_objective) # relative duality gap relative_gap = duality_gap / state.primal_objective gap_infeasible = relative_gap > tolerance # primal residual norm rp_norm = norm(state.rp) # check feasibility of primal residual norm p_infeasible = rp_norm >= rp_norm_threshold # if either duality gap or dual res norm are beyond tolerance, we continue condition = jnp.logical_or(gap_infeasible, d_infeasible) condition = jnp.logical_or(condition, p_infeasible) condition = jnp.logical_and(condition, more_iters) condition = jnp.logical_and(condition, x_unstable) # print(f'[{state.iterations:02d}] x_norm: {x_norm:.3f}, rel:{x_relative_change:.2e} ' + # f'rel_rd {rel_rd:.2e} rp_norm: {rp_norm:.2e} p_infeasible: {p_infeasible}' + # f' p_obj: {state.primal_objective:.1e}, d_obj: {state.dual_objective:.1e} relative_gap {relative_gap:.1e}') return condition # state = init() # while cond(state): # state = double_iteration(state) state = lax.while_loop(cond, double_iteration, init()) return state
solve_bp_jit = jit(solve_bp, static_argnums=(0,))
[docs]def solve_l1_l2(A, b, x0, z0, w, nonneg, rho, gamma, tolerance, max_iters): r""" Solves the problem :math:`\min \| x \|_1 + \\frac{1}{2 \\rho} \| A x - b \|_2^2` using ADMM This function implements eq 2.25 of the paper. """ times = A.times trans = A.trans mu = jnp.mean(jnp.abs(b)) mu_orig = mu rho_by_mu = rho / mu rho_by_mu_p1 = rho_by_mu + 1 b_by_mu = b / mu #print(f'mu: {mu}, rho_by_mu: {rho_by_mu}, rho_by_mu_p1: {rho_by_mu_p1}') def init(): x = x0 z = z0 # primal residual rp = b - times(x) # dual residual rd = - trans(b) rp_norm_sqr = rp.T @ rp primal_objective = jnp.sum(jnp.abs(w*x)) + (0.5 / rho) * rp_norm_sqr # update dual objective dual_objective = 0. # initial state return BPState(x=x, x_prev=jnp.zeros(x.shape), z=z, rp=rp, rd=rd, primal_objective=primal_objective, dual_objective=dual_objective, iterations=0, n_times=1, n_trans=1) def iteration(state): # update y x_by_mu = state.x / mu y = times(state.z - x_by_mu) + b_by_mu y = y / rho_by_mu_p1 Aty = trans(y) #print(f'\n[{state.iterations+1}]', end='') #print(state.x[0:5]) #print(y[0:5]) # update z z = Aty + x_by_mu z = jnp.where(nonneg, project_to_real_upper_limit(z, w), project_to_box(z, w)) n_times = state.n_times + 1 n_trans = state.n_trans + 1 # dual residual rd = z - Aty # update x x = state.x - (gamma*mu) * rd # primal residual rp = b - times(x) n_times += 1 # primal resdiual norm squared rp_norm_sqr = rp.T @ rp # y norm squared y_norm_sqr = y.T @ y # primary objective primal_objective = jnp.sum(jnp.abs(w*x)) + (0.5 / rho) * rp_norm_sqr # dual objective dual_objective = b.T @ y - (0.5 * rho) * y_norm_sqr # updatd state return BPState(x=x, x_prev=state.x, z=z, rp=rp, rd=rd, primal_objective=primal_objective, dual_objective=dual_objective, iterations=state.iterations+1, n_times=n_times, n_trans=n_trans) def double_iteration(state): state = iteration(state) return iteration(state) def cond(state): """ Stopping condition: - Either relative change in x should be within tolerance. - Or both relative duality gap and relative dual norm should be within tolerance. """ q = 0.1 # limit on number of iterations condition = state.iterations < max_iters # x norm x_norm = norm(state.x) # relative change in x norm x_relative_change = norm(state.x - state.x_prev) / x_norm # condition on relative change in x norm condition = jnp.logical_and(condition, x_relative_change > tolerance * (1 - q)) # condition on dual residual norm rel_rd = norm(state.rd) / norm(state.z) # duality gap duality_gap = jnp.abs(state.dual_objective - state.primal_objective) # relative duality gap relative_gap = duality_gap / state.primal_objective # if either duality gap or dual res norm are beyond tolerance, we continue rd_gap_cond = jnp.logical_or(relative_gap > tolerance, rel_rd > tolerance) # combined condition condition = jnp.logical_and(condition, rd_gap_cond) #print(f'[{state.iterations:02d}] x_norm: {x_norm:.3f}, rel:{x_relative_change:.2e} ' + # f'rel_rd {rel_rd:.2e} p_obj: {state.primal_objective:.1e}, d_obj: {state.dual_objective:.1e} relative_gap {relative_gap:.1e}') return condition # state = init() # while cond(state): # state = double_iteration(state) state = lax.while_loop(cond, double_iteration, init()) return state
solve_l1_l2_jit = jit(solve_l1_l2, static_argnums=(0,))
[docs]def solve_l1_l2con(A, b, x0, z0, w, nonneg, delta, gamma, tolerance, max_iters): r""" Solves the problem :math:`\min \| x \|_1 \\text{s.t.} \| A x - b \|_2 \\leq \\delta` using ADMM This function implements eq 2.27 of the paper. """ times = A.times trans = A.trans mu = jnp.mean(jnp.abs(b)) mu_orig = mu b_by_mu = b / mu delta_by_mu = delta / mu rp_norm_threshold = delta * (1 + tolerance) # print(f'mu: {mu}, delta_by_mu: {delta_by_mu}, rp_norm_threshold: {rp_norm_threshold}') # print(f'x0', end='') # print(x0[0:6]) def init(): # primal residual rp = b - times(x0) # dual residual rd = - trans(b) primal_objective = jnp.sum(jnp.abs(w*x0)) # update dual objective dual_objective = 0. # initial state return BPState(x=x0, x_prev=jnp.zeros(x0.shape), z=z0, rp=rp, rd=rd, primal_objective=primal_objective, dual_objective=dual_objective, iterations=0, n_times=1, n_trans=1) def iteration(state): # update y x_by_mu = state.x / mu y = times (state.z - x_by_mu) + b_by_mu # subtract projection of y to l2 ball from y y_norm = norm(y) y = jnp.maximum(0, 1 - delta_by_mu / y_norm) * y Aty = trans(y) # update z z = Aty + x_by_mu z = jnp.where(nonneg, project_to_real_upper_limit(z, w), project_to_box(z, w)) n_times = state.n_times + 1 n_trans = state.n_trans + 1 # dual residual rd = z - Aty # update x x = state.x - (gamma*mu) * rd # primal residual rp = b - times(x) n_times += 1 # primary objective primal_objective = jnp.sum(jnp.abs(w*x)) # dual objective dual_objective = b.T @ y - delta * norm(y) # print(f'x[{state.iterations+1}]', end='') # print(x[0:6]) # print(f'y[{state.iterations+1}]', end='') # print(y[0:6]) # print(f'z[{state.iterations+1}]', end='') # print(z[0:6]) # updatd state return BPState(x=x, x_prev=state.x, z=z, rp=rp, rd=rd, primal_objective=primal_objective, dual_objective=dual_objective, iterations=state.iterations+1, n_times=n_times, n_trans=n_trans) def double_iteration(state): state = iteration(state) return iteration(state) def cond(state): """ Stopping condition: - Either relative change in x should be within tolerance. - Or both relative duality gap and relative dual norm should be within tolerance. """ # limit on number of iterations more_iters = state.iterations < max_iters # x norm x_norm = norm(state.x) # relative change in x norm x_relative_change = norm(state.x - state.x_prev) / x_norm # condition on relative change in x norm x_unstable = x_relative_change > tolerance # condition on dual residual norm rel_rd = norm(state.rd) / norm(state.z) d_infeasible = rel_rd > tolerance # duality gap duality_gap = jnp.abs(state.dual_objective - state.primal_objective) # relative duality gap relative_gap = duality_gap / state.primal_objective gap_infeasible = relative_gap > tolerance # primal residual norm rp_norm = norm(state.rp) # check feasibility of primal residual norm p_infeasible = rp_norm > rp_norm_threshold # if either duality gap or dual res norm are beyond tolerance, we continue condition = jnp.logical_or(gap_infeasible, d_infeasible) condition = jnp.logical_or(condition, p_infeasible) condition = jnp.logical_and(condition, more_iters) condition = jnp.logical_and(condition, x_unstable) # print(f'[{state.iterations:02d}] x_norm: {x_norm:.3f}, rel:{x_relative_change:.2e} ' + # f'rel_rd {rel_rd:.2e} rp_norm: {rp_norm:.2e} p_infeasible: {p_infeasible}' + # f' p_obj: {state.primal_objective:.1e}, d_obj: {state.dual_objective:.1e} relative_gap {relative_gap:.1e}') return condition # state = init() # while cond(state): # state = double_iteration(state) state = lax.while_loop(cond, double_iteration, init()) return state
solve_l1_l2con_jit = jit(solve_l1_l2con, static_argnums=(0, 5,6,7,8, 9))
[docs]def solve(A, b, x0=None, z0=None, W=None, weights=None, nonneg=False, rho=0., delta=0., gamma=1.0, tolerance=5e-3, max_iters=9999, jit=True): """Wrapper method to solve a variety of l1 minimization problems using ADMM Args: A (jax.numpy.ndarray): Sensing matrix/dictionary b (jax.numpy.ndarray): Signal being approximated x0 (jax.numpy.ndarray): Initial value of solution (primary variable) :math:`x` z0 (jax.numpy.ndarray): Initial value of dual variable :math:`z` nonneg (bool): Flag to indicate if values in the solution are all non-negative W (jax.numpy.ndarray): The sparsifying orthonormal basis such that :math:`W x` is sparse weights (jax.numpy.ndarray): The weights for individual entries in :math:`x` rho (float): weight for the quadratic penalty term delta (float): constraint on the residual norm gamma (float): ADMM update parameter for :math:`x` max_iters (int): maximum number of ADMM iterations Returns: RecoveryFullSolution: Solution vector :math:`x` and residual :math:`r` This function is based on :cite:`yang2011alternating`. It implements eq 2.25 of the paper. """ if W: # change A to solve for alpha = W x A = A @ W m = b.shape[0] Atb = A.trans(b) n_times = 0 n_trans = 1 n = Atb.shape[0] b_max = float(norm(b, ord=jnp.inf)) atb_max = float(norm(Atb, ord=jnp.inf)) zero_solution = False if rho > 0: zero_solution = atb_max <= rho if delta > 0: zero_solution = norm(b) <= delta if zero_solution: x = jnp.zeros(n) return RecoveryFullSolution(x=x, r=b, iterations=0, n_times=n_times, n_trans=n_trans) if x0 is None: x0 = Atb / b_max if z0 is None: z0 = jnp.zeros(n) w = jnp.ones(n) if weights is not None: # make sure that the final weights are an array of size n w = w * weights # scale data and model parameters b = b / b_max if rho > 0: rho = rho / b_max if delta > 0: delta = delta / b_max if jit: if rho > 0: # It's an l1-l2 problem state = solve_l1_l2_jit(A, b, x0, z0, w, nonneg, rho, gamma, tolerance, max_iters) elif delta > 0: # It's an l1-l2 constrained problem BPIC state = solve_l1_l2con_jit(A, b, x0, z0, w, nonneg, delta, gamma, tolerance, max_iters) else: # It's a basis pursuit problem state = solve_bp_jit(A, b, x0, z0, w, nonneg, gamma, tolerance, max_iters) else: if rho > 0: # It's an l1-l2 problem BPDN state = solve_l1_l2(A, b, x0, z0, w, nonneg, rho, gamma, tolerance, max_iters) elif delta > 0: # It's an l1-l2 constrained problem BPIC state = solve_l1_l2con(A, b, x0, z0, w, nonneg, delta, gamma, tolerance, max_iters) else: # It's a basis pursuit problem state = solve_bp(A, b, x0, z0, w, nonneg, gamma, tolerance, max_iters) x = jnp.where(nonneg, jnp.maximum(0, state.x), state.x) if W: # go back from sparsifying basis to signal space x = W.times(x) return RecoveryFullSolution(x=b_max*x, r=b_max*state.rp, iterations=state.iterations, n_times=state.n_times+n_times, n_trans=state.n_trans+n_trans)