Source code for cr.wavelets._src.cwt_tc

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from functools import partial
from typing import NamedTuple, Callable, Tuple

from jax import jit, lax
import jax.numpy as jnp
import jax.numpy.fft as jfft
import scipy
import numpy as np
import matplotlib.pyplot as plt

from .cont_wavelets import *

from .util import next_pow_of_2, time_points, frequency_points
from .wavelet import to_wavelet

########################################################################################################
# CWT in time and frequency domains
########################################################################################################

[docs]def cwt_tc_time(data, wavelet_func, scales, dt=1., axis=-1): """Computes the continuous wavelet transform """ sample = wavelet_func(1, 1.) a = len(scales) n = data.shape[axis] out_shape = (a,) + data.shape output = jnp.empty(out_shape, dtype=sample.dtype) # compute in time slices = [None for _ in data.shape] slices[axis] = slice(None) slices = tuple(slices) t = time_points(n, dt) for index, scale in enumerate(scales): # n = jnp.minimum(10*scale, b) # sample wavelet and normalise norm = (dt) ** .5 # compute the wavelet wavelet_seq = norm * wavelet_func(t, scale) # keep a max of 10:scale values # wavelet = wavelet[:10*scale] # conjugate it wavelet_seq = jnp.conj(wavelet_seq) # reverse it wavelet_seq = wavelet_seq[::-1] filter = wavelet_seq[slices] if jnp.isrealobj(filter): # convolve with data coeffs = jnp.convolve(data, filter, mode='same') else: # convolve with data coeffs_real = jnp.convolve(data, filter.real, mode='same') coeffs_imag = jnp.convolve(data, filter.imag, mode='same') coeffs = lax.complex(coeffs_real, coeffs_imag) output = output.at[index].set(coeffs) return output
cwt_tc_time_jit = jit(cwt_tc_time, static_argnums=(1,3,4))
[docs]def cwt_tc_frequency(data, wavelet_func, scales, dt=1., axis=-1): """ Computes the CWT of data [along axis] for a given wavelet (in frequency domain) """ # make sure that parmeters are arrays data = jnp.asarray(data) scales = jnp.asarray(scales) # number of data points for each data vector n = data.shape[axis] # next power of 2 pn = next_pow_of_2(n) # compute the FFT of the data data_fft = jfft.fft(data, n=pn, axis=axis) # angular frequencies at which the Wavelet basis will be computed wk = jfft.fftfreq(pn, d=dt) * 2 * jnp.pi # sample wavelet at all the scales and normalise norm = ( 1 / dt) ** .5 wavelet_freq = norm * wavelet_func(wk, scales) # take the conjugate wavelet_freq = jnp.conj(wavelet_freq) # Convert negative axis. Add one to account for # inclusion of scales axis above. axis = (axis % data.ndim) + 1 # perform the convolution in frequency space slices = [slice(None)] + [None for _ in data.shape] slices[axis] = slice(None) slices = tuple(slices) out = jfft.ifft(data_fft[None] * wavelet_freq[slices], n=pn, axis=axis) slices = [slice(None) for _ in out.shape] slices[axis] = slice(None, n) slices = tuple(slices) if data.ndim == 1: return out[slices].squeeze() else: return out[slices]
cwt_tc_frequency_jit = jit(cwt_tc_frequency, static_argnums=(1,3, 4))
[docs]def cwt_tc(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): """Computes the CWT of data along a specified axis with a specified wavelet """ wavelet = to_wavelet(wavelet) if method == 'conv': wavelet_func = wavelet.functions.time output = cwt_tc_time_jit(data, wavelet_func, scales, dt=sampling_period, axis=axis) elif method == 'fft': wavelet_func = wavelet.functions.frequency output = cwt_tc_frequency_jit(data, wavelet_func, scales, dt=sampling_period, axis=axis) else: raise NotImplementedError("The specified method is not supported yet") return output
######################################################################################################## # Tuple Describing a Continuous Wavelet Analysis Result ######################################################################################################## class WaveletAnalysis(NamedTuple): """Continuous Wavelet Analysis of a 1D data signal """ data : jnp.ndarray """ data on which analysis is being performed""" wavelet: WaveletFunctions """ The wavelet being used for analysis""" dt: float """ sample spacing / period""" dj : float """ scale resolution """ mask_coi: bool """Disregard wavelet power outside the cone of influence""" frequency: bool """The method used for computing CWT time domain or frequency domain""" axis : int """ The axis along which the analysis will be performed""" scales: jnp.ndarray """ The scales at which the analysis was performed""" scalogram : jnp.ndarray """The resultant scalogram""" @property def n(self): """Returns the length of data along the axis on which CWT is being computed""" return self.data.shape[self.axis] @property def times(self): """Returns the nomal time points for the dataset""" return time_points(self.n, self.dt) @property def fourier_period(self): """Return a function that calculates the equivalent Fourier period as a function of scale. """ return self.wavelet.fourier_period @property def scale_from_period(self): """Return a function that calculates the wavelet scale from the fourier period """ return self.wavelet.scale_from_period @property def fourier_periods(self): """Return the equivalent Fourier periods for the scales used.""" return self.fourier_period(self.scales) @property def fourier_frequencies(self): """ Return the equivalent frequencies . This is equivalent to 1.0 / self.fourier_periods """ return jnp.reciprocal(self.fourier_periods) @property def s0(self): return find_s0(self.wavelet, self.dt) @property def w_k(self): """Angular frequency as a function of Fourier index. N.B the frequencies returned by numpy are adimensional, on the interval [-1/2, 1/2], so we multiply by 2 * pi. """ return 2 * jnp.pi * jfft.fftfreq(self.n, self.dt) @property def magnitude(self): """Returns the magnitude of the scalogram""" return jnp.abs(self.scalogram) @property def power(self): """Calculate the wavelet power spectrum""" return jnp.abs(self.scalogram) ** 2 @property def coi(self): """The Cone of Influence is the region near the edges of the input signal in which edge effects may be important. Return a tuple (T, S) that describes the edge of the cone of influence as a single line in (time, scale). """ times = self.times scales = self.scales Tmin = times.min() Tmax = times.max() Tmid = Tmin + (Tmax - Tmin) / 2 s = np.logspace(np.log10(scales.min()), np.log10(scales.max()), 100) coi_func = self.wavelet.coi c1 = Tmin + coi_func(s) c2 = Tmax - coi_func(s) C = np.hstack((c1[np.where(c1 < Tmid)], c2[np.where(c2 > Tmid)])) S = np.hstack((s[np.where(c1 < Tmid)], s[np.where(c2 > Tmid)])) # sort w.r.t time iC = C.argsort() sC = C[iC] sS = S[iC] return sC, sS @property def wavelet_transform_delta(self): """Calculate the delta wavelet transform. Returns an array of the transform computed over the scales. """ wavelet_func = self.wavelet.frequency # wavelet as f(w_k, s) WK, S = jnp.meshgrid(self.w_k, self.scales) # compute Y_ over all s, w_k and sum over k norm = (2 * jnp.pi * S / self.dt) ** .5 W_d = (1 / self.n) * jnp.sum(norm * wavelet_func(WK, S.T), axis=1) # N.B This W_d is 1D (defined only at n=0) return W_d @property def C_d(self): """Compute the parameter C_delta, used in reconstruction. See section 3.i of TC98. FIXME: this doesn't work. TC98 gives 0.776 for the Morlet wavelet with dj=0.125. """ dj = self.dj dt = self.dt s = self.scales W_d = self.wavelet_transform_delta # value of the wavelet function at t=0 Y_00 = self.wavelet.time(0).real real_sum = jnp.sum(W_d.real / s ** .5) C_d = real_sum * (dj * dt ** .5 / Y_00) return C_d def plot_power(self, ax=None, coi=True): """"Create a basic wavelet power plot with time on the x-axis, scale on the y-axis, and a cone of influence overlaid. """ if not ax: fig, ax = plt.subplots() times = self.times scales = self.scales Time, Scale = jnp.meshgrid(times, scales) ax.contourf(Time, Scale, self.power, 100) ax.set_yscale('log') ax.grid(True) if coi: coi_time, coi_scale = self.coi ax.fill_between(x=coi_time, y1=coi_scale, y2=self.scales.max(), color='gray', alpha=0.3) ax.set_xlim(times.min(), times.max()) return ax ######################################################################################################## # Tools for Wavelet Analysis ########################################################################################################
[docs]def find_s0(wavelet, dt): """Find the smallest resolvable scale by finding where the equivalent Fourier period is equal to 2 * dt. For a Morlet wavelet, this is roughly 1. """ def f(s): return wavelet.fourier_period(s) - 2 * dt return scipy.optimize.fsolve(f, 1)[0]
[docs]def find_optimal_scales(s0, dt, dj, n): # Largest scale J = int((1 / dj) * math.log2(n * dt / s0)) sj = s0 * 2 ** (dj * jnp.arange(0, J + 1)) return sj
DEFAULT_WAVELET = morlet(w0=6)
[docs]def analyze(data, wavelet=DEFAULT_WAVELET, scales=None, dt=1., dj=0.125, mask_coi=False, frequency=False, axis=-1): """Performs wavelet analysis on a dataset""" n = data.shape[axis] if scales is None: s0 = find_s0(wavelet, dt) scales = find_optimal_scales(s0, dt, dj, n) if frequency: scalogram = cwt_tc_frequency_jit(data, wavelet.frequency, scales, dt, axis) else: scalogram = cwt_tc_time_jit(data, wavelet.time, scales, dt, axis) scales = jnp.asarray(scales) return WaveletAnalysis(data=data, wavelet=wavelet, dt=dt, dj=dj, mask_coi=mask_coi, frequency=frequency, axis=axis, scales=scales, scalogram=scalogram)