# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from jax import jit, lax, vmap
import jax.numpy as jnp
from .wavelet import build_wavelet, DiscreteWavelet
######################################################################################
# Utility functions
######################################################################################
[docs]def dwt_max_level(input_len, filter_len):
"""Returns the maximum level of useful DWT decomposition based on data length and filter length
"""
if isinstance (filter_len, str):
filter_len = build_wavelet(filter_len)
if filter_len is None:
raise ValueError("Invalid wavelet")
if isinstance(filter_len, DiscreteWavelet):
filter_len = filter_len.dec_len
if filter_len < 2 or int(filter_len) != filter_len:
raise ValueError("filter_len must be an integer >= 2")
if input_len < filter_len - 1:
return 0
return int(math.log2(input_len // (filter_len - 1)))
# some modes are not supported yet
modes = ["zero", "constant", "symmetric", "periodic",
# "smooth",
"periodization", "reflect",
# "antisymmetric", "antireflect"
]
[docs]def dwt_coeff_len(data_len, filter_len, mode='symmetric'):
"""Returns the length of wavelet decomposition output based on data length, filter length and mode
"""
if isinstance (filter_len, str):
filter_len = build_wavelet(filter_len)
if filter_len is None:
raise ValueError("Invalid wavelet")
if isinstance(filter_len, DiscreteWavelet):
filter_len = filter_len.dec_len
if data_len < 1:
raise ValueError("Value of data_len must be greater than zero.")
if filter_len < 1:
raise ValueError("Value of filter_len must be greater than zero.")
if mode == 'periodization':
return (data_len + 1) // 2
else:
return (data_len + filter_len - 1) // 2
[docs]def dwt_coeff_lengths(data_len, filter_len, max_level, mode='symmetric'):
"""Returns the lengths of wavelet decomposition outputs up to a max_level
"""
if isinstance (filter_len, str):
filter_len = build_wavelet(filter_len)
if filter_len is None:
raise ValueError("Invalid wavelet")
if isinstance(filter_len, DiscreteWavelet):
filter_len = filter_len.dec_len
if data_len < 1:
raise ValueError("Value of data_len must be greater than zero.")
if filter_len < 1:
raise ValueError("Value of filter_len must be greater than zero.")
def mapper(x):
if mode == 'periodization':
return (x + 1) // 2
else:
return (x + filter_len - 1) // 2
output = []
for i in range(max_level):
data_len = mapper(data_len)
if data_len < 1:
break
output.append(data_len)
return output
def pad_smooth(vector, pad_width, iaxis, kwargs):
# smooth extension to left
left = vector[pad_width[0]]
slope_left = (left - vector[pad_width[0] + 1])
vector = vector.at[:pad_width[0]].set(
left + jnp.arange(pad_width[0], 0, -1) * slope_left)
# smooth extension to right
right = vector[-pad_width[1] - 1]
slope_right = (right - vector[-pad_width[1] - 2])
vector = vector.at[-pad_width[1]:].set(
right + jnp.arange(1, pad_width[1] + 1) * slope_right)
return vector
def pad_antisymmetric(vector, pad_width, iaxis, kwargs):
# smooth extension to left
# implement by flipping portions symmetric padding
npad_l, npad_r = pad_width
vsize_nonpad = vector.size - npad_l - npad_r
# Note: must modify vector in-place
vector = vector.at[:].set(jnp.pad(vector[pad_width[0]:-pad_width[-1]],
pad_width, mode='symmetric'))
r_edge = npad_l + vsize_nonpad - 1
l_edge = npad_l
# width of each reflected segment
seg_width = vsize_nonpad
# flip reflected segments on the right of the original signal
n = 1
while r_edge <= vector.size:
segment_slice = slice(r_edge + 1,
min(r_edge + 1 + seg_width, vector.size))
if n % 2:
vector = vector.at[segment_slice].set(vector[segment_slice]*-1)
r_edge += seg_width
n += 1
# flip reflected segments on the left of the original signal
n = 1
while l_edge >= 0:
segment_slice = slice(max(0, l_edge - seg_width), l_edge)
if n % 2:
vector.at[segment_slice].set(vector[segment_slice]*-1)
l_edge -= seg_width
n += 1
return vector
def make_even_shape(data):
"""Makes the data shape to be even in all dimensions by duplicating the last value
"""
edge_pad_widths = [(0, data.shape[ax] % 2)
for ax in range(data.ndim)]
data = jnp.pad(data, edge_pad_widths, mode='edge')
return data
[docs]def pad(data, pad_widths, mode):
"""Pads a given 1D signal using a given boundary mode.
"""
data = jnp.asarray(data)
pad_widths = jnp.asarray(pad_widths)
if mode == 'symmetric':
return jnp.pad(data, pad_widths, mode='symmetric')
elif mode == 'reflect':
return jnp.pad(data, pad_widths, mode='reflect')
elif mode == 'antireflect':
return jnp.pad(data, pad_widths, mode='reflect', reflect_type="odd")
elif mode == 'constant':
return jnp.pad(data, pad_widths, mode='edge')
elif mode == 'zero':
return jnp.pad(data, pad_widths, mode='constant', constant_values=0)
elif mode == 'smooth':
return jnp.pad(data, pad_widths, pad_smooth)
# elif mode == 'antisymmetric':
# return jnp.pad(data, pad_widths, pad_antisymmetric)
elif mode == 'periodic':
return jnp.pad(data, pad_widths, mode='wrap')
elif mode == 'periodization':
# Promote odd-sized dimensions to even length by duplicating the
# last value.
edge_pad_widths = [(0, data.shape[ax] % 2)
for ax in range(data.ndim)]
data = jnp.pad(data, edge_pad_widths, mode='edge')
return jnp.pad(data, pad_widths, mode='wrap')
else:
raise ValueError("mode must be one of ['symmetric', 'constant', 'reflect', 'antireflect', 'zero', 'smooth', 'periodic', 'periodization']")
def next_pow_of_2(n):
"""
Returns the smallest integer greater than or equal to n which is a power of 2
"""
return 2**int(math.ceil(math.log2(n)))
######################################################################################
# Utility functions for continuous wavelets
######################################################################################
def time_points(n, dt=1):
"""
Returns n evenly distributed points in time domain
"""
# n = 3, vec = [-1, 0, 1], n=4 vec=[-1.5, -0.5, 0.5, 1.5]
# in general [ - (n-1)/2 : (n-1)/2]
t = jnp.arange(0, n) - (n - 1.0) / 2
# scale t
t = t * dt
return t
def frequency_points(n, dt=1.):
"""Returns n evenly distributed points in frequency domain
"""
fk = jnp.fft.fftfreq(n)
fk = jnp.fft.fftshift(fk)
wk = 2*jnp.pi*fk / dt
return wk
def scales_from_voices_per_octave(nu, range):
"""Returns the list of scales based on the voices per octave parameter
"""
return 2 ** (range / nu)
######################################################################################
# Local utility functions
######################################################################################
def ensure_wavelet_(wavelet):
if isinstance(wavelet, str):
wavelet = build_wavelet(wavelet)
if wavelet is None:
raise ValueError("Invalid wavelet")
return wavelet
def part_dec_filter_(part, wavelet):
if part == 'a':
return wavelet.dec_lo
if part == 'd':
return wavelet.dec_hi
raise ValueError(f'Invalid part: {part}')
def part_rec_filter_(part, wavelet):
if part == 'a':
return wavelet.rec_lo
if part == 'd':
return wavelet.rec_hi
raise ValueError(f'Invalid part: {part}')
def check_axis_(axis, ndim):
if axis < 0:
axis = ndim + axis
if axis >= ndim:
raise ValueError(f"Invalid axis: {axis} with ndim: {ndim}")
return axis