Source code for cr.sparse._src.sls.power

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Power iteration method for linear operators
"""

from jax import jit, lax
import jax.numpy as jnp

from typing import NamedTuple

from cr.nimble import arr_l2norm, arr_vdot

class PowerIterState(NamedTuple):
    """State for the power iterations algorithm
    """
    # (unnormalized) eigen vector guess
    v: jnp.ndarray
    # eigen value estimate
    old_estimate: float
    # new eigen value estimate
    new_estimate: float
    # number of iterations
    iterations: int

[docs]class PowerIterSolution(NamedTuple): """Solution of the eigen vector estimate """ v: jnp.ndarray """The estimated eigen vector""" s : float """The estimated eigen value""" iterations: int """Number of iterations to converge"""
[docs]def power_iterations( operator, b, max_iters=100, error_tolerance=1e-6): """Computes the largest eigen value of a (symmetric) linear operator by power method Args: operator (cr.sparse.lop.Operator): A symmetric linear operator :math:`A` b (jax.numpy.ndarray): A user provided initial guess for the largest eigen vector max_iters (int): Maximum number of iterations error_tolerance (float): Tolerance for relative change in largest eigen value Returns: PowerIterSolution: A named tuple containing the largest eigen value, corresponding eigen vector and the number of iterations for convergence The operator may accept multi-dimensional arrays as input. E.g. a 2D convolution operator will accept 2D images as input. In such cases, the eigen vector will also be a multi-dimensional array. Note: This will not work for matrices with complex eigen values. """ def init(): return PowerIterState(v=b, old_estimate=-1e20, new_estimate=1e20, iterations=0) def cond(state): # check if the gap between new and old estimate is still high change = state.new_estimate - state.old_estimate relchange = jnp.abs(change / state.old_estimate) not_converged = jnp.greater(relchange, error_tolerance) # return true if the the algorithm hasn't converged and there are more iterations to go return jnp.logical_and(state.iterations < max_iters, not_converged) def body(state): """One step of power iteration.""" v = state.v # normalize v_norm = arr_l2norm(v) v = v / v_norm # compute the next vector v_new = operator.times(v) # estimate the eigen value new_estimate = jnp.vdot(v, v_new) return PowerIterState(v=v_new, old_estimate=state.new_estimate, new_estimate=new_estimate, iterations=state.iterations+1) state = lax.while_loop(cond, body, init()) # We have converged v = state.v # normalize the eigen vector again v = v / arr_l2norm(v) return PowerIterSolution(v = v, s=state.new_estimate, iterations=state.iterations)
power_iterations_jit = jit(power_iterations, static_argnums=(0, 2, 3))