Source code for cr.sparse._src.sls.fista

# Copyright 2021 CR-Suite Development Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Iterative shrinkage and thresholding algorithm

from jax import jit, lax
import jax.numpy as jnp

from typing import NamedTuple

from cr.nimble import arr_l2norm, arr_l2norm_sqr, arr_vdot
from .defs import identity_op, default_threshold

[docs]class FISTAState(NamedTuple): """ISTA algorithm state """ x : jnp.ndarray """Current solution estimate""" z : jnp.ndarray """Current solution estimate""" r : jnp.ndarray """Current residual""" t: float """FIST specific coefficient""" r_norm_sqr: jnp.ndarray """Square of residual norm""" x_change_norm: jnp.ndarray """Change in the norm of x """ iterations: int """Number of iterations to converge"""
[docs]def fista( operator, b, x0, step_size, threshold_func=default_threshold, basis=identity_op, res_norm_rtol=1e-3, x_norm_change_tol=1e-10, max_iters=1000, ): r"""Solves the problem :math:`\widehat{x} = \text{arg} \min_{x} \frac{1}{2}\| b - A x \|_2^2 + \lambda \mathbf{R}(x)` via fast iterative shrinkage and thresholding. It supports the more general problem where :math:`x` is sparse in a basis :math:`B` .. math:: \widehat{x} = \text{arg} \min_{x} \frac{1}{2}\| b - A x \|_2^2 + \lambda \mathbf{R} (B^H x) Args: operator (cr.sparse.lop.Operator): A linear operator :math:`A` b (jax.numpy.ndarray): Data vector x0 (jax.numpy.ndarray): An initial estimate :math:`x_0` of the model vector :math:`x` step_size (float): Step size for ISTA iteration threshold_func (Function): A user defined thresholding function. See :ref:`sls:thresholding` for details. basis (cr.sparse.lop.Operator): A sparsifying basis :math:`B` for :math:`x` res_norm_rtol (float): Relative tolerance for norm of residual :math:`r = b - A x` relative to norm of :math:`b` x_norm_change_tol (float): Tolerance for change in :math:`x` in each iteration max_iters (int): Maximum number of iterations Returns: FISTAState: a named tuple containing the solution :math:`x`, the residual :math:`r` and other details. See :ref:`sls:ista` for an introduction to IST algorithm. Both ``x0`` and ``b`` can be multi-dimensional arrays. """ m, n = operator.shape b_norm_sqr = arr_l2norm_sqr(b) r_norm_sqr_threshold = b_norm_sqr * (res_norm_rtol ** 2) def init(): # compute the initial residual r = b - operator.times(x0) # compute the norm of the initial residual r_norm_sqr = arr_l2norm_sqr(r) return FISTAState(x=x0, z=x0, r=r, t=1., r_norm_sqr=r_norm_sqr, x_change_norm=1e10, iterations=0) def body(state): # compute the gradient step grad = step_size * operator.trans(state.r) # update the solution x = state.z + grad # compute the representation of x in the sparsifying basis alpha = basis.trans(x) # apply the thresholding function on the sparse representation alpha = threshold_func(state.iterations, alpha) # convert back to data space x = basis.times(alpha) # update t t = (1. + jnp.sqrt(1. + 4. * state.t ** 2)) / 2. # update z z = x + ((state.t - 1.) / t) * (x - state.x) # update the residual r = b - operator.times(z) # compute the norm of the current residual r_norm_sqr = arr_l2norm_sqr(r) # compute the norm of change in x x_change_norm = arr_l2norm(x - state.x) return FISTAState(x=x, z=z, r=r, t=t, r_norm_sqr=r_norm_sqr, x_change_norm=x_change_norm, iterations=state.iterations+1) def cond(state): not_converged = jnp.greater(state.r_norm_sqr, r_norm_sqr_threshold) not_converged = jnp.logical_and(state.x_change_norm > x_norm_change_tol, not_converged) # return true if the the algorithm hasn't converged and there are more iterations to go return jnp.logical_and(state.iterations < max_iters, not_converged) state = lax.while_loop(cond, body, init()) return state
fista_jit = jit(fista, static_argnums=(0, 3, 4, 5, 6, 7, 8))