Source code for cr.wavelets._src.codec.codec_a

# Copyright 2022 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cr.wavelets as crwt
from .blocks import *

[docs]def build_codec(wavelet_name, level, block_length, max_prd, thresholds): """Builds an encoder and decoder function for a given choice of wavelet, decomposition level, signal block length, target percentage root mean square difference and coefficient energy thresholding targets """ wavelet = crwt.to_wavelet(wavelet_name) def encode(block): coeffs = crwt.wavedec(block, wavelet, level=level) th_coeffs, binmaps = threshold(coeffs, thresholds) nz_th_coeffs = remove_zeros(th_coeffs, binmaps) scaled_coeffs, shifts, scales = scale_to_0_1(nz_th_coeffs) # quantize_to_prd_target num_bits = 9 cur_prd = 0 prd_vals = np.empty(num_bits) while num_bits >= 5 and cur_prd <= max_prd: num_bits -= 1 quantized_coeffs = quantize_1(scaled_coeffs, num_bits) inv_quant_coeffs = inv_quantize_1(quantized_coeffs, num_bits) unscaled_coeffs = descale_from_0_1(inv_quant_coeffs, shifts, scales) unscaled_coeffs = add_zeros(unscaled_coeffs, binmaps) reconstructed = crwt.waverec(unscaled_coeffs, wavelet) cur_prd = crn.prd(block, reconstructed) prd_vals[num_bits] = cur_prd if cur_prd > max_prd: num_bits += 1 cur_prd = prd_vals[num_bits] quantized_coeffs = quantize_1(scaled_coeffs, num_bits) combined_coeffs = combine_arrays(quantized_coeffs) combined_binmaps = combine_arrays(binmaps) result = encode_cbss_to_bits( combined_coeffs, combined_binmaps, shifts, scales, num_bits) return result def decode(bits): (c_coeffs, c_binmaps, shifts, scales, num_bits) = decode_cbss_from_bits( wavelet_name, level, block_length, bits) coeffs, binmaps = split_coefs_binmaps( wavelet_name, level, block_length, c_coeffs, c_binmaps) inv_quant_coeffs = inv_quantize_1(coeffs, num_bits) unscaled_coeffs = descale_from_0_1(inv_quant_coeffs, shifts, scales) unscaled_coeffs = add_zeros(unscaled_coeffs, binmaps) reconstructed = crwt.waverec(unscaled_coeffs, wavelet) return reconstructed return encode, decode