# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum, auto
from typing import NamedTuple, List, Dict, Tuple
import jax
import jax.numpy as jnp
import jax.numpy.fft as jfft
from .families import FAMILY, wname_to_family_order, is_discrete_wavelet
from .coeffs import db, sym, coif, bior, dmey, sqrt2
from .cont_wavelets import WaveletFunctions, cmor, ricker
import re
[docs]class SYMMETRY(Enum):
"""Describes the type of symmetry in a wavelet
"""
UNKNOWN = -1
"""Unknown Symmetry"""
ASYMMETRIC = 0
"""Assymetric Wavelet"""
NEAR_SYMMETRIC = 1
"""Near Symmetric Wavelet"""
SYMMETRIC = 2
"""Symmetric Wavelet"""
ANTI_SYMMETRIC = 3
"""Anti-symmetric Wavelet"""
class BaseWavelet(NamedTuple):
"""Represents basic information about a wavelet
"""
support_width: int = 0
symmetry: SYMMETRY = SYMMETRY.UNKNOWN
orthogonal: bool = False
biorthogonal: bool = False
compact_support: bool = False
name: FAMILY = None
family_name: str = None
short_name: str = None
[docs]class DiscreteWavelet(NamedTuple):
"""Represents information about a discrete wavelet
"""
support_width: int = -1
"""Length of the support for finite support wavelets"""
symmetry: SYMMETRY = SYMMETRY.UNKNOWN
"""Indicates the kind of symmetry inside the wavelet"""
orthogonal: bool = False
"""Indicates if the wavelet is orthogonal"""
biorthogonal: bool = False
"""Indicates if the wavelet is biorthogonal"""
compact_support: bool = False
"""Indicates if the wavelet has compact support"""
name: str = ''
"""Name of the wavelet"""
family_name: str = ''
"""Name of the wavelet family"""
short_name: str = ''
"""Short name of the wavelet family"""
dec_hi: jax.Array = None
"""Decomposition high pass filter"""
dec_lo: jax.Array = None
"""Decomposition low pass filter"""
rec_hi: jax.Array = None
"""Reconstruction high pass filter"""
rec_lo: jax.Array = None
"""Reconstruction low pass filter"""
dec_len: int = 0
"""Length of decomposition filters"""
rec_len: int = 0
"""Length of reconstruction filters"""
vanishing_moments_psi: int = 0
"""Number of vanishing moments of the wavelet function"""
vanishing_moments_phi: int = 0
"""Number of vanishing moments of the scaling function"""
def __str__(self):
"""Returns the string representation
"""
s = []
for x in [
u"Wavelet %s" % self.name,
u" Family name: %s" % self.family_name,
u" Short name: %s" % self.short_name,
u" Filters length: %d" % self.dec_len,
u" Orthogonal: %s" % self.orthogonal,
u" Biorthogonal: %s" % self.biorthogonal,
u" Symmetry: %s" % self.symmetry.name.lower(),
u" DWT: True",
u" CWT: False"
]:
s.append(x.rstrip())
return u'\n'.join(s)
def wavefun(self, level=8):
"""Returns the scaling and wavelet functions for the wavelet
Args:
level (:obj:`int`, optional): Number of levels of reconstruction
to get the approximation of scaling and wavelet functions.
Default 8.
"""
from .discrete import orth_wavefun_jit, biorth_wavefun
if self.orthogonal:
return orth_wavefun_jit(self.rec_lo, self.rec_hi, level=level)
if self.biorthogonal:
return biorth_wavefun(self, level=level)
raise NotImplementedError()
@property
def filter_bank(self):
"""Returns the Quadratrure Mirror Filter Bank associated with the wavelet (dec_lo, dec_hi, rec_lo, rec_hi)
"""
return (self.dec_lo, self.dec_hi, self.rec_lo, self.rec_hi)
@property
def inverse_filter_bank(self):
"""Returns the filter bank associated with the inverse wavelet
"""
return (self.rec_lo[::-1], self.rec_hi[::-1],
self.dec_lo[::-1], self.dec_hi[::-1])
[docs]class ContinuousWavelet(NamedTuple):
"""Represents information about a continuous wavelet
"""
support_width: int = -1
"""Length of the support for finite support wavelets"""
symmetry: SYMMETRY = SYMMETRY.UNKNOWN
"""Indicates the kind of symmetry inside the wavelet"""
orthogonal: bool = False
"""Indicates if the wavelet is orthogonal"""
biorthogonal: bool = False
"""Indicates if the wavelet is biorthogonal"""
compact_support: bool = False
"""Indicates if the wavelet has compact support"""
name: str = ''
"""Name of the wavelet"""
family_name: str = ''
"""Name of the wavelet family"""
short_name: str = ''
"""Short name of the wavelet family"""
# additinal parameters for continuous wavelets
lower_bound: float = 0
"""time window lower bound for computing the wavelet function"""
upper_bound: float = 0
"""time window upper bound for computing the wavelet function"""
complex_cwt: bool = False
"""flag indicating if the wavelet is complex or real"""
center_frequency: float = -1.
"""center frequency of the wavelet"""
bandwidth_frequency: float = -1.
"""bandwidth of the wavelet"""
fbsp_order: int = 0
functions: WaveletFunctions = None
"""Functions associated with the wavelet"""
def __str__(self):
s = []
for x in [
u"ContinuousWavelet %s" % self.name,
u" Family name: %s" % self.family_name,
u" Short name: %s" % self.short_name,
u" Symmetry: %s" % self.symmetry.name.lower(),
u" DWT: False",
u" CWT: True",
u" Complex CWT: %s" % self.complex_cwt,
]:
s.append(x.rstrip())
return u'\n'.join(s)
def wavefun(self, level=8, length=None):
"""Returns the wavelet function for the wavelet
Args:
level (:obj:`int`, optional): Number of levels of reconstruction
to get the approximation of the wavelet function.
Default 8.
"""
if self.functions is None:
raise NotImplementedError(f"No implementation available for {self.name}")
func = self.functions.time
p = 2**level
output_length = p if length is None else length
t = jnp.linspace(self.lower_bound, self.upper_bound, output_length)
psi = func(t)
return psi, t
@property
def domain(self):
"""Returns the time domain of the wavelet
"""
return self.upper_bound - self.lower_bound
def qmf(h):
"""Returns the quadrature mirror filter of a given filter"""
g = h[::-1]
g = g.at[1::2].set(-g[1::2])
return g
def orthogonal_filter_bank(scaling_filter):
"""Returns the orthogonal filter bank for a given scaling filter"""
# scaling filter must be even
if not (scaling_filter.shape[0] % 2) == 0:
raise ValueError('scaling_filter must be of even length.')
# normalize
rec_lo = sqrt2 * scaling_filter / jnp.sum(scaling_filter)
dec_lo = rec_lo[::-1]
rec_hi = qmf(rec_lo)
dec_hi = rec_hi[::-1]
return (dec_lo, dec_hi, rec_lo, rec_hi)
def filter_bank_(rec_lo):
"""Construct a filter bank from the saved values in coeffs.py"""
dec_lo = rec_lo[::-1]
rec_hi = qmf(rec_lo)
dec_hi = rec_hi[::-1]
return (dec_lo, dec_hi, rec_lo, rec_hi)
def mirror(h):
n = h.shape[0]
modulation = (-1)**jnp.arange(1, n+1)
return modulation * h
def negate_evens(g):
return g.at[0::2].set(-g[0::2])
def negate_odds(g):
return g.at[1::2].set(-g[1::2])
def bior_index(n, m):
idx = max = None
if n == 1:
idx = m // 2
max = 5
elif n == 2:
idx = m // 2 -1
max = 8
elif n == 3:
idx = m // 2
max = 9
elif n == 4 or n == 5:
if n == m:
idx = 0
max = m
elif n == 6:
if m == 8:
idx = 0
max = 8
else:
pass
return idx, max
[docs]def build_discrete_wavelet(family: FAMILY, order: int):
"""Builds a descrete wavelet by its family and order
"""
nv = family.value
if nv is FAMILY.HAAR.value:
dec_lo, dec_hi, rec_lo, rec_hi = filter_bank_(db[0])
w = DiscreteWavelet(support_width=1,
symmetry=SYMMETRY.ASYMMETRIC,
orthogonal=True,
biorthogonal=True,
compact_support=True,
name="Haar",
family_name = "Haar",
short_name="haar",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=2,
rec_len=2,
vanishing_moments_psi=1,
vanishing_moments_phi=0)
return w
if nv == FAMILY.DB.value:
index = order - 1
if index >= len(db):
return None
filters_length = 2 * order
dec_len = rec_len = filters_length
dec_lo, dec_hi, rec_lo, rec_hi = filter_bank_(db[index])
w = DiscreteWavelet(support_width=2*order-1,
symmetry=SYMMETRY.ASYMMETRIC,
orthogonal=True,
biorthogonal=True,
compact_support=True,
name=f'db{order}',
family_name = "Daubechies",
short_name="db",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=order,
vanishing_moments_phi=0)
return w
if nv == FAMILY.SYM.value:
index = order - 2
if index >= len(sym):
return None
filters_length = 2 * order
dec_len = rec_len = filters_length
dec_lo, dec_hi, rec_lo, rec_hi = filter_bank_(sym[index])
w = DiscreteWavelet(support_width=2*order-1,
symmetry=SYMMETRY.NEAR_SYMMETRIC,
orthogonal=True,
biorthogonal=True,
compact_support=True,
name=f'sym{order}',
family_name = "Symlets",
short_name="sym",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=order,
vanishing_moments_phi=0)
return w
if nv == FAMILY.COIF.value:
index = order - 1
if index >= len(coif):
return None
filters_length = 6 * order
dec_len = rec_len = filters_length
dec_lo, dec_hi, rec_lo, rec_hi = filter_bank_(coif[index]*sqrt2)
w = DiscreteWavelet(support_width=6*order-1,
symmetry=SYMMETRY.NEAR_SYMMETRIC,
orthogonal=True,
biorthogonal=True,
compact_support=True,
name=f'coif{order}',
family_name = "Coiflets",
short_name="coif",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=2*order,
vanishing_moments_phi=2*order-1)
return w
if nv == FAMILY.BIOR.value:
n = order // 10
m = order % 10
idx, max = bior_index(n, m)
if idx is None or max is None:
return None
arr = bior[n-1]
if idx >= len(arr):
return None
filters_length = 2*m if n == 1 else 2*m + 2
dec_len = rec_len = filters_length
start = max - m
rec_lo = arr[0][start:start+rec_len]
dec_lo = arr[idx+1][::-1]
rec_hi = negate_odds(dec_lo)
dec_hi = negate_evens(rec_lo)
w = DiscreteWavelet(support_width=6*order-1,
symmetry=SYMMETRY.SYMMETRIC,
orthogonal=False,
biorthogonal=True,
compact_support=True,
name=f'bior{n}.{m}',
family_name = "Biorthogonal",
short_name="bior",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=2*order,
vanishing_moments_phi=2*order-1)
return w
if nv == FAMILY.RBIO.value:
n = order // 10
m = order % 10
idx, max = bior_index(n, m)
if idx is None or max is None:
return None
arr = bior[n-1]
if idx >= len(arr):
return None
filters_length = 2*m if n == 1 else 2*m + 2
dec_len = rec_len = filters_length
start = max - m
dec_lo = arr[0][start:start+rec_len][::-1]
rec_lo = arr[idx+1]
rec_hi = negate_odds(dec_lo)
dec_hi = negate_evens(rec_lo)
w = DiscreteWavelet(support_width=6*order-1,
symmetry=SYMMETRY.SYMMETRIC,
orthogonal=False,
biorthogonal=True,
compact_support=True,
name=f'rbio{n}.{m}',
family_name = "Reverse biorthogonal",
short_name="rbio",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=2*order,
vanishing_moments_phi=2*order-1)
return w
if nv is FAMILY.DMEY.value:
dec_len = rec_len = filters_length = 62
dec_lo, dec_hi, rec_lo, rec_hi = filter_bank_(dmey)
w = DiscreteWavelet(support_width=1,
symmetry=SYMMETRY.SYMMETRIC,
orthogonal=True,
biorthogonal=True,
compact_support=True,
name="dmey",
family_name = "Discrete Meyer (FIR Approximation)",
short_name="dmey",
dec_hi=dec_hi,
dec_lo=dec_lo,
rec_hi=rec_hi,
rec_lo=rec_lo,
dec_len=dec_len,
rec_len=rec_len,
vanishing_moments_psi=-1,
vanishing_moments_phi=-1)
return w
return None
# regular expression for finding bandwidth-frequency and center-frequency
cwt_pattern = re.compile(r'\D+(\d+\.*\d*)+')
def _get_bw_center_freqs(freqs, bandwidth_frequency, center_frequency):
if len(freqs) == 2:
bandwidth_frequency = float(freqs[0])
center_frequency = float(freqs[1])
return bandwidth_frequency, center_frequency
def _get_m_b_c(freqs, fbsp_order, bandwidth_frequency, center_frequency):
if len(freqs) == 3:
fbsp_order = int(freqs[0])
bandwidth_frequency = float(freqs[1])
center_frequency = float(freqs[2])
return fbsp_order, bandwidth_frequency, center_frequency
[docs]def build_continuous_wavelet(name: str, family: FAMILY, order: int):
"""Builds a continuous wavelet by its family and order
"""
# wavelet base name
base_name = name[:4]
subname = name[4:]
# indentify the B-C pattern if present
freqs = re.findall(cwt_pattern, name)
if subname and len(freqs) == 0:
raise ValueError("No frequencies have been specified.")
freqs = [float(freq) for freq in freqs]
nv = family.value
if nv == FAMILY.GAUS.value:
if order > 8:
return None
symmetry = SYMMETRY.SYMMETRIC if order % 2 == 0 else SYMMETRY.ANTI_SYMMETRIC
w = ContinuousWavelet(support_width=-1,
symmetry=symmetry,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Gaussian",
short_name="gaus",
complex_cwt=False,
lower_bound=-5.,
upper_bound=5.,
center_frequency=0.,
bandwidth_frequency=0.,
fbsp_order=0)
return w
elif nv == FAMILY.MEXH.value:
functions = ricker()
w = ContinuousWavelet(support_width=-1,
symmetry=SYMMETRY.SYMMETRIC,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Mexican hat wavelet",
short_name="mexh",
complex_cwt=False,
lower_bound=-8.,
upper_bound=8.,
center_frequency=0.25,
bandwidth_frequency=0.,
fbsp_order=0,
functions=functions)
return w
elif nv == FAMILY.MORL.value:
w = ContinuousWavelet(support_width=-1,
symmetry=SYMMETRY.SYMMETRIC,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Morlet wavelet",
short_name="morl",
complex_cwt=False,
lower_bound=-8.,
upper_bound=8.,
center_frequency=0.,
bandwidth_frequency=0.,
fbsp_order=0)
return w
elif nv == FAMILY.CGAU.value:
if order > 8:
return None
symmetry = SYMMETRY.SYMMETRIC if order % 2 == 0 else SYMMETRY.ANTI_SYMMETRIC
w = ContinuousWavelet(support_width=-1,
symmetry=symmetry,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Complex Gaussian wavelets",
short_name="cgau",
complex_cwt=True,
lower_bound=-5.,
upper_bound=5.,
center_frequency=0.,
bandwidth_frequency=0.,
fbsp_order=0)
return w
elif nv == FAMILY.SHAN.value:
bandwidth_frequency, center_frequency = _get_bw_center_freqs(freqs, 0.5, 1.)
w = ContinuousWavelet(support_width=-1,
symmetry=SYMMETRY.ASYMMETRIC,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Shannon wavelets",
short_name="shan",
complex_cwt=True,
lower_bound=-20.,
upper_bound=20.,
center_frequency=center_frequency,
bandwidth_frequency=bandwidth_frequency,
fbsp_order=0)
return w
elif nv == FAMILY.FBSP.value:
fbsp_order, bandwidth_frequency, center_frequency = _get_m_b_c(freqs, 2, 1., 0.5)
w = ContinuousWavelet(support_width=-1,
symmetry=SYMMETRY.ASYMMETRIC,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Frequency B-Spline wavelets",
short_name="fbsp",
complex_cwt=True,
lower_bound=-20.,
upper_bound=20.,
center_frequency=center_frequency,
bandwidth_frequency=bandwidth_frequency,
fbsp_order=fbsp_order)
return w
elif nv == FAMILY.CMOR.value:
bandwidth_frequency, center_frequency = _get_bw_center_freqs(freqs, 1., 0.5)
functions = cmor(bandwidth_frequency, center_frequency)
w = ContinuousWavelet(support_width=-1,
symmetry=SYMMETRY.ASYMMETRIC,
orthogonal=False,
biorthogonal=False,
compact_support=False,
name=name,
family_name = "Complex Morlet wavelets",
short_name="cmor",
complex_cwt=True,
lower_bound=-8.,
upper_bound=8.,
center_frequency=center_frequency,
bandwidth_frequency=bandwidth_frequency,
fbsp_order=2,
functions=functions)
return w
return None
[docs]def build_wavelet(name):
"""Builds a wavelet object by the name of the wavelet
Args:
name (str): Name of the wavelet
Returns:
cr.sparse.wt.DiscreteWavelet: a discrete wavelet object
Example:
::
>>> wavelet = wt.build_wavelet('db1')
>>> print(wavelet)
Wavelet db1
Family name: Daubechies
Short name: db
Filters length: 2
Orthogonal: True
Biorthogonal: True
Symmetry: asymmetric
DWT: True
CWT: False
>>> dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank
>>> print(dec_lo)
>>> print(dec_hi)
>>> print(rec_lo)
>>> print(rec_hi)
[0.70710678 0.70710678]
[-0.70710678 0.70710678]
[0.70710678 0.70710678]
[ 0.70710678 -0.70710678]
>>> phi, psi, x = wavelet.wavefun()
"""
name = name.lower()
family, order = wname_to_family_order(name)
wavelet = None
if is_discrete_wavelet(family):
wavelet = build_discrete_wavelet(family, order)
else:
wavelet = build_continuous_wavelet(name, family, order)
# other wavelet types are not supported for now
if wavelet is None:
raise ValueError(f"Invalid wavelet name {name}")
return wavelet
def rec_integrate(function, dt):
"""Integrate a function using the rectangle integration method
"""
integral = jnp.cumsum(function)
integral *= dt
return integral
def to_wavelet(wavelet):
if isinstance(wavelet, str):
wavelet = build_wavelet(wavelet)
if wavelet is None:
raise ValueError("Invalid wavelet")
return wavelet
[docs]def integrate_wavelet(wavelet, precision=8):
"""Integrate wavelet function using the rectangle integration method
"""
wavelet = to_wavelet(wavelet)
approximations = wavelet.wavefun(precision)
if len(approximations) == 2:
psi, t = approximations
dt = t[1] - t[0]
return rec_integrate(psi, dt), t
elif len(approximations) == 3:
phi, psi, t = approximations
dt = t[1] - t[0]
return rec_integrate(psi, dt), t
elif len(approximations) == 5:
phi_d, psi_d, phi_r, psi_r, t = approximations
dt = t[1] - t[0]
return rec_integrate(psi_d, dt), rec_integrate(psi_r, dt), t
[docs]def central_frequency(wavelet, precision=8):
"""Computes the central frequency of the wavelet function
"""
wavelet = to_wavelet(wavelet)
# Let's see if the central frequency is defined for the wavelet
if wavelet.center_frequency:
return wavelet.center_frequency
# get the wavelet functions
approximations = wavelet.wavefun(precision)
if len(approximations) == 2:
psi, t = approximations
elif len(approximations) == 3:
_, psi, t = approximations
elif len(approximations) == 5:
_, psi, _, _, t = functions_approximations
domain = t[-1] - t[0]
# identify the peak frequency [skip the DC]
index = jnp.argmax(jnp.abs(jfft.fft(psi)[1:])) + 2
if index > len(psi) / 2:
index = len(psi) - index + 2
# convert to Hz
return 1.0 / (domain / (index - 1))
[docs]def scale2frequency(wavelet, scales, precision=8):
"""Converts scales to frequencies for a wavelet
"""
scales = jnp.asarray(scales)
cf = central_frequency(wavelet, precision=precision)
return cf / scales