Skip to content

Commit

Permalink
Series decomposition transform with FFT (#430)
Browse files Browse the repository at this point in the history
* added implementation

* added tests

* moved fixture

* updated inference tests

* updated doc

* review fixes

* review fixes

* added tests

* updated changelog
  • Loading branch information
brsnw250 authored Jul 31, 2024
1 parent 4710a4e commit 12c0e8f
Show file tree
Hide file tree
Showing 10 changed files with 635 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `TSDataset.features` property to get list of all features in a dataset ([#405](https://github.com/etna-team/etna/pull/405))
- Add `MADOutlierTransform` class for anomaly detection ([#415](https://github.com/etna-team/etna/pull/415))
- Add `MeanEncoderTransform` ([#413](https://github.com/etna-team/etna/pull/413))
- Add `FourierDecomposeTransform` transform for series decomposition using DFT ([#430](https://github.com/etna-team/etna/pull/430))

### Changed
- Allow to change `device`, `batch_size` and `num_workers` of embedding models ([#396](https://github.com/etna-team/etna/pull/396))
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_reference/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Decomposition transforms and their utilities:
decomposition.MedianPerIntervalModel
decomposition.SklearnPreprocessingPerIntervalModel
decomposition.SklearnRegressionPerIntervalModel
decomposition.FourierDecomposeTransform

Categorical encoding transforms:

Expand Down
1 change: 1 addition & 0 deletions etna/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from etna.transforms.decomposition import ChangePointsSegmentationTransform
from etna.transforms.decomposition import ChangePointsTrendTransform
from etna.transforms.decomposition import DeseasonalityTransform
from etna.transforms.decomposition import FourierDecomposeTransform
from etna.transforms.decomposition import IrreversibleChangePointsTransform
from etna.transforms.decomposition import LinearTrendTransform
from etna.transforms.decomposition import ReversibleChangePointsTransform
Expand Down
1 change: 1 addition & 0 deletions etna/transforms/decomposition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
from etna.transforms.decomposition.deseasonal import DeseasonalityTransform
from etna.transforms.decomposition.detrend import LinearTrendTransform
from etna.transforms.decomposition.detrend import TheilSenTrendTransform
from etna.transforms.decomposition.dft_based import FourierDecomposeTransform
from etna.transforms.decomposition.stl import STLTransform
200 changes: 200 additions & 0 deletions etna/transforms/decomposition/dft_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import List

import numpy as np
import pandas as pd

from etna.datasets import TSDataset
from etna.datasets.utils import determine_num_steps
from etna.transforms import IrreversibleTransform


class FourierDecomposeTransform(IrreversibleTransform):
"""Transform that uses Fourier transformation to estimate series decomposition.
Note
----
This transform decomposes only in-sample data. For the future timestamps it produces ``NaN``.
For the dataset to be transformed, it should contain at least the minimum amount of in-sample timestamps that are required by transform.
Warning
-------
This transform adds new columns to the dataset, that correspond to the selected frequencies. Such columns are named with
``dft_{i}`` suffix. Suffix index do NOT indicate any relation to the frequencies. Produced names should be thought of as
arbitrary identifiers to the produced sinusoids.
"""

def __init__(self, k: int, in_column: str = "target", residuals: bool = False):
"""Init ``FourierDecomposeTransform``.
Parameters
----------
k:
how many top positive frequencies selected for the decomposition. Selection performed proportional to the amplitudes.
in_column:
name of the processed column.
residuals:
whether to add residuals after decomposition. This guarantees that all components, including residuals, sum up to the series.
"""
if k <= 0:
raise ValueError("Parameter `k` must be positive integer!")

self.k = k
self.in_column = in_column
self.residuals = residuals

self._first_timestamp = None
self._last_timestamp = None

super().__init__(required_features=[in_column])

def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
return []

def _fit(self, df: pd.DataFrame):
"""Fit transform with the dataframe."""
pass

def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Transform provided dataframe."""
pass

@staticmethod
def _get_num_pos_freqs(series: pd.Series) -> int:
"""Get number of positive frequencies for the series."""
num_obs = len(series)
return int(np.ceil((num_obs - 1) / 2) + 1)

def _check_segments(self, df: pd.DataFrame):
"""Check if series satisfy conditions."""
segments_with_missing = []
min_num_pos_freq = float("inf")
for segment in df:
series = df[segment]
series = series.loc[series.first_valid_index() : series.last_valid_index()]
if series.isna().any():
segments_with_missing.append(segment)

min_num_pos_freq = min(min_num_pos_freq, self._get_num_pos_freqs(series))

if len(segments_with_missing) > 0:
raise ValueError(
f"Feature `{self.in_column}` contains missing values in segments: {segments_with_missing}!"
)

if self.k > min_num_pos_freq:
raise ValueError(f"Parameter `k` must not be greater then {min_num_pos_freq} for the provided dataset!")

def _dft_components(self, series: pd.Series) -> pd.DataFrame:
"""Estimate series decomposition using FFT."""
initial_index = series.index
series = series.loc[series.first_valid_index() : series.last_valid_index()]

num_pos_freqs = self._get_num_pos_freqs(series)

# compute Fourier decomposition of the series
dft_series = np.fft.fft(series)

# compute "amplitudes" for each frequency
abs_dft_series = np.abs(dft_series)

# select top-k indices
abs_pos_dft_series = abs_dft_series[:num_pos_freqs]
top_k_idxs = np.argpartition(abs_pos_dft_series, num_pos_freqs - self.k)[-self.k :]

# select top-k and separate each frequency
freq_matrix = np.diag(dft_series)
freq_matrix = freq_matrix[:num_pos_freqs]
selected_freqs = freq_matrix[top_k_idxs]

# return frequencies to initial domain
components = np.fft.ifft(selected_freqs).real

components_df = pd.DataFrame(
data=components.T, columns=[f"dft_{i}" for i in range(components.shape[0])], index=series.index
)

if self.residuals:
components_df["dft_residuals"] = series.values - np.sum(components, axis=0)

# return trailing and leading nans to the series if any existed initially
if not components_df.index.equals(initial_index):
components_df = components_df.reindex(index=initial_index, fill_value=np.nan)

return components_df

def fit(self, ts: TSDataset) -> "FourierDecomposeTransform":
"""Fit the transform and the decomposition model.
Parameters
----------
ts:
dataset to fit the transform on.
Returns
-------
:
the fitted transform instance.
"""
self._first_timestamp = ts.index.min()
self._last_timestamp = ts.index.max()

self._check_segments(df=ts[..., self.in_column].droplevel("feature", axis=1))

return self

def transform(self, ts: TSDataset) -> TSDataset:
"""Transform ``TSDataset`` inplace.
Parameters
----------
ts:
Dataset to transform.
Returns
-------
:
Transformed ``TSDataset``.
"""
if self._first_timestamp is None:
raise ValueError("Transform is not fitted!")

if ts.index.min() < self._first_timestamp:
raise ValueError(
f"First index of the dataset to be transformed must be larger or equal than {self._first_timestamp}!"
)

if ts.index.min() > self._last_timestamp:
raise ValueError(
f"Dataset to be transformed must contain historical observations in range {self._first_timestamp} - {self._last_timestamp}"
)

segment_df = ts[..., self.in_column].droplevel("feature", axis=1)

ts_max_timestamp = ts.index.max()
if ts_max_timestamp > self._last_timestamp:
future_steps = determine_num_steps(self._last_timestamp, ts_max_timestamp, freq=ts.freq)
segment_df.iloc[-future_steps:] = np.nan

self._check_segments(df=segment_df)

segments = segment_df.columns
segment_components = []
for segment in segments:
components_df = self._dft_components(series=segment_df[segment])
components_df.columns = f"{self.in_column}_" + components_df.columns

components_df.columns = pd.MultiIndex.from_product(
[[segment], components_df.columns], names=["segment", "feature"]
)

segment_components.append(components_df)

segment_components = pd.concat(segment_components, axis=1)

ts.add_columns_from_pandas(segment_components)

return ts


__all__ = ["FourierDecomposeTransform"]
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,3 +908,29 @@ def ts_with_binary_exog() -> TSDataset:
df_exog = TSDataset.to_dataset(df_exog)
ts = TSDataset(df, freq="D", df_exog=df_exog, known_future="all")
return ts


@pytest.fixture()
def outliers_solid_tsds():
"""Create TSDataset with outliers and same last date."""
timestamp = pd.date_range("2021-01-01", end="2021-02-20", freq="D")
target1 = [np.sin(i) for i in range(len(timestamp))]
target1[10] += 10

target2 = [np.sin(i) for i in range(len(timestamp))]
target2[8] += 8
target2[15] = 2
target2[26] -= 12

df1 = pd.DataFrame({"timestamp": timestamp, "target": target1, "segment": "1"})
df2 = pd.DataFrame({"timestamp": timestamp, "target": target2, "segment": "2"})
df = pd.concat([df1, df2], ignore_index=True)
df_exog = df.copy()
df_exog.columns = ["timestamp", "regressor_1", "segment"]
ts = TSDataset(
df=TSDataset.to_dataset(df).iloc[:-10],
df_exog=TSDataset.to_dataset(df_exog),
freq="D",
known_future="all",
)
return ts
Loading

0 comments on commit 12c0e8f

Please sign in to comment.