Note
Click here to download the full example code
Wavelet Decomposition of an Image¶
This is a simple example which demonstrates how
cr.sparse.wt.dwt2
function can be used to
perform 2D wavelet decomposition.
Its interface is identical to the corresponding function in PyWavelets library.
This example is adapted from PyWavelets documentation.
Configure JAX to work with 64-bit floating point precision.
from jax.config import config
config.update("jax_enable_x64", True)
Let’s import necessary libraries
import jax.numpy as jnp
# CR-Sparse libraries
import cr.sparse.wt as wt
# We use PyWavelets only for sample data
import pywt.data
# Plotting
import matplotlib.pyplot as plt
Load the Cameraman image
original = pywt.data.camera()
Perform wavelet decomposition
Split the coefficients tuple into individual parts
LL, (LH, HL, HH) = coeffs2
Plot the decomposition
titles = ['Approximation', ' Horizontal detail',
'Vertical detail', 'Diagonal detail']
fig = plt.figure(figsize=(12, 3))
for i, a in enumerate([LL, LH, HL, HH]):
ax = fig.add_subplot(1, 4, i + 1)
ax.imshow(a, interpolation="nearest", cmap=plt.cm.gray)
ax.set_title(titles[i], fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
fig.tight_layout()
Total running time of the script: ( 0 minutes 0.776 seconds)