From f84ed67c0ad15505042ea7fbd4ae1698bec9b56a Mon Sep 17 00:00:00 2001 From: d-a-bunin <142778107+d-a-bunin@users.noreply.github.com> Date: Fri, 2 Aug 2024 10:21:19 +0300 Subject: [PATCH] Fix ensembles to have an empty `params_to_tune` (#432) --- CHANGELOG.md | 1 + etna/ensembles/direct_ensemble.py | 4 ++-- etna/ensembles/stacking_ensemble.py | 4 ++-- etna/ensembles/voting_ensemble.py | 4 ++-- tests/test_ensembles/test_direct_ensemble.py | 6 +++--- tests/test_ensembles/test_stacking_ensemble.py | 6 +++--- tests/test_ensembles/test_voting_ensemble.py | 6 +++--- 7 files changed, 16 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 605342211..11210d737 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix forecast visualization with `horizon=1` ([#426](https://github.com/etna-team/etna/pull/426)) - Set upper bound `<2` on numpy version ([#431](https://github.com/etna-team/etna/pull/431)) - +- Fix `VotingEnsemble`, `StackingEnsemble`, `DirectEnsemble` have a valid `params_to_tune` that returns empty dict ([#432](https://github.com/etna-team/etna/pull/432)) - Fix passing custom model to `STLTransform` ([#412](https://github.com/etna-team/etna/pull/412)) - - Update `TSDataset.describe`, `TSDataset.info` to exclude target intervals and target components in `num_exogs` ([#405](https://github.com/etna-team/etna/pull/405)) diff --git a/etna/ensembles/direct_ensemble.py b/etna/ensembles/direct_ensemble.py index 7c8178a0c..de75c185e 100644 --- a/etna/ensembles/direct_ensemble.py +++ b/etna/ensembles/direct_ensemble.py @@ -176,11 +176,11 @@ def _predict( def params_to_tune(self) -> Dict[str, BaseDistribution]: """Get hyperparameter grid to tune. - Not implemented for this class. + Currently, returns empty dict, but could have a proper implementation in the future. Returns ------- : Grid with hyperparameters. """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't support this method!") + return {} diff --git a/etna/ensembles/stacking_ensemble.py b/etna/ensembles/stacking_ensemble.py index ebbc9d8ab..ae48416d4 100644 --- a/etna/ensembles/stacking_ensemble.py +++ b/etna/ensembles/stacking_ensemble.py @@ -281,11 +281,11 @@ def _predict( def params_to_tune(self) -> Dict[str, BaseDistribution]: """Get hyperparameter grid to tune. - Not implemented for this class. + Currently, returns empty dict, but could have a proper implementation in the future. Returns ------- : Grid with hyperparameters. """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't support this method!") + return {} diff --git a/etna/ensembles/voting_ensemble.py b/etna/ensembles/voting_ensemble.py index 43ad590f7..4a43632c2 100644 --- a/etna/ensembles/voting_ensemble.py +++ b/etna/ensembles/voting_ensemble.py @@ -235,11 +235,11 @@ def _predict( def params_to_tune(self) -> Dict[str, BaseDistribution]: """Get hyperparameter grid to tune. - Not implemented for this class. + Currently, returns empty dict, but could have a proper implementation in the future. Returns ------- : Grid with hyperparameters. """ - raise NotImplementedError(f"{self.__class__.__name__} doesn't support this method!") + return {} diff --git a/tests/test_ensembles/test_direct_ensemble.py b/tests/test_ensembles/test_direct_ensemble.py index fa75902fb..52da44761 100644 --- a/tests/test_ensembles/test_direct_ensemble.py +++ b/tests/test_ensembles/test_direct_ensemble.py @@ -206,9 +206,9 @@ def test_predict_with_return_components_fails(example_tsds, direct_ensemble_pipe direct_ensemble_pipeline.predict(ts=example_tsds, return_components=True) -def test_params_to_tune_not_implemented(direct_ensemble_pipeline): - with pytest.raises(NotImplementedError, match="DirectEnsemble doesn't support this method"): - _ = direct_ensemble_pipeline.params_to_tune() +def test_params_to_tune(direct_ensemble_pipeline): + result = direct_ensemble_pipeline.params_to_tune() + assert result == {} @pytest.mark.parametrize("n_jobs", (1, 5)) diff --git a/tests/test_ensembles/test_stacking_ensemble.py b/tests/test_ensembles/test_stacking_ensemble.py index bcda4c56e..bf261fb7e 100644 --- a/tests/test_ensembles/test_stacking_ensemble.py +++ b/tests/test_ensembles/test_stacking_ensemble.py @@ -453,6 +453,6 @@ def test_ts_with_segment_named_target( assert isinstance(df, pd.DataFrame) -def test_params_to_tune_not_implemented(stacking_ensemble_pipeline): - with pytest.raises(NotImplementedError, match="StackingEnsemble doesn't support this method"): - _ = stacking_ensemble_pipeline.params_to_tune() +def test_params_to_tune(stacking_ensemble_pipeline): + result = stacking_ensemble_pipeline.params_to_tune() + assert result == {} diff --git a/tests/test_ensembles/test_voting_ensemble.py b/tests/test_ensembles/test_voting_ensemble.py index 4ba56753c..41831a811 100644 --- a/tests/test_ensembles/test_voting_ensemble.py +++ b/tests/test_ensembles/test_voting_ensemble.py @@ -310,6 +310,6 @@ def test_predict_with_return_components_fails(example_tsds, voting_ensemble_naiv voting_ensemble_naive.predict(ts=example_tsds, return_components=True) -def test_params_to_tune_not_implemented(voting_ensemble_pipeline): - with pytest.raises(NotImplementedError, match="VotingEnsemble doesn't support this method"): - _ = voting_ensemble_pipeline.params_to_tune() +def test_params_to_tune(voting_ensemble_pipeline): + result = voting_ensemble_pipeline.params_to_tune() + assert result == {}