diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b872a629..bea8f041c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add dataset integrity check using hash for internal datasets ([#151](https://github.com/etna-team/etna/pull/151)) - Create page about internal datasets in documentation ([#175](https://github.com/etna-team/etna/pull/175)) - Add usage example of internal datasets in `101-get_started.ipynb` and `305-classification.ipynb` tutorials ([#202](https://github.com/etna-team/etna/pull/202)) -- +- Add size method to `TSDataset` class ([#238](https://github.com/etna-team/etna/pull/238)) ### Changed - Add `relevance_aggregation_mode` and `redundancy_aggregation_mode` into `MRMRFeatureSelectionTransform.params_to_tune` ([#212](https://github.com/etna-team/etna/pull/212)) diff --git a/etna/datasets/tsdataset.py b/etna/datasets/tsdataset.py index 9385e61ec..5a8f68a3d 100644 --- a/etna/datasets/tsdataset.py +++ b/etna/datasets/tsdataset.py @@ -1651,3 +1651,22 @@ def to_torch_dataset( ts_samples = [samples for df_segment in ts_segments for samples in make_samples(df_segment)] return _TorchDataset(ts_samples=ts_samples) + + def size(self) -> Tuple[int, int, Optional[int]]: + """Return size of TSDataset. + + The order of sizes is (number of time series, number of segments, + and number of features (if their amounts are equal in each segment; otherwise, returns None)). + + Returns + ------- + : + Tuple of TSDataset sizes + """ + current_number_of_features = 0 + for segment in self.segments: + cur_seg_features = self.df[segment].columns.get_level_values("feature").unique() + if current_number_of_features != 0 and current_number_of_features != len(cur_seg_features): + return len(self.index), len(self.segments), None + current_number_of_features = len(cur_seg_features) + return len(self.index), len(self.segments), current_number_of_features diff --git a/tests/test_datasets/test_dataset.py b/tests/test_datasets/test_dataset.py index eda408463..b6d3562af 100644 --- a/tests/test_datasets/test_dataset.py +++ b/tests/test_datasets/test_dataset.py @@ -606,6 +606,35 @@ def test_dataset_segment_conversion_during_init(df_segments_int): assert np.all(ts.columns.get_level_values("segment") == ["1", "2"]) +def test_size_with_diff_number_of_features(): + df_temp = generate_ar_df(start_time="2023-01-01", periods=30, n_segments=2, freq="D") + df_exog_temp = generate_ar_df(start_time="2023-01-01", periods=30, n_segments=1, freq="D") + df_exog_temp = df_exog_temp.rename({"target": "target_exog"}, axis=1) + ts_temp = TSDataset(df=TSDataset.to_dataset(df_temp), df_exog=TSDataset.to_dataset(df_exog_temp), freq="D") + assert ts_temp.size()[0] == len(df_exog_temp) + assert ts_temp.size()[1] == 2 + assert ts_temp.size()[2] is None + + +def test_size_target_only(): + df_temp = generate_ar_df(start_time="2023-01-01", periods=40, n_segments=3, freq="D") + ts_temp = TSDataset(df=TSDataset.to_dataset(df_temp), freq="D") + assert ts_temp.size()[0] == len(df_temp) / 3 + assert ts_temp.size()[1] == 3 + assert ts_temp.size()[2] == 1 + + +def simple_test_size_(): + df_temp = generate_ar_df(start_time="2023-01-01", periods=30, n_segments=2, freq="D") + df_exog_temp = generate_ar_df(start_time="2023-01-01", periods=30, n_segments=2, freq="D") + df_exog_temp = df_exog_temp.rename({"target": "target_exog"}, axis=1) + df_exog_temp["other_feature"] = 1 + ts_temp = TSDataset(df=TSDataset.to_dataset(df_temp), df_exog=TSDataset.to_dataset(df_exog_temp), freq="D") + assert ts_temp.size()[0] == len(df_exog_temp) / 2 + assert ts_temp.size()[1] == 2 + assert ts_temp.size()[2] == 3 + + @pytest.mark.xfail def test_make_future_raise_error_on_diff_endings(ts_diff_endings): with pytest.raises(ValueError, match="All segments should end at the same timestamp"):