Skip to content

Commit

Permalink
#51 More Augment Functions - Wavelet Transform
Browse files Browse the repository at this point in the history
  • Loading branch information
tackes committed Oct 26, 2023
1 parent 71823e1 commit 8c44146
Showing 1 changed file with 282 additions and 0 deletions.
282 changes: 282 additions & 0 deletions src/pytimetk/feature_engineering/wavelet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
import numpy as np
import pandas as pd
import pandas_flavor as pf
import polars as pl

from typing import Union, List
from pytimetk.utils.checks import check_dataframe_or_groupby, check_date_column, check_value_column

from pytimetk.utils.pandas_helpers import flatten_multiindex_column_names
from pytimetk.utils.checks import check_dataframe_or_groupby, check_date_column, check_value_column
from pytimetk.utils.polars_helpers import pandas_to_polars_frequency, pandas_to_polars_aggregation_mapping


#@pf.register_dataframe_method
def augment_wavelet(
data: Union[pd.DataFrame, pd.core.groupby.generic.DataFrameGroupBy],
date_column: str,
value_column: str,
method: str,
sample_rate: str,
scales: Union[str, List[str]],
#engine: str = 'pandas'
):
"""
Apply the Wavely transform to specified columns of a DataFrame or
DataFrameGroupBy object.
Parameters
----------
data : pd.DataFrame or pd.core.groupby.generic.DataFrameGroupBy
Input DataFrame or DataFrameGroupBy object with one or more columns of
real-valued signals.
value_column : str or list
List of column names in 'data' to which the Hilbert transform will be
applied.
engine : str, optional
The `engine` parameter is used to specify the engine to use for
summarizing the data. It can be either "pandas" or "polars".
- The default value is "pandas".
- When "polars", the function will internally use the `polars` library
for summarizing the data. This can be faster than using "pandas" for
large datasets.
sample_rate :
Sampling rate of the input data.
For time-series data, the sample rate (sample_rate) typically refers
to the frequency at which data points are collected.
For example, if your data has a 30-minute interval, if you think of the
data in terms of "samples per hour", the sample rate would be:
sample_rate = samples / hour = 1 / 0.5 = 2
scales : str or list
Array of scales to use in the transform.
The choice of scales in wavelet analysis determines which frequencies
(or periodicities) in the data you want to analyze. In other words, the
scales determine the "window size" or the "look-back period" the wavelet
uses to analyze the data.
Smaller scales: Correspond to analyzing high-frequency changes
(short-term fluctuations) in the data.
Larger scales: Correspond to analyzing low-frequency changes
(long-term fluctuations) in the data.
The specific values for scales depend on what frequencies or
periodicities you expect in your data and wish to study.
For instance, if you believe there are daily, weekly, and monthly
patterns in your data, you'd choose scales that correspond to these
periodicities given your sampling rate.
For a daily pattern with data at 30-minute intervals:
scales = 2 * 24 = 48 because there are 48 half hour intervals in a day
For a weekly pattern with data at 30-minute intervals:
scales = 48 * 7 = 336 because there are 336 half hour intervals in a
week
Recommendation, use a range of values to cover both short term and long
term patterns, then adjust accordingly.
Returns
-------
df_wavelet : pd.DataFrame
DataFrame with added columns for CWT coefficients for each scale, with
a real and imaginary column added.
Notes
-----
For a detailed introduction to wavelet transforms, you can visit this
website.
https://ataspinar.com/2018/12/21/a-guide-for-using-the-wavelet-transform-in-machine-learning/
The Bump wavelet is a real-valued wavelet function, so its imaginary
part is inherently zero.
In the continuous wavelet transform (CWT), the Morlet and Analytic
Morlet wavelets are complex-valued, so their convolutions with the signal
yield complex results (with both real and imaginary parts).
Wavelets, in general, are mathematical functions that can decompose a
signal into its constituent parts at different scales. Different wavelet
functions are suitable for different types of signals and analytical goals.
Let's look at the three wavelet methods:
1. Morlet Wavelet:
Characteristics:
Essentially a complex sinusoid modulated by a Gaussian window.
It provides a good balance between time localization and frequency
localization.
When to use:
When you want a good compromise between time and frequency localization.
Particularly useful when you're interested in sinusoidal components or
oscillatory patterns of your data. Commonly used in time-frequency analysis
because of its simplicity and effectiveness.
2. Bump Wavelet:
Characteristics:
Has an oscillating behavior similar to the Morlet but has sharper time
localization. Its frequency localization isn't as sharp as its time
localization.
When to use:
When you are more interested in precisely identifying when certain events or
anomalies occur in your data. It can be especially useful for detecting
sharp spikes or short-lived events in your signal.
3. Analytic Morlet Wavelet:
Characteristics:
A variation of the Morlet wavelet that is designed to have no negative
frequencies when transformed. This means it's "analytic." Offers slightly
better frequency localization than the standard Morlet wavelet.
When to use:
When you're interested in phase properties of your signal.
Can be used when you need to avoid negative frequencies in your analysis,
making it useful for certain types of signals, like analytic signals.
Offers a cleaner spectrum in the frequency domain than the standard Morlet.
Examples
--------
```{python}
# Example 1: Using Pandas Engine on a pandas groupby object
import pytimetk as tk
import pandas as pd
df = tk.datasets.load_dataset('walmart_sales_weekly', parse_dates = ['Date'])
wavelet_df = (df.groupby('id').augment_wavelet(
date_column = 'Date',
value_column ='Weekly_Sales',
scales = [15],
sample_rate =1,
method = 'bump'
)
)
wavelet_df.head()
```
```{python}
# Example 2: Using Pandas Engine on a pandas dataframe
import pyti??metk as tk
import pandas as pd
df = tk.load_dataset('taylor_30_min', parse_dates = ['date'])
result_df = (augment_wavelet(
df,
date_column = 'date',
value_column ='value',
scales = [15],
sample_rate =1000,
method = 'morlet'
)
)
```
"""
# Run common checks
check_dataframe_or_groupby(data)
check_value_column(data, value_column)
check_date_column(data, date_column)


