Wavelets

CR-Sparse provides support for both DWT (Discrete Wavelet Transform) and CWT (Continuous Wavelet Transform).

The support for discrete wavelets is a partial port of wavelets functionality from the PyWavelets project. The functionality has been written on top of JAX. While PyWavelets gets its performance from the C extensions in its implementation, we have built the functionality on top of JAX API.

  • API and implementation are both based on functional programming paradigm.

  • There are no C extensions. All implementation is pure Python.

  • The implementation takes advantage of XLA and can run easily on GPUs and TPUs.

Continuous Wavelet Transform have been implemented following [TC98]. A reference implementation using NumPy is here.

The code examples in this section will assume following imports:

import cr.sparse as crs
import cr.sparse.wt as wt

Discrete Wavelets

API is available at two levels

  • Functions which directly correspond to the high level API of PyWavelets.

  • Lower level functions which are JIT compiled.

The high level functions involve handling a variety of use cases for the arguments passed. For example, they can accept lists as well as JAX nd-arrays. These functions cannot be JIT compiled. Lower level functions have been carefully designed to accept arguments which fit the JIT rules of JAX. They can be embedded in another JIT compiled function.

Current support is focused on discrete wavelet transforms. Following wavelets are supported.

bior1.1 bior1.3 bior1.5 bior2.2 bior2.4 bior2.6 bior2.8 bior3.1 bior3.3 bior3.5 bior3.7 bior3.9 bior4.4 bior5.5 bior6.8 coif1 coif2 coif3 coif4 coif5 coif6 coif7 coif8 coif9 coif10 coif11 coif12 coif13 coif14 coif15 coif16 coif17 db1 db2 db3 db4 db5 db6 db7 db8 db9 db10 db11 db12 db13 db14 db15 db16 db17 db18 db19 db20 db21 db22 db23 db24 db25 db26 db27 db28 db29 db30 db31 db32 db33 db34 db35 db36 db37 db38 dmey haar rbio1.1 rbio1.3 rbio1.5 rbio2.2 rbio2.4 rbio2.6 rbio2.8 rbio3.1 rbio3.3 rbio3.5 rbio3.7 rbio3.9 rbio4.4 rbio5.5 rbio6.8 sym2 sym3 sym4 sym5 sym6 sym7 sym8 sym9 sym10 sym11 sym12 sym13 sym14 sym15 sym16 sym17 sym18 sym19 sym20

High-level API

Data types

FAMILY

An enumeration describing the wavelet families supported in this library

SYMMETRY

Describes the type of symmetry in a wavelet

DiscreteWavelet

Represents information about a discrete wavelet

Wavelets

families([short])

Returns the list of (discrete) wavelet families supported by this package.

build_wavelet(name)

Builds a wavelet object by the name of the wavelet

wavelist([family, kind])

Returns the list of wavelts supported by this library for a specific wavelet family.

is_discrete_wavelet(name)

Returns if the wavelet family is a family of discrete wavelets

wname_to_family_order(name)

Returns the wavelet family and order from the name

build_discrete_wavelet(family, order)

Builds a descrete wavelet by its family and order

Wavelet transforms

dwt(data, wavelet[, mode, axis])

Computes single level discrete wavelet decomposition

idwt(ca, cd, wavelet[, mode, axis])

Computes single level discrete wavelet reconstruction

dwt2(image, wavelet[, mode, axes])

Computes single level wavelet decomposition for 2D images

idwt2(coeffs, wavelet[, mode, axes])

Computes single level wavelet reconstruction for 2D images

downcoef(part, data, wavelet[, mode, level])

Partial discrete wavelet decomposition (multi-level)

upcoef(part, coeffs, wavelet[, mode, level, …])

Partial discrete wavelet reconstruction from one part of coefficients (multi-level)

wavedec(data, wavelet[, mode, level, axis])

Computes multilevel 1D discrete wavelet transform

waverec(coeffs, wavelet[, mode, axis])

Multilevel 1D inverse discrete wavelet transform

dwt_axis(data, wavelet, axis[, mode])

Computes single level wavelet decomposition along a given axis

idwt_axis(ca, cd, wavelet, axis[, mode])

Computes single level wavelet reconstruction along a given axis

