# Copyright 2021 CR.Sparse Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum, auto
from typing import NamedTuple, List, Dict, Tuple
import jax.numpy as jnp
import jax.numpy.fft as jfft
from .families import FAMILY, wname_to_family_order, is_discrete_wavelet
from .coeffs import db, sym, coif, bior, dmey, sqrt2
from .cont_wavelets import WaveletFunctions, cmor, ricker
import re
[docs]class SYMMETRY(Enum):
"""Describes the type of symmetry in a wavelet
"""
UNKNOWN = -1
"""Unknown Symmetry"""
ASYMMETRIC = 0
"""Assymetric Wavelet"""
NEAR_SYMMETRIC = 1
"""Near Symmetric Wavelet"""
SYMMETRIC = 2
"""Symmetric Wavelet"""
ANTI_SYMMETRIC = 3
"""Anti-symmetric Wavelet"""
class BaseWavelet(NamedTuple):
"""Represents basic information about a wavelet
"""
support_width: int = 0
symmetry: SYMMETRY = SYMMETRY.UNKNOWN
orthogonal: bool = False
biorthogonal: bool = False
compact_support: bool = False
name: FAMILY = None
family_name: str = None
short_name: str = None
[docs]class DiscreteWavelet(NamedTuple):
"""Represents information about a discrete wavelet
"""
support_width: int = -1
"""Length of the support for finite support wavelets"""
symmetry: SYMMETRY = SYMMETRY.UNKNOWN
"""Indicates the kind of symmetry inside the wavelet"""
orthogonal: bool = False
"""Indicates if the wavelet is orthogonal"""
biorthogonal: bool = False
"""Indicates if the wavelet is biorthogonal"""
compact_support: bool = False
"""Indicates if the wavelet has compact support"""
name: str = ''
"""Name of the wavelet"""
family_name: str = ''
"""Name of the wavelet family"""
short_name: str = ''
"""Short name of the wavelet family"""
dec_hi: jnp.DeviceArray = None
"""Decomposition high pass filter"""
dec_lo: jnp.DeviceArray = None
"""Decomposition low pass filter"""
rec_hi: jnp.DeviceArray = None
"""Reconstruction high pass filter"""
rec_lo: jnp.DeviceArray = None
"""Reconstruction low pass filter"""
dec_len: int = 0
"""Length of decomposition filters"""
rec_len: int = 0
"""Length of reconstruction filters"""
vanishing_moments_psi: int = 0
"""Number of vanishing moments of the wavelet function"""
vanishing_moments_phi: int = 0
"""Number of vanishing moments of the scaling function"""
def __str__(self):
"""Returns the string representation of the discrete wavelet object
"""
s = []
for x in [
u"Wavelet %s" % self.name,
u" Family name: %s" % self.family_name,
u" Short name: %s" % self.short_name,
u" Filters length: %d" % self.dec_len,
u" Orthogonal: %s" % self.orthogonal,
u" Biorthogonal: %s" % self.biorthogonal,
u" Symmetry: %s" % self.symmetry.name.lower(),
u" DWT: True",
u" CWT: False"
]:
s.append(x.rstrip())
return u'\n'.join(s)
def wavefun(self, level=8):
"""Returns the scaling and wavelet functions for the wavelet
Args:
level (:obj:`int`, optional): Number of levels of reconstruction
to get the approximation of scaling and wavelet functions.
Default 8.
"""
from .discrete import orth_wavefun, biorth_wavefun
if self.orthogonal:
return orth_wavefun(self, level=level)
if self.biorthogonal:
return biorth_wavefun(self, level=level)
raise NotImplemented
@property
def filter_bank(self):
"""Returns the Quadratrure Mirror Filter Bank associated with the wavelet (dec_lo, dec_hi, rec_lo, rec_hi)
"""
return (self.dec_lo, self.dec_hi, self.rec_lo, self.rec_hi)
@property
def inverse_filter_bank(self):
"""Returns the filter bank associated with the inverse wavelet
"""
return (self.rec_lo[::-1], self.rec_hi[::-1],
self.dec_lo[::-1], self.dec_hi[::-1])
class ContinuousWavelet(NamedTuple):
"""Represents information about a continuous wavelet
"""
support_width: int = -1
"""Length of the support for finite support wavelets"""
symmetry: SYMMETRY = SYMMETRY.UNKNOWN
"""Indicates the kind of symmetry inside the wavelet"""
orthogonal: bool = False
"""Indicates if the wavelet is orthogonal"""
biorthogonal: bool = False
"""Indicates if the wavelet is biorthogonal"""
compact_support: bool = False
"""Indicates if the wavelet has compact support"""
name: str = ''
"""Name of the wavelet"""
family_name: str = ''
"""Name of the wavelet family"""
short_name: str = ''
"""Short name of the wavelet family"""
# additinal parameters for continuous wavelets
lower_bound: float = 0
"""time window lower bound for computing the wavelet function"""
upper_bound: float = 0
"""time window upper bound for computing the wavelet function"""
complex_cwt: bool = False
"""flag indicating if the wavelet is complex or real"""
center_frequency: float = -1.
"""center frequency of the wavelet"""
bandwidth_frequency: float = -1.
"""bandwidth of the wavelet"""
fbsp_order: int = 0
functions: WaveletFunctions = None
"""Functions associated with the wavelet"""
def __str__(self):
s = []
for x in [
u"ContinuousWavelet %s" % self.name,
u" Family name: %s" % self.family_name,
u" Short name: %s" % self.short_name,
u" Symmetry: %s" % self.symmetry.name.lower(),
u" DWT: False",
u" CWT: True",
u" Complex CWT: %s" % self.complex_cwt,
]:
s.append(x.rstrip())
return u'\n'.join(s)
def wavefun(self, level=8, length=None):
"""Returns the wavelet function for the wavelet
Args:
level (:obj:`int`, optional): Number of levels of reconstruction
to get the approximation of the wavelet function.
Default 8.
"""
if self.functions is None:
raise NotImplementedError(f"No implementation available for {self.name}")
func = self.functions.time
p = 2**level
output_length = p if length is None else length
t = jnp.linspace(self.lower_bound, self.upper_bound, output_length)
psi = func(t)
return psi, t
@property
def domain(self):
"""Returns the time domain of the wavelet
"""
return self.upper_bound - self.lower_bound
def qmf(h):
"""Returns the quadrature mirror filter of a given filter"""
g = h[::-1]
g = g.at[1::2].set(-g[1::2])
return g
def orthogonal_filter_bank(scaling_filter):
"""Returns the orthogonal filter bank for a given scaling filter"""
# scaling filter must be even
if not (scaling_filter.shape[0] % 2) == 0:
raise ValueError('scaling_filter must be of even length.')
# normalize
rec_lo = sqrt2 * scaling_filter / jnp.sum(scaling_filter)
dec_lo = rec_lo[::-1]
rec_hi = qmf(rec_lo)
dec_hi = rec_hi[::-1]
return (dec_lo, dec_hi, rec_lo, rec_hi)
def filter_bank_(rec_lo):
"""Construct a filter bank from the saved values in coeffs.py"""
dec_lo = rec_lo[::-1]
rec_hi = qmf(rec_lo)
dec_hi = rec_hi[::-1]
return (dec_lo, dec_hi, rec_lo, rec_hi)
def mirror(h):
n = h.shape[0]
modulation = (-1)**jnp.arange(1, n+1)
return modulation * h
def negate_evens(g):
return g.at[0::2].set(-g[0::2])
def negate_odds(g):
return g.at[1::2].set(-g[1::2])
def bior_index(n, m):
idx = max = None
if n == 1:
idx = m // 2
max = 5
elif n == 2:
idx = m // 2 -1
max = 8
elif n == 3:
idx = m // 2
max = 9
elif n == 4 or n == 5:
if n == m:
idx = 0
max = m
elif n == 6:
if m == 8:
idx = 0
max = 8
else:
pass
return idx, max
[docs]def build_discrete_wavelet(family: FAMILY, order: int):
"""Builds a descrete wavelet by its family and order
"""
nv = family.value
if nv is FAMILY.HAAR.value:
dec_lo, dec_hi, rec_lo, rec_hi = filter_bank_(db[0])
w = DiscreteWavelet(support_width=1,
symmetry=SYMMETRY.ASYMMETRIC,
orthogonal=True,
biorthogonal=True,
compact_support=True,
name="Haar",
family_name = "Haar",
short_name="haar",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=2,
rec_len=2,
vanishing_moments_psi=1,
vanishing_moments_phi=0)
return w
if nv == FAMILY.DB.value:
index = order - 1
if index >= len(db):
return None
filters_length = 2 * order
dec_len = rec_len = filters_length
dec_lo, dec_hi, rec_lo, rec_hi = filter_bank_(db[index])
w = DiscreteWavelet(support_width=2*order-1,
symmetry=SYMMETRY.ASYMMETRIC,
orthogonal=True,
biorthogonal=True,
compact_support=True,
name=f'db{order}',
family_name = "Daubechies",
short_name="db",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=order,
vanishing_moments_phi=0)
return w
if nv == FAMILY.SYM.value:
index = order - 2
if index >= len(sym):
return None
filters_length = 2 * order
dec_len = rec_len = filters_length
dec_lo, dec_hi, rec_lo, rec_hi = filter_bank_(sym[index])
w = DiscreteWavelet(support_width=2*order-1,
symmetry=SYMMETRY.NEAR_SYMMETRIC,
orthogonal=True,
biorthogonal=True,
compact_support=True,
name=f'sym{order}',
family_name = "Symlets",
short_name="sym",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=order,
vanishing_moments_phi=0)
return w
if nv == FAMILY.COIF.value:
index = order - 1
if index >= len(coif):
return None
filters_length = 6 * order
dec_len = rec_len = filters_length
dec_lo, dec_hi, rec_lo, rec_hi = filter_bank_(coif[index]*sqrt2)
w = DiscreteWavelet(support_width=6*order-1,
symmetry=SYMMETRY.NEAR_SYMMETRIC,
orthogonal=True,
biorthogonal=True,
compact_support=True,
name=f'coif{order}',
family_name = "Coiflets",
short_name="coif",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=2*order,
vanishing_moments_phi=2*order-1)
return w
if nv == FAMILY.BIOR.value:
n = order // 10
m = order % 10
idx, max = bior_index(n, m)
if idx is None or max is None:
return None
arr = bior[n-1]
if idx >= len(arr):
return None
filters_length = 2*m if n == 1 else 2*m + 2
dec_len = rec_len = filters_length
start = max - m
rec_lo = arr[0][start:start+rec_len]
dec_lo = arr[idx+1][::-1]
rec_hi = negate_odds(dec_lo)
dec_hi = negate_evens(rec_lo)
w = DiscreteWavelet(support_width=6*order-1,
symmetry=SYMMETRY.SYMMETRIC,
orthogonal=False,
biorthogonal=True,
compact_support=True,
name=f'bior{n}.{m}',
family_name = "Biorthogonal",
short_name="bior",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=2*order,
vanishing_moments_phi=2*order-1)
return w
if nv == FAMILY.RBIO.value:
n = order // 10
m = order % 10
idx, max = bior_index(n, m)
if idx is None or max is None:
return None
arr = bior[n-1]
if idx >= len(arr):
return None
filters_length = 2*m if n == 1 else 2*m + 2
dec_len = rec_len = filters_length
start = max - m
dec_lo = arr[0][start:start+rec_len][::-1]
rec_lo = arr[idx+1]
rec_hi = negate_odds(dec_lo)
dec_hi = negate_evens(rec_lo)
w = DiscreteWavelet(support_width=6*order-1,
symmetry=SYMMETRY.SYMMETRIC,
orthogonal=False,
biorthogonal=True,
compact_support=True,
name=f'rbio{n}.{m}',
family_name = "Reverse biorthogonal",
short_name="rbio",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=2*order,
vanishing_moments_phi=2*order-1)
return w
if nv is FAMILY.DMEY.value:
dec_len = rec_len = filters_length = 62
dec_lo, dec_hi, rec_lo, rec_hi = filter_bank_(dmey)
w = DiscreteWavelet(support_width=1,
symmetry=SYMMETRY.SYMMETRIC,
orthogonal=True,
biorthogonal=True,
compact_support=True,
name="dmey",
family_name = "Discrete Meyer (FIR Approximation)",
short_name="dmey",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=-1,
vanishing_moments_phi=-1)
return w
return None
# regular expression for finding bandwidth-frequency and center-frequency
cwt_pattern = re.compile(r'\D+(\d+\.*\d*)+')
def _get_bw_center_freqs(freqs, bandwidth_frequency, center_frequency):
if len(freqs) == 2:
bandwidth_frequency = float(freqs[0])
center_frequency = float(freqs[1])
return bandwidth_frequency, center_frequency
def _get_m_b_c(freqs, fbsp_order, bandwidth_frequency, center_frequency):
if len(freqs) == 3:
fbsp_order = int(freqs[0])
bandwidth_frequency = float(freqs[1])
center_frequency = float(freqs[2])
return fbsp_order, bandwidth_frequency, center_frequency
def build_continuous_wavelet(name: str, family: FAMILY, order: int):
"""Builds a continuous wavelet by its family and order
"""
# wavelet base name
base_name = name[:4]
subname = name[4:]
# indentify the B-C pattern if present
freqs = re.findall(cwt_pattern, name)
if subname and len(freqs) == 0:
raise ValueError("No frequencies have been specified.")
freqs = [float(freq) for freq in freqs]
nv = family.value
if nv == FAMILY.GAUS.value:
if order > 8:
return None
symmetry = SYMMETRY.SYMMETRIC if order % 2 == 0 else SYMMETRY.ANTI_SYMMETRIC
w = ContinuousWavelet(support_width=-1,
symmetry=symmetry,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Gaussian",
short_name="gaus",
complex_cwt=False,
lower_bound=-5.,
upper_bound=5.,
center_frequency=0.,
bandwidth_frequency=0.,
fbsp_order=0)
return w
elif nv == FAMILY.MEXH.value:
functions = ricker()
w = ContinuousWavelet(support_width=-1,
symmetry=SYMMETRY.SYMMETRIC,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Mexican hat wavelet",
short_name="mexh",
complex_cwt=False,
lower_bound=-8.,
upper_bound=8.,
center_frequency=0.25,
bandwidth_frequency=0.,
fbsp_order=0,
functions=functions)
return w
elif nv == FAMILY.MORL.value:
w = ContinuousWavelet(support_width=-1,
symmetry=SYMMETRY.SYMMETRIC,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Morlet wavelet",
short_name="morl",
complex_cwt=False,
lower_bound=-8.,
upper_bound=8.,
center_frequency=0.,
bandwidth_frequency=0.,
fbsp_order=0)
return w
elif nv == FAMILY.CGAU.value:
if order > 8:
return None
symmetry = SYMMETRY.SYMMETRIC if order % 2 == 0 else SYMMETRY.ANTI_SYMMETRIC
w = ContinuousWavelet(support_width=-1,
symmetry=symmetry,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Complex Gaussian wavelets",
short_name="cgau",
complex_cwt=True,
lower_bound=-5.,
upper_bound=5.,
center_frequency=0.,
bandwidth_frequency=0.,
fbsp_order=0)
return w
elif nv == FAMILY.SHAN.value:
bandwidth_frequency, center_frequency = _get_bw_center_freqs(freqs, 0.5, 1.)
w = ContinuousWavelet(support_width=-1,
symmetry=SYMMETRY.ASYMMETRIC,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Shannon wavelets",
short_name="shan",
complex_cwt=True,
lower_bound=-20.,
upper_bound=20.,
center_frequency=center_frequency,
bandwidth_frequency=bandwidth_frequency,
fbsp_order=0)
return w
elif nv == FAMILY.FBSP.value:
fbsp_order, bandwidth_frequency, center_frequency = _get_m_b_c(freqs, 2, 1., 0.5)
w = ContinuousWavelet(support_width=-1,
symmetry=SYMMETRY.ASYMMETRIC,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Frequency B-Spline wavelets",
short_name="fbsp",
complex_cwt=True,
lower_bound=-20.,
upper_bound=20.,
center_frequency=center_frequency,
bandwidth_frequency=bandwidth_frequency,
fbsp_order=fbsp_order)
return w
elif nv == FAMILY.CMOR.value:
bandwidth_frequency, center_frequency = _get_bw_center_freqs(freqs, 1., 0.5)
functions = cmor(bandwidth_frequency, center_frequency)
w = ContinuousWavelet(support_width=-1,
symmetry=SYMMETRY.ASYMMETRIC,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Complex Morlet wavelets",
short_name="cmor",
complex_cwt=True,
lower_bound=-8.,
upper_bound=8.,
center_frequency=center_frequency,
bandwidth_frequency=bandwidth_frequency,
fbsp_order=2,
functions=functions)
return w
return None
[docs]def build_wavelet(name):
"""Builds a wavelet object by the name of the wavelet
Args:
name (str): Name of the wavelet
Returns:
cr.sparse.wt.DiscreteWavelet: a discrete wavelet object
Example:
::
>>> wavelet = wt.build_wavelet('db1')
>>> print(wavelet)
Wavelet db1
Family name: Daubechies
Short name: db
Filters length: 2
Orthogonal: True
Biorthogonal: True
Symmetry: asymmetric
DWT: True
CWT: False
>>> dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank
>>> print(dec_lo)
>>> print(dec_hi)
>>> print(rec_lo)
>>> print(rec_hi)
[0.70710678 0.70710678]
[-0.70710678 0.70710678]
[0.70710678 0.70710678]
[ 0.70710678 -0.70710678]
>>> phi, psi, x = wavelet.wavefun()
"""
name = name.lower()
family, order = wname_to_family_order(name)
wavelet = None
if is_discrete_wavelet(family):
wavelet = build_discrete_wavelet(family, order)
else:
wavelet = build_continuous_wavelet(name, family, order)
# other wavelet types are not supported for now
if wavelet is None:
raise ValueError(f"Invalid wavelet name {name}")
return wavelet
def rec_integrate(function, dt):
"""Integrate a function using the rectangle integration method
"""
integral = jnp.cumsum(function)
integral *= dt
return integral
def to_wavelet(wavelet):
if isinstance(wavelet, str):
wavelet = build_wavelet(wavelet)
if wavelet is None:
raise ValueError("Invalid wavelet")
return wavelet
def integrate_wavelet(wavelet, precision=8):
"""Integrate wavelet function using the rectangle integration method
"""
wavelet = to_wavelet(wavelet)
approximations = wavelet.wavefun(precision)
if len(approximations) == 2:
psi, t = approximations
dt = t[1] - t[0]
return rec_integrate(psi, dt), t
elif len(approximations) == 3:
phi, psi, t = approximations
dt = t[1] - t[0]
return rec_integrate(psi, dt), t
elif len(approximations) == 5:
phi_d, psi_d, phi_r, psi_r, t = functions_approximations
dt = t[1] - t[0]
return rec_integrate(psi_d, dt), rec_integrate(psi_r, dt), t
def central_frequency(wavelet, precision=8):
"""Computes the central frequency of the wavelet function
"""
wavelet = to_wavelet(wavelet)
# Let's see if the central frequency is defined for the wavelet
if wavelet.center_frequency:
return wavelet.center_frequency
# get the wavelet functions
approximations = wavelet.wavefun(precision)
if len(approximations) == 2:
psi, t = approximations
elif len(approximations) == 3:
phi, psi, t = approximations
elif len(approximations) == 5:
phi_d, psi, phi_r, psi_r, t = functions_approximations
domain = t[-1] - t[0]
# identify the peak frequency [skip the DC]
index = jnp.argmax(jnp.abs(jfft.fft(psi)[1:])) + 2
if index > len(psi) / 2:
index = len(psi) - index + 2
# convert to Hz
return 1.0 / (domain / (index - 1))
def scale2frequency(wavelet, scales, precision=8):
"""Converts scales to frequencies for a wavelet
"""
scales = jnp.asarray(scales)
cf = central_frequency(wavelet, precision=precision)
return cf / scales