Thinking in JAX¶
JAX API is similar to NumPy. However there are many differences needed to achieve the Just In Time (JIT) compilation of functions written in JAX. In this section, we give a set of examples showing how to write numerical code properly with JAX.
JAX arrays are immutable.
It should be possible to statically determine the shape of function output variables from the shape of input variables for the JIT compiler.
jax.laxis a low level module containing several helper functions to express complex logic in functional manner. We will often use its functions in the examples below.
whileloops cannot be JIT compiled.
forloop will be unrolled by JIT compiler (if the iteration count can be statically determined). This increases compilation time and should be avoided.
lax.fori_loopas functional alternatives.
JAX doesn’t support array views.
If there are some arguments to a function which determine the shape of intermediate arrays in the function body or the output of the function body, then they must be marked via
static_argnamesto the JIT compiler.
if elsePython blocks cannot be JIT compiled. You can use
jnp.wherefor building equivalent logic.
Sometimes, you may worry that you are writing too many low level functions just to make the JIT compiler happy to implement some logic which could have been done by using some for/while loops in normal NumPy code. But this additional complexity pays off in the end. If the JIT compiler accepts your implementation, it will generate code which will usually be much faster than NumPy version.
As we are not used to writing functional code, it takes a lot of effort to come up with proper JAX compatible designs in the beginning. The more you code in JAX, the easier it becomes to think functionally.
All the code snippets in this tutorial
are taken from the code in
In the following
jnp is a short name
import jax.numpy as jnp
Activating 64-bit mode
By default, JAX uses 32-bit for floating point numbers. For sparse reconstruction algorithms, 32-bit precision is often not enough. Do make sure to configure JAX to use 64-bit floating point numbers before calling any JAX functions:
from jax.config import config config.update("jax_enable_x64", True)
Set a value at a particular index:
Add a value at a particular index:
Subtract a value at a particular index:
Swapping two elements at index i and j:
xi = x[i] xj = x[j] x = x.at[i].set(xj) x = x.at[j].set(xi)
Check if a vector contains increasing values:
jnp.all(jnp.diff(x) > 0)
Check if all values in a vector are equal:
jnp.all(x == x)
Convert a vector to a row vector (1xn):
Convert a vector to a column vector (nx1):
Construct a unit vector of length n with a zero in i-th dimension:
Note that the length of the array given by n has to be statically determined by the JIT compiler.
Right shift the contents of a vector by one element:
We first construct an array of the same shape as x containing all zeros.
We then fill the n -1 elements in this array (except the first element) with the first n-1 elements of x.
The last element of x is left out.
Our focus is on expressing our logic in a functional manner.
We leave it to the JIT compiler to come up with the efficient implementation of the logic for the target architecture.
If we want to right shift by n elements, then the logic becomes:
Return the magnitudes of elements of a vector in descending order:
We first get the magnitudes
We then sort the result using
We finally reverse the array in descending order by indexing
Let us be more adventurous.
We wish to find out how many of the largest elements
in a vector
a are enough to capture a fraction
of the total energy of the vector
The vector can be real or complex. We shall break this
down into multiple steps.
Compute energy of individual elements:
a = jnp.conj(a) * a
Sort the energies in descending order:
a = jnp.sort(a)[::-1]
Compute the total energy:
s = jnp.sum(a)
Normalize the energies to fractions:
a = a / s
Compute the cumulative energies starting from the largest coefficient:
cmf = jnp.cumsum(a)
Find the index at which the cumulative energy reaches
the required fraction
index = jnp.argmax(cmf >= q)
The required number of elements to capture
fraction of energy is
index + 1.
Consider the following function:
def f(x, alpha): if alpha == 0: return x return x / alpha
We shall now build this logic using
lax.cond step by step.
The condition to check is
alpha === 0.
We have to define two functions. One for the case where
the condition is true and another for the case where the
condition is false. For both cases, we shall define
anonymous functions using the
Here is the function for the true case:
lambda x : x
Here is the function for the false case:
lambda x: x / alpha
Both functions take
x as argument. Now, we can combine
these elements to form our functional equivalent code:
lax.cond(alpha == 0, lambda x : x, lambda x: x / alpha, x)
We suggest you to read the official documentation
to understand the details of
A circular buffer is a fixed size array in which one can push values either left or right side. When we push a new element, an old element from the other side is removed.
Assume that we are given a buffer
need to push a value
val from the left side:
If we need to push a value from the right side:
jnp.linalg.norm is the workhorse for
general norm computation. However, we
can often use simple computations for
specific cases ourselves.
Computing the l-1 norm:
Computing the l-2 norm:
Computing the l-inf norm:
Column wise norms
Often in sparse signal processing, we are dealing with a matrix consisting of vectors arranged column wise where we have to compute the norm of each vector.
Column-wise l-2 norm:
jnp.linalg.norm(X, ord=2, axis=0, keepdims=False)
keepdims=False flag is needed to ensure
that the result is reduced to a 1D array.
If we wish to compute the norm along rows, we can just
A common task is normalizing a vector so that it becomes unit norm. Care must be taken for the case where the vector is zero.
We can shift the norm value by a very small amount before carrying out the division. For 32-bit floating point numbers, the smallest positive value is given by:
EPS = jnp.finfo(jnp.float32).eps
Then normalization can be written as:
s = jnp.sqrt(jnp.abs(jnp.vdot(x, x))) + EPS x = jnp.divide(x, s)
This approach avoids a conditional expression
lax.cond. It is good to avoid
conditional code as much as possible as
they become bottlenecks (especially when the
numerical code is running on GPU hardware).
Since this normalization is slightly
inaccurate, you should examine the use case
if this inaccuracy is acceptable or not.
Checking if a matrix is symmetric:
Computing the Hermitian transpose:
jnp.conjugate(jnp.swapaxes(A, -1, -2))
Checking if a real matrix has orthogonal columns:
G = A.T @ A m = G.shape I = jnp.eye(m) result = jnp.allclose(G, I)
Checking for orthogonal rows:
G = A @ A.T m = G.shape I = jnp.eye(m) result = jnp.allclose(G, I, atol=m*m*atol)
Extracting the off-diagonal elements of a matrix:
mask = ~jnp.eye(*A.shape, dtype=bool) off_diagonal_elements = A[mask]
Setting the diagonal elements of a given matrix:
indices = jnp.diag_indices(A.shape) A = A.at[indices].set(value)
Adding something to the diagonal elements of a matrix:
indices = jnp.diag_indices(A.shape) A = A.at[indices].add(value)
Finding the index of the largest element (by magnitude) in each column of a matrix:
Premultiplying a matrix A with a diagonal matrix whose diagonal elements are given by a vector d:
jnp.multiply(d[:, None], A)
Post-multiplying a matrix A with a diagonal matrix whose diagonal elements are given by a vector d:
Extracting bxb diagonal blocks from a matrix:
n = A.shape nb = n // b starts = [i*b for i in range(nb)] blocks = jnp.array([A[s:s+b,s:s+b] for s in starts])
Constructing a Toeplitz matrix
A Toeplitz matrix is completely specified by its first row and column. E.g.,
[[1 2 3 4] [2 1 2 3] [3 2 1 2] [4 3 2 1]]
Suppose we are given the first row and first column of the Toeplitz matrix and we are required to construct the whole matrix. We can do so in a fashion which doesn’t require any loops. It is achieved by indexing magic.
def toeplitz_mat(c, r): m = len(c) n = len(r) # assert c == r w = jnp.concatenate((c[::-1], r[1:])) # backwards indices a = -jnp.arange(m, dtype=int) # forwards indices b = jnp.arange(m-1,m+n-1, dtype=int) # combine indices for the toeplitz matrix indices = a[:, None] + b[None, :] # form the toeplitz matrix mat = w[indices] return mat
We combined the first row and first column elements into a single array w. Then constructed an index matrix where each element in the index matrix is an index in the w array identifying the element to be placed in the output Toeplitz matrix. Forming the Toeplitz matrix then becomes a simple indexing step.
Basic Signal Processing¶
Scaling a vector to the range 0 and 1:
shift = jnp.min(x) x = x - shift scale = jnp.max(x) x = x / scale
x = x * scale x = x + shift
Hard thresholding to K largest elements:
indices = jnp.argsort(jnp.abs(x)) I = indices[:-K-1:-1] x_I = x[I]
Here the tuple of
(I, x_I) identifies
the indices and values of K largest entries.
To build the full length approximation, we will have
to do the following:
x = jnp.zeros_like(x) x = x.at[I].set(x_I)
Alternatively, we can do the following to compute the K sparse approximation:
indices = jnp.argsort(jnp.abs(x)) x = x.at[indices[:-K]].set(0)
A common signal processing task is to divide a signal x into windows of length w each such that consecutive windows have an overlap of m samples. Achieving this in JAX will require some indexing trick again:
step = w - m starts = jnp.arange(0, len(x) - w + 1, step) block = jnp.arange(w) idx = starts[:, None] + block[None, :] windows = x[idx]
This constructs the windows of x in each row of the resulting matrix. If you wish the windows to be column wise, just take the transpose.