# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implements different versions of wavelet transform
n=1024, J=10, L=6
L - J = 10 - 6 = 4.
s_10 = s_6 + d_6 + d_7 + d_8 + d_9
j=J-1:-1:L = [9, 8, 7, 6]
n = 256, J=8, L=0
s_8 = s_0 + d_0 + d_1 + d_2 + d_3 + d_4 + d_5 + d_6 + d_7.
s_J = s_L + sum(L <= j < J) d_j.
"""
from functools import partial
from jax import jit, lax, vmap
import jax.numpy as jnp
from cr.nimble import promote_arg_dtypes
from .dyad import *
from .multirate import *
from .wavelet import build_wavelet
from .util import *
######################################################################################
# Single level wavelet decomposition/reconstruction
######################################################################################
@partial(jit, static_argnums=(1,2))
def pad_(data, p, mode):
if mode == 'symmetric':
return jnp.pad(data, p, mode='symmetric')
elif mode == 'reflect':
return jnp.pad(data, p, mode='reflect')
elif mode == 'constant':
return jnp.pad(data, p, mode='edge')
elif mode == 'zero':
return jnp.pad(data, p, mode='constant', constant_values=0)
elif mode == 'periodic':
return jnp.pad(data, p, mode='wrap')
elif mode == 'periodization':
# Promote odd-sized dimensions to even length by duplicating the
# last value.
edge_pad_width = (0, data.shape[0] % 2)
data = jnp.pad(data, edge_pad_width, mode='edge')
return jnp.pad(data, p//2, mode='wrap')
else:
raise ValueError("mode must be one of ['symmetric', 'constant', 'reflect', 'zero', 'periodic', 'periodization']")
[docs]@partial(jit, static_argnums=(3,))
def dwt_(data, dec_lo, dec_hi, mode):
"""Computes single level discrete wavelet decomposition
"""
p = len(dec_lo)
x_padded = pad_(data, p, mode)
x_in = x_padded[None, None, :]
padding = [(1, 1)] if mode == 'periodization' else [(0, 1)]
strides = (2,)
lo_in = dec_lo[::-1][None, None, :]
lo = lax.conv_general_dilated(x_in, lo_in, strides, padding)
lo = lo[0, 0, slice(None)]
hi_in = dec_hi[::-1][None, None, :]
hi = lax.conv_general_dilated(x_in, hi_in, strides, padding)
hi = hi[0, 0, slice(None)]
if mode == 'periodization':
#return lo, hi
return lo[1:-1], hi[1:-1]
else:
return lo[1:-1], hi[1:-1]
[docs]def dwt(data, wavelet, mode="symmetric", axis=-1):
"""Computes single level discrete wavelet decomposition
Args:
data (jax.numpy.ndarray): Input signal array whose DWT is to be computed
wavelet (str or cr.sparse.wt.DiscreteWavelet): The wavelet to be used to compute DWT (by name or object)
mode (:obj:`str`, optional): Signal extension mode to be used during DWT computation. Default 'symmetric'.
See :ref:`Modes <ref-wt-modes>` for available modes.
axis (int, optional): The axis along which the vectors from data will be picked for computing DWT. Default -1 (last axis).
Returns:
(jax.numpy.ndarray, jax.numpy.ndarray): A tuple (cA, cD) containing the approximation and detail coefficients for the data.
Example:
Computing the haar/db1 wavelet decomposition::
>>> ca, cd = wt.dwt([1,2,3,4,4,3,2,1], 'db1')
>>> print(ca)
>>> print(cd)
[2.12132034 4.94974747 4.94974747 2.12132034]
[-0.70710678 -0.70710678 0.70710678 0.70710678]
"""
wavelet = ensure_wavelet_(wavelet)
data = jnp.asarray(data)
if jnp.iscomplexobj(data):
car, cdr = dwt(data.real, wavelet, mode)
cai, cdi = dwt(data.imag, wavelet, mode)
return lax.complex(car, cai), lax.complex(cdr, cdi)
axis = check_axis_(axis, data.ndim)
data, dec_lo, dec_hi = promote_arg_dtypes(data, wavelet.dec_lo, wavelet.dec_hi)
if data.ndim == 1:
return dwt_(data, dec_lo, dec_hi, mode)
else:
return dwt_axis_(data, dec_lo, dec_hi, axis, mode)
[docs]@partial(jit, static_argnums=(4,))
def idwt_(ca, cd, rec_lo, rec_hi, mode):
"""Computes single level discrete wavelet reconstruction
"""
p = len(rec_lo)
ca = up_sample(ca, 2)
cd = up_sample(cd, 2)
if mode == 'periodization':
ca = jnp.pad(ca, p//2, mode='wrap')
cd = jnp.pad(cd, p//2, mode='wrap')
# Compute the low pass portion of the next level of approximation
a = jnp.convolve(ca, rec_lo, 'same')
# Compute the high pass portion of the next level of approximation
d = jnp.convolve(cd, rec_hi, 'same')
# Compute the sum
sum = a + d
if mode == 'periodization':
return sum[p//2:-p//2]
skip = p//2 - 1
if skip > 0:
return sum[skip:-skip]
return sum
@partial(jit, static_argnums=(3,))
def idwt_joined_(w, rec_lo, rec_hi, mode):
"""Computes single level discrete wavelet reconstruction
"""
n = len(w)
m = n // 2
ca = w[:m]
cd = w[m:]
x = idwt_(ca, cd, rec_lo, rec_hi, mode)
return x
[docs]def idwt(ca, cd, wavelet, mode="symmetric", axis=-1):
"""Computes single level discrete wavelet reconstruction
"""
if ca is None and cd is None:
raise ValueError("Both ca and cd cannot be None")
# make sure that ca and cd are arrays
if ca is not None:
ca = jnp.asarray(ca)
if cd is not None:
cd = jnp.asarray(cd)
if cd is None:
cd = jnp.zeros_like(ca)
if ca is None:
ca = jnp.zeros_like(cd)
if ca.shape != cd.shape:
raise Value("ca and cd must have identical shape.")
wavelet = ensure_wavelet_(wavelet)
if ca.shape[0] < wavelet.rec_len // 2:
raise ValueError("Insufficient coefficients for wavelet reconstruction.")
axis = check_axis_(axis, ca.ndim)
if jnp.iscomplexobj(ca) or jnp.iscomplexobj(ca):
car = jnp.real(ca)
cai = jnp.imag(ca)
cdr = jnp.real(cd)
cdi = jnp.imag(cd)
xr = idwt(car, cdr, wavelet, mode)
xi = idwt(cai, cdi, wavelet, mode)
return lax.complex(xr, xi)
rec_lo = wavelet.rec_lo
rec_hi = wavelet.rec_hi
if ca.ndim == 1:
return idwt_(ca, cd, rec_lo, rec_hi, mode)
else:
return idwt_axis_(ca, cd, rec_lo, rec_hi, axis, mode)
######################################################################################
# Wavelet decomposition/reconstruction for only one set of coefficients
######################################################################################
[docs]@partial(jit, static_argnums=(2,))
def downcoef_(data, filter, mode):
"""Partial discrete wavelet decomposition
"""
p = len(filter)
x_padded = pad_(data, p, mode)
x_in = x_padded[None, None, :]
padding = [(1, 1)] if mode == 'periodization' else [(0, 1)]
strides = (2,)
filter = filter[::-1][None, None, :]
out = lax.conv_general_dilated(x_in, filter, strides, padding)
out = out[0, 0, slice(None)]
if mode == 'periodization':
return out[1:-1]
else:
return out[1:-1]
[docs]def downcoef(part, data, wavelet, mode='symmetric', level=1):
"""Partial discrete wavelet decomposition (multi-level)
"""
if level < 1:
raise ValueError("Value of level must be greater than 0.")
if data.ndim > 1:
raise ValueError("downcoef only supports 1d data.")
if jnp.iscomplexobj(data):
real = downcoef(part, data.real, wavelet, mode, level)
imag = downcoef(part, data.imag, wavelet, mode, level)
return lax.complex(real, imag)
wavelet = ensure_wavelet_(wavelet)
data = promote_arg_dtypes(data)
filter = part_dec_filter_(part, wavelet)
# We do averaging for all levels except the last one
dec_lo = wavelet.dec_lo
for i in range(level-1):
data = downcoef_(data, dec_lo, mode)
# In the last iteration, we apply 'a' or 'd' as needed
data = downcoef_(data, filter, mode)
return data
[docs]@partial(jit, static_argnums=(2,))
def upcoef_(coeffs, filter, mode):
"""Partial discrete wavelet reconstruction from one part of coefficients
"""
p = len(filter)
coeffs = up_sample(coeffs, 2)
m = len(coeffs)
if mode == 'periodization':
coeffs = jnp.pad(coeffs, p//2, mode='wrap')
sum = jnp.convolve(coeffs, filter, 'full')
if mode == 'periodization':
return sum[p-1:-p]
return sum[:-1]
@partial(jit, static_argnums=(2,3))
def upcoef_a(coeffs, rec_lo, mode, level):
for i in range(level):
coeffs = upcoef_(coeffs, rec_lo, mode)
return coeffs
@partial(jit, static_argnums=(3,4))
def upcoef_d(coeffs, rec_hi, rec_lo, mode, level):
coeffs = upcoef_(coeffs, rec_hi, mode)
for i in range(level-1):
coeffs = upcoef_(coeffs, rec_lo, mode)
return coeffs
[docs]def upcoef(part, coeffs, wavelet, mode='symmetric', level=1, take=0):
"""Partial discrete wavelet reconstruction from one part of coefficients (multi-level)
"""
if level < 1:
raise ValueError("Value of level must be greater than 0.")
if coeffs.ndim > 1:
raise ValueError("upcoef only supports 1d coeffs.")
if jnp.iscomplexobj(coeffs):
real = upcoef(part, coeffs.real, wavelet, mode, level)
imag = upcoef(part, coeffs.imag, wavelet, mode, level)
return lax.complex(real, imag)
wavelet = ensure_wavelet_(wavelet)
filter = part_rec_filter_(part, wavelet)
# We do averaging for all levels except the last one
rec_lo = wavelet.rec_lo
coeffs = upcoef_(coeffs, filter, mode)
for i in range(level-1):
coeffs = upcoef_(coeffs, rec_lo, mode)
rec_len = wavelet.rec_len
if take > 0 and take < rec_len:
left_bound = right_bound = (rec_len-take) // 2
if (rec_len-take) % 2:
# right_bound must never be zero for indexing to work
right_bound = right_bound + 1
return coeffs[left_bound:-right_bound]
return coeffs
######################################################################################
# Single level wavelet decomposition/reconstruction along a given axis
######################################################################################
[docs]@partial(jit, static_argnums=(3,4))
def dwt_axis_(data, dec_lo, dec_hi, axis, mode):
"""Applies the DWT along a given axis
"""
return jnp.apply_along_axis(dwt_, axis, data, dec_lo, dec_hi, mode)
[docs]def dwt_axis(data, wavelet, axis, mode="symmetric"):
"""Computes single level wavelet decomposition along a given axis
"""
wavelet = ensure_wavelet_(wavelet)
if jnp.iscomplexobj(data):
car, cdr = dwt_axis(data.real, wavelet, axis, mode)
cai, cdi = dwt_axis(data.imag, wavelet, axis, mode)
return lax.complex(car, cai), lax.complex(cdr, cdi)
data, dec_lo, dec_hi = promote_arg_dtypes(data, wavelet.dec_lo, wavelet.dec_hi)
return dwt_axis_(data, dec_lo, dec_hi, axis, mode)
[docs]@partial(jit, static_argnums=(4,5))
def idwt_axis_(ca, cd, rec_lo, rec_hi, axis, mode):
"""Applies the Inverse DWT along a given axis
"""
w = jnp.concatenate((ca, cd), axis=axis)
return jnp.apply_along_axis(idwt_joined_, axis, w, rec_lo, rec_hi, mode)
[docs]def idwt_axis(ca, cd, wavelet, axis, mode="symmetric"):
"""Computes single level wavelet reconstruction along a given axis
"""
wavelet = ensure_wavelet_(wavelet)
if ca is not None:
ca = jnp.asarray(ca)
if cd is not None:
cd = jnp.asarray(cd)
if cd is None:
cd = jnp.zeros_like(ca)
if ca is None:
ca = jnp.zeros_like(cd)
if ca.shape != cd.shape:
raise Value("ca and cd must have identical shape.")
if jnp.iscomplexobj(ca) or jnp.iscomplexobj(ca):
car = jnp.real(ca)
cai = jnp.imag(ca)
cdr = jnp.real(cd)
cdi = jnp.imag(cd)
xr = idwt_axis(car, cdr, wavelet, axis, mode)
xi = idwt_axis(cai, cdi, wavelet, axis, mode)
return lax.complex(xr, xi)
return idwt_axis_(ca, cd, wavelet.rec_lo, wavelet.rec_hi, axis, mode)
[docs]def dwt_column(data, wavelet, mode="symmetric"):
"""Computes single level wavelet decomposition along columns (axis-0)
"""
return dwt_axis(data, wavelet, 0, mode)
[docs]def dwt_row(data, wavelet, mode="symmetric"):
"""Computes single level wavelet decomposition along rows (axis-1)
"""
return dwt_axis(data, wavelet, 1, mode)
[docs]def dwt_tube(data, wavelet, mode="symmetric"):
"""Computes single level wavelet decomposition along tubes (axis-2)
"""
return dwt_axis(data, wavelet, 2, mode)
[docs]def idwt_column(ca, cd, wavelet, mode="symmetric"):
"""Computes single level wavelet reconstruction along columns (axis-0)
"""
return idwt_axis(ca, cd, wavelet, 0, mode)
[docs]def idwt_row(ca, cd, wavelet, mode="symmetric"):
"""Computes single level wavelet reconstruction along rows (axis-1)
"""
return idwt_axis(ca, cd, wavelet, 1, mode)
[docs]def idwt_tube(ca, cd, wavelet, mode="symmetric"):
"""Computes single level wavelet reconstruction along tubes (axis-2)
"""
return idwt_axis(ca, cd, wavelet, 2, mode)
######################################################################################
# Single level wavelet decomposition/reconstruction on 2 dimensions
######################################################################################
dwt2_rw_ = vmap(dwt_, in_axes=(0, None, None, None), out_axes=0)
dwt2_cw_ = vmap(dwt_, in_axes=(1, None, None, None), out_axes=1)
[docs]def dwt2(image, wavelet, mode="symmetric", axes=(-2, -1)):
"""Computes single level wavelet decomposition for 2D images
"""
wavelet = ensure_wavelet_(wavelet)
image = promote_arg_dtypes(image)
dec_lo = wavelet.dec_lo
dec_hi = wavelet.dec_hi
axes = tuple(axes)
if len(axes) != 2:
raise ValueError("Expected two dimensions")
# make sure that axes are positive
axes = [a + image.ndim if a < 0 else a for a in axes]
ca, cd = dwt_axis(image, wavelet, axes[0], mode)
caa, cad = dwt_axis(ca, wavelet, axes[1], mode)
cda, cdd = dwt_axis(cd, wavelet, axes[1], mode)
return caa, (cda, cad, cdd)
[docs]def idwt2(coeffs, wavelet, mode="symmetric", axes=(-2, -1)):
"""Computes single level wavelet reconstruction for 2D images
"""
wavelet = ensure_wavelet_(wavelet)
caa, (cda, cad, cdd) = coeffs
axes = tuple(axes)
if len(axes) != 2:
raise ValueError("Expected two dimensions")
# make sure that axes are positive
axes = [a + caa.ndim if a < 0 else a for a in axes]
ca = idwt_axis(caa, cad, wavelet, axes[1], mode)
cd = idwt_axis(cda, cdd, wavelet, axes[1], mode)
image = idwt_axis(ca, cd, wavelet, axes[0], mode)
return image
######################################################################################
# Single level wavelet decomposition/reconstruction on n dimensions
######################################################################################
######################################################################################
# Multi level wavelet decomposition/reconstruction
######################################################################################
def forward_periodized_orthogonal(qmf, x, L=0):
"""Computes the forward wavelet transform of x
* Uses the periodized version of x
* with an orthogonal wavelet basis
* length of x must be dyadic.
if L == 0, then we perform full wavelet decomposition
"""
# Let's get the dyadic length of x and verify that
# length of x is a power of 2.
# assert has_dyadic_length(x)
n = x.shape[0]
J = dyadic_length_int(x)
# assert L < J, "L must be smaller than dyadic index of x"
# Create the storage for wavelet coefficients.
end = n
for j in range(J-1, L-1, -1):
part = x[:end]
# Compute the hipass component of x and downsample it.
hi = hi_pass_down_sample(qmf, part)
# Compute the low pass downsampled version
lo = lo_pass_down_sample(qmf, part)
# Update the wavelet decomposition
x = x.at[:end].set(jnp.concatenate((lo, hi)))
end = end // 2
return x
forward_periodized_orthogonal_jit = jit(forward_periodized_orthogonal, static_argnums=(2,))
def inverse_periodized_orthogonal(qmf, w, L=0):
""" Computes the inverse wavelet transform of x
* Uses the periodized version of x
* with an orthogonal wavelet basis
* length of x must be dyadic.
"""
# Let's get the dyadic length of w
n = w.shape[0]
J = dyadic_length_int(w)
# initialize x with its coerce approximation
mid = 2**L
lo = w[:mid]
end = mid*2
for j in range(L, J):
hi = w[mid:end]
# Compute the low pass portion of the next level of approximation
x_low = up_sample_lo_pass(qmf, lo)
# Compute the high pass portion of the next level of approximation
x_hi = up_sample_hi_pass(qmf, hi)
# Compute the next level approximation of x
lo = x_low + x_hi
mid = end
end = mid * 2
return lo
inverse_periodized_orthogonal_jit = jit(inverse_periodized_orthogonal, static_argnums=(2,))