diff --git a/CHANGELOG.md b/CHANGELOG.md index d8416a07e..94bb83678 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,7 +56,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - - -- +- Fix holidays during loading datasets `traffic_2008_10T` and `traffic_2008_hourly` ([#462](https://github.com/etna-team/etna/pull/462)) - - - diff --git a/etna/datasets/internal_datasets.py b/etna/datasets/internal_datasets.py index 67cb74469..fdf367efb 100644 --- a/etna/datasets/internal_datasets.py +++ b/etna/datasets/internal_datasets.py @@ -15,7 +15,6 @@ from typing import Tuple from typing import Union -import holidays import numpy as np import pandas as pd @@ -162,7 +161,7 @@ def load_dataset( data, dataset_hash = read_dataset(dataset_path=dataset_dir / f"{name}_{part}.csv.gz") if dataset_hash != datasets_dict[name]["hash"][part]: warnings.warn( - f"Local hash and expected hash are different for {name} record part {part}." + f"Local hash and expected hash are different for {name} record part {part}. " "The first possible reason is that the local copy of the dataset is out of date. In this case you can " "try setting rebuild_dataset=True to rebuild the dataset. The second possible reason is that the local " "copy of the dataset reflects a more recent version of the data than your version of the library. " @@ -348,11 +347,22 @@ def read_data(path: Path, part: str) -> np.ndarray: targets = np.concatenate([targets_train, targets_test], axis=0) targets = targets[np.argsort(ts_indecies)].reshape(-1, 963) - drop_days = ( - list(holidays.country_holidays(country="US", years=2008).keys()) - + list(holidays.country_holidays(country="US", years=2009).keys())[:3] - + [date(2009, 3, 8), date(2009, 3, 10)] - ) + # federal holidays and days with anomalies + drop_days = [ + date(2008, 1, 1), + date(2008, 1, 21), + date(2008, 2, 18), + date(2008, 5, 26), + date(2008, 7, 4), + date(2008, 9, 1), + date(2008, 10, 13), + date(2008, 11, 11), + date(2008, 11, 27), + date(2008, 12, 25), + date(2009, 1, 1), + date(2009, 1, 19), + date(2009, 2, 16), + ] + [date(2008, 3, 8), date(2009, 3, 9)] dates_df = pd.DataFrame({"timestamp": pd.date_range("2008-01-01 00:00:00", "2009-03-30 23:50:00", freq="10T")}) dates_df["dt"] = dates_df["timestamp"].dt.date @@ -917,9 +927,9 @@ def list_datasets() -> List[str]: "freq": "10T", "parts": ("train", "test", "full"), "hash": { - "train": "4d8d1367fd5341475b852fe9779d0b05", + "train": "f22f77c170e698f4f51231b24e5bc9f0", "test": "261ee7b09e50d1c7e1e74ccf08412f3f", - "full": "f0c9229d78cfa5b0abf5be950b6843b2", + "full": "d1d05602b15aa30d461e21148483a0c8", }, }, "traffic_2008_hourly": { @@ -927,9 +937,9 @@ def list_datasets() -> List[str]: "freq": "H", "parts": ("train", "test", "full"), "hash": { - "train": "7e6609cce30ae22004c7d7b1d39a35d5", + "train": "161748edc508b4e206344fcbb984bf9a", "test": "adc3fa06ee856c6481faa400e9e9f602", - "full": "8d988a81e8c368164aada708be27a1c2", + "full": "899bc1fa3fc334868a9e41033a2c3a52", }, }, "traffic_2015_hourly": {