Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brsnw250 committed Jul 31, 2024
1 parent 942e1f9 commit 8d3e999
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 64 deletions.
2 changes: 1 addition & 1 deletion etna/transforms/decomposition/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def transform(self, ts: TSDataset) -> TSDataset:

if ts.index.min() < self._first_timestamp:
raise ValueError(
f"First index of the dataset to be transformed must be larger or equal then {self._first_timestamp}!"
f"First index of the dataset to be transformed must be larger or equal than {self._first_timestamp}!"
)

if ts.index.min() > self._last_timestamp:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_transforms/test_decomposition/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
import pytest

from etna.datasets import TSDataset
from etna.datasets import generate_ar_df


@pytest.fixture()
def ts_with_exogs() -> TSDataset:
periods = 100
periods_exog = periods + 10
df = generate_ar_df(start_time="2020-01-01", periods=periods, freq="D", n_segments=2)
df_exog = generate_ar_df(start_time="2020-01-01", periods=periods_exog, freq="D", n_segments=2, random_seed=2)
df_exog.rename(columns={"target": "exog"}, inplace=True)
df_exog["holiday"] = np.random.choice([0, 1], size=periods_exog * 2)

ts = TSDataset(df, freq="D", df_exog=df_exog, known_future="all")
return ts


@pytest.fixture()
def ts_with_exogs_train_test(ts_with_exogs):
return ts_with_exogs.train_test_split(test_size=20)


@pytest.fixture()
def forward_stride_datasets(ts_with_exogs):
train_df = ts_with_exogs.df.iloc[:-10]
test_df = ts_with_exogs.df.iloc[-20:]

train_ts = TSDataset(df=train_df, freq=ts_with_exogs.freq)
test_ts = TSDataset(df=test_df, freq=ts_with_exogs.freq)

return train_ts, test_ts
30 changes: 0 additions & 30 deletions tests/test_transforms/test_decomposition/test_dft_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest

from etna.datasets import TSDataset
from etna.datasets import generate_ar_df
from etna.metrics import MAE
from etna.models import CatBoostPerSegmentModel
from etna.models import HoltWintersModel
Expand All @@ -23,35 +22,6 @@ def simple_pipeline_with_decompose(in_column, horizon, k):
return pipeline


@pytest.fixture()
def ts_with_exogs() -> TSDataset:
periods = 100
periods_exog = periods + 10
df = generate_ar_df(start_time="2020-01-01", periods=periods, freq="D", n_segments=2)
df_exog = generate_ar_df(start_time="2020-01-01", periods=periods_exog, freq="D", n_segments=2, random_seed=2)
df_exog.rename(columns={"target": "exog"}, inplace=True)
df_exog["holiday"] = np.random.choice([0, 1], size=periods_exog * 2)

ts = TSDataset(df, freq="D", df_exog=df_exog, known_future="all")
return ts


@pytest.fixture()
def ts_with_exogs_train_test(ts_with_exogs):
return ts_with_exogs.train_test_split(test_size=20)


@pytest.fixture()
def forward_stride_datasets(ts_with_exogs):
train_df = ts_with_exogs.df.iloc[:-10]
test_df = ts_with_exogs.df.iloc[-20:]

train_ts = TSDataset(df=train_df, freq=ts_with_exogs.freq)
test_ts = TSDataset(df=test_df, freq=ts_with_exogs.freq)

return train_ts, test_ts


@pytest.fixture()
def ts_with_missing(ts_with_exogs):
target_df = ts_with_exogs[..., "target"]
Expand Down
33 changes: 0 additions & 33 deletions tests/test_transforms/test_decomposition/test_model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.datasets import generate_ar_df
from etna.metrics import MAE
from etna.models import BATSModel
from etna.models import CatBoostPerSegmentModel
Expand All @@ -28,37 +26,6 @@ def simple_pipeline_with_decompose(in_column, horizon):
return pipeline


@pytest.fixture()
def ts_with_exogs() -> TSDataset:
periods = 100
periods_exog = periods + 10
df = generate_ar_df(start_time="2020-01-01", periods=periods, freq="D", n_segments=2)
df_exog = generate_ar_df(start_time="2020-01-01", periods=periods_exog, freq="D", n_segments=2, random_seed=2)
df_exog.rename(columns={"target": "exog"}, inplace=True)
df_exog["holiday"] = np.random.choice([0, 1], size=periods_exog * 2)

df = TSDataset.to_dataset(df)
df_exog = TSDataset.to_dataset(df_exog)
ts = TSDataset(df, freq="D", df_exog=df_exog, known_future="all")
return ts


@pytest.fixture()
def ts_with_exogs_train_test(ts_with_exogs):
return ts_with_exogs.train_test_split(test_size=20)


@pytest.fixture()
def forward_stride_datasets(ts_with_exogs):
train_df = ts_with_exogs.df.iloc[:-10]
test_df = ts_with_exogs.df.iloc[-20:]

train_ts = TSDataset(df=train_df, freq=ts_with_exogs.freq)
test_ts = TSDataset(df=test_df, freq=ts_with_exogs.freq)

return train_ts, test_ts


@pytest.mark.parametrize("in_column", ("target", "feat"))
def test_init(in_column):
transform = ModelDecomposeTransform(model=HoltWintersModel(), in_column=in_column)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3015,6 +3015,7 @@ def test_inverse_transform_future_with_target_fail_difference(
"transform, dataset_name, expected_changes",
[
(FourierDecomposeTransform(in_column="target", k=5, residuals=True), "regular_ts", {}),
(ModelDecomposeTransform(model=HoltWintersModel(), in_column="target", residuals=True), "regular_ts", {}),
],
)
def test_inverse_transform_future_with_target_fail_require_history(
Expand Down
5 changes: 5 additions & 0 deletions tests/test_transforms/test_inference/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2855,6 +2855,11 @@ def test_transform_future_with_target(self, transform, dataset_name, expected_ch
"regular_ts",
{"create": {"target_dft_0", "target_dft_1", "target_dft_residuals"}},
),
(
ModelDecomposeTransform(model=HoltWintersModel(), in_column="target", residuals=True),
"regular_ts",
{"create": {"target_level", "target_residuals"}},
),
),
)
def test_transform_future_with_target_fail_require_history(
Expand Down

0 comments on commit 8d3e999

Please sign in to comment.