dwt_column(data, wavelet[, mode])

Computes single level wavelet decomposition along columns (axis-0)

dwt_row(data, wavelet[, mode])

Computes single level wavelet decomposition along rows (axis-1)

dwt_tube(data, wavelet[, mode])

Computes single level wavelet decomposition along tubes (axis-2)

idwt_column(ca, cd, wavelet[, mode])

Computes single level wavelet reconstruction along columns (axis-0)

idwt_row(ca, cd, wavelet[, mode])

Computes single level wavelet reconstruction along rows (axis-1)

idwt_tube(ca, cd, wavelet[, mode])

Computes single level wavelet reconstruction along tubes (axis-2)

Utilities

modes

Built-in mutable sequence.

pad(data, pad_widths, mode)

Pads a given 1D signal using a given boundary mode.

dwt_max_level(input_len, filter_len)

Returns the maximum level of useful DWT decomposition based on data length and filter length

dwt_coeff_len(data_len, filter_len, mode)

Returns the length of wavelet decomposition output based on data length, filter length and mode

up_sample(x, s)

Upsample x by a factor s by introducing zeros in between

Lower-level API

dwt_(data, dec_lo, dec_hi, mode)

Computes single level discrete wavelet decomposition

idwt_(ca, cd, rec_lo, rec_hi, mode)

Computes single level discrete wavelet reconstruction

downcoef_(data, filter, mode)

Partial discrete wavelet decomposition

upcoef_(coeffs, filter, mode)

Partial discrete wavelet reconstruction from one part of coefficients

dwt_axis_(data, dec_lo, dec_hi, axis, mode)

Applies the DWT along a given axis

idwt_axis_(ca, cd, rec_lo, rec_hi, axis, mode)

Applies the Inverse DWT along a given axis

Signal Extension Modes

Real world signals are finite. They are typically stored in finite size arrays in computers. Computing the wavelet transform of signal values around the boundary of the signal inevitably involves assuming some form of signal extrapolation. A simple extrapolation method is to extend the signal with zeros at the boundary. Reconstruction of the signal from its wavelet coefficients may introduce boundary artifacts based on how the signal was extrapolated. A careful choice of signal extension method is necessary based on actual application.

We provide following signal extension modes at the moment.

zero

Signal is extended by adding zeros:

>>> wt.pad(jnp.array([1,2,4,-1,2,-1]), 2, 'zero')
DeviceArray([ 0,  0,  1,  2,  4, -1,  2, -1,  0,  0], dtype=int64)
constant

Border values of the signal are replicated:

>>> wt.pad(jnp.array([1,2,4,-1,2,-1]), 2, 'constant')
DeviceArray([ 1,  1,  1,  2,  4, -1,  2, -1, -1, -1], dtype=int64)
symmetric

Signal is extended by mirroring the samples at the border in mirror form. The border sample is also mirrored.:

>>> wt.pad(jnp.array([1,2,4,-1,2,-1]), 2, 'symmetric')
DeviceArray([ 2,  1,  1,  2,  4, -1,  2, -1, -1,  2], dtype=int64)
reflect

Signal is extended by reflecting the samples around the border sample. Border sample is not copied in the extension.:

>>> wt.pad(jnp.array([1,2,4,-1,2,-1]), 2, 'reflect')
DeviceArray([ 4,  2,  1,  2,  4, -1,  2, -1,  2, -1], dtype=int64)
periodic

Signal is extended periodically. The samples at the end repeat at the extension at the beginning. The samples at the beginning repeat at the extension at the end.:

>>> wt.pad(jnp.array([1,2,4,-1,2,-1]), 2, 'periodic')
DeviceArray([ 2, -1,  1,  2,  4, -1,  2, -1,  1,  2], dtype=int64)
periodization

The signal is extended the same way as the periodic extension. The major difference is that the number of wavelet coefficients is identical to the length of the signal. All extra values are trimmed.

Many of the signal extension modes are similar to the padding modes supported by the jax.numpy.pad function. However, the naming convention is different and follows PyWavelets.

Continuous Wavelets

Further Reading

TC98

Christopher Torrence and Gilbert P Compo. A practical guide to wavelet analysis. Bulletin of the American Meteorological society, 79(1):61–78, 1998.