wavelet_functions = {
'morlet': morlet_wavelet,
'bump': bump_wavelet,
'analytic_morlet': analytic_morlet_wavelet
}

# Sort the DataFrame by the date column before applying the CWT
if isinstance(data, pd.DataFrame):
data = data.sort_values(by=date_column)

if method not in wavelet_functions:
raise ValueError(f"Invalid method '{method}'. Available methods are {list(wavelet_functions.keys())}")

# Select the wavelet function
wavelet_function = wavelet_functions[method]

# Compute the CWT
def compute_cwt(signal, wavelet_function, scales, sampling_rate):
coefficients = []

for scale in scales:
# Adjust the wavelet time vector based on the sample rate
wavelet_data = wavelet_function(np.arange(-len(signal) // 2, len(signal) // 2) / sampling_rate / scale)
convolution = np.convolve(signal, np.conj(wavelet_data), mode='same')
coefficients.append(convolution)
return np.array(coefficients)

# Define helper function
def _apply_cwt(df):
values = df[value_column].values
coeffs = compute_cwt(values, wavelet_function, scales, sample_rate)
for idx, scale in enumerate(scales):
df[f'{method}_scale_{scale}_real'] = coeffs[idx].real
df[f'{method}_scale_{scale}_imag'] = coeffs[idx].imag
return df

# Check if data is a groupby object
if isinstance(data, pd.core.groupby.generic.DataFrameGroupBy):
return pd.concat([_apply_cwt(group.sort_values(by=date_column)) for _, group in data]).reset_index(drop=True)

return _apply_cwt(data)

# Monkey-patch the method to the DataFrameGroupBy class
pd.core.groupby.DataFrameGroupBy.augment_wavelet = augment_wavelet

def morlet_wavelet(t, fc=1.0):
"""Compute the Complex Morlet wavelet"""
return np.exp(1j * np.pi * fc * t) * np.exp(-t**2 / 2)

def bump_wavelet(t, w=1.0):
"""Compute the Bump wavelet."""
s1 = np.exp(-1 / (1 - t**2))
s2 = np.exp(-w**2 / (w**2 - t**2))
condition = np.logical_and(t > -1, t < 1)
return np.where(condition, s1 * s2, 0)

def analytic_morlet_wavelet(t, w=5.0):
"""Compute the Analytic Morlet wavelet."""
s1 = np.exp(2j * np.pi * w * t)
s2 = np.exp(-(t**2) / 2)
return s1 * s2


# import pytimetk as tk

# df = tk.load_dataset('taylor_30_min', parse_dates = ['date'])
# df.head()

# result_df = (augment_wavelet(
# df,
# date_column = 'date',
# value_column ='value',
# scales = [48,336],
# sample_rate = 2,
# method = 'morlet'
# )
# )
# import pytimetk as tk
# from scipy import signal
# signal.cwt(result_df.value, signal.ricker, [15])

# df = tk.datasets.load_dataset('walmart_sales_weekly', parse_dates = ['Date'])

# result_df = (df.groupby('id').augment_wavelet(
# date_column = 'Date',
# value_column ='Weekly_Sales',
# scales = [52],
# sample_rate =1/(24*7),
# method = 'bump'
# )
# )

0 comments on commit 8c44146

Please sign in to comment.