From 6a5a6bc69607060dfc1a193ac728b6f53bf3a020 Mon Sep 17 00:00:00 2001 From: Felix Tarkoey Date: Mon, 12 Oct 2020 14:45:27 +0200 Subject: [PATCH] Proor of concept on how to fix issues #531, #535 and #570 --- pywt/_cwt.py | 45 +++++++++++++++++++++---------------- pywt/_functions.py | 55 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 79 insertions(+), 21 deletions(-) diff --git a/pywt/_cwt.py b/pywt/_cwt.py index a47cf9885..feae2f4d8 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -1,8 +1,9 @@ from math import floor, ceil +from scipy import interpolate from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet, Wavelet, _check_dtype) -from ._functions import integrate_wavelet, scale2frequency +from ._functions import evaluate_wavelet, scale2frequency __all__ = ["cwt"] @@ -123,13 +124,16 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): dt_out = dt_cplx if wavelet.complex_cwt else dt out = np.empty((np.size(scales),) + data.shape, dtype=dt_out) precision = 10 - int_psi, x = integrate_wavelet(wavelet, precision=precision) - int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi + psi, x = evaluate_wavelet(wavelet, precision=precision) + psi = np.conj(psi) if wavelet.complex_cwt else psi - # convert int_psi, x to the same precision as the data - dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt - int_psi = np.asarray(int_psi, dtype=dt_psi) + # convert psi, x to the same precision as the data + dt_psi = dt_cplx if psi.dtype.kind == 'c' else dt + psi = np.asarray(psi, dtype=dt_psi) x = np.asarray(x, dtype=data.real.dtype) + # FIXME: The original wavelet function could be used here, but + # interpolation is computationally more efficient. + wavefun = interpolate.interp1d(x, psi, kind='cubic', assume_sorted=True) if method == 'fft': size_scale0 = -1 @@ -146,41 +150,44 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): data = data.reshape((-1, data.shape[-1])) for i, scale in enumerate(scales): - step = x[1] - x[0] - j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) - j = j.astype(int) # floor - if j[-1] >= int_psi.size: - j = np.extract(j < int_psi.size, j) - int_psi_scale = int_psi[j][::-1] + # FIXME: Boundary points might be discarded erroneously + if np.sign(x[0])*np.sign(x[-1])<0: + # Wavelet is sampled at 0.0 if the range includes it + xsl = np.arange(0.0, x[0], -1.0/scale) + xsr = np.arange(0.0, x[-1], 1.0/scale) + xs = np.concatenate((xsl[:0:-1], xsr)) + else: + xs = np.arange(x[0], x[-1], 1.0/scale) + psi_scale = wavefun(xs)[::-1] if method == 'conv': if data.ndim == 1: - conv = np.convolve(data, int_psi_scale) + conv = np.convolve(data, psi_scale) else: # batch convolution via loop conv_shape = list(data.shape) - conv_shape[-1] += int_psi_scale.size - 1 + conv_shape[-1] += psi_scale.size - 1 conv_shape = tuple(conv_shape) conv = np.empty(conv_shape, dtype=dt_out) for n in range(data.shape[0]): - conv[n, :] = np.convolve(data[n], int_psi_scale) + conv[n, :] = np.convolve(data[n], psi_scale) else: # The padding is selected for: # - optimal FFT complexity # - to be larger than the two signals length to avoid circular # convolution size_scale = next_fast_len( - data.shape[-1] + int_psi_scale.size - 1 + data.shape[-1] + psi_scale.size - 1 ) if size_scale != size_scale0: # Must recompute fft_data when the padding size changes. fft_data = fftmodule.fft(data, size_scale, axis=-1) size_scale0 = size_scale - fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1) + fft_wav = fftmodule.fft(psi_scale, size_scale, axis=-1) conv = fftmodule.ifft(fft_wav * fft_data, axis=-1) - conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1] + conv = conv[..., :data.shape[-1] + psi_scale.size - 1] - coef = - np.sqrt(scale) * np.diff(conv, axis=-1) + coef = conv / np.sqrt(scale) if out.dtype.kind != 'c': coef = coef.real # transform axis is always -1 due to the data reshape above diff --git a/pywt/_functions.py b/pywt/_functions.py index 86033967a..368b7aaf7 100644 --- a/pywt/_functions.py +++ b/pywt/_functions.py @@ -17,8 +17,8 @@ from ._extensions._pywt import DiscreteContinuousWavelet, Wavelet, ContinuousWavelet -__all__ = ["integrate_wavelet", "central_frequency", "scale2frequency", "qmf", - "orthogonal_filter_bank", +__all__ = ["integrate_wavelet", "evaluate_wavelet", "central_frequency", + "scale2frequency", "qmf", "orthogonal_filter_bank", "intwave", "centrfrq", "scal2frq", "orthfilt"] @@ -119,6 +119,57 @@ def integrate_wavelet(wavelet, precision=8): return _integrate(psi_d, step), _integrate(psi_r, step), x +def evaluate_wavelet(wavelet, precision=8): + """ + Evaluate `psi` wavelet function between lower and upper bound. + + Parameters + ---------- + wavelet : Wavelet instance or str + Wavelet to evaluate. If a string, should be the name of a wavelet. + precision : int, optional + Number of wavelet function points computed with Wavelet's + wavefun(level=precision) method (default: 8). + + Returns + ------- + [psi, x] : + for orthogonal wavelets + [psi_d, psi_r, x] : + for other wavelets + + + Examples + -------- + >>> from pywt import Wavelet, evaluate_wavelet + >>> wavelet1 = Wavelet('db2') + >>> [psi, x] = evaluate_wavelet(wavelet1, precision=5) + >>> wavelet2 = Wavelet('bior1.3') + >>> [psi_d, psi_r, x] = evaluate_wavelet(wavelet2, precision=5) + + """ + + if type(wavelet) in (tuple, list): + psi, x = np.asarray(wavelet[0]), np.asarray(wavelet[1]) + return psi, x + elif not isinstance(wavelet, (Wavelet, ContinuousWavelet)): + wavelet = DiscreteContinuousWavelet(wavelet) + + functions_approximations = wavelet.wavefun(precision) + + if len(functions_approximations) == 2: # continuous wavelet + psi, x = functions_approximations + return psi, x + + elif len(functions_approximations) == 3: # orthogonal wavelet + phi, psi, x = functions_approximations + return psi, x + + else: # biorthogonal wavelet + phi_d, psi_d, phi_r, psi_r, x = functions_approximations + return psi_d, psi_r, x + + def central_frequency(wavelet, precision=8): """ Computes the central frequency of the `psi` wavelet function.