Skip to content

Commit

Permalink
Fix ensembles to have an empty params_to_tune (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
d-a-bunin authored Aug 2, 2024
1 parent 28294eb commit f84ed67
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix forecast visualization with `horizon=1` ([#426](https://github.com/etna-team/etna/pull/426))
- Set upper bound `<2` on numpy version ([#431](https://github.com/etna-team/etna/pull/431))
-
- Fix `VotingEnsemble`, `StackingEnsemble`, `DirectEnsemble` have a valid `params_to_tune` that returns empty dict ([#432](https://github.com/etna-team/etna/pull/432))
- Fix passing custom model to `STLTransform` ([#412](https://github.com/etna-team/etna/pull/412))
-
- Update `TSDataset.describe`, `TSDataset.info` to exclude target intervals and target components in `num_exogs` ([#405](https://github.com/etna-team/etna/pull/405))
Expand Down
4 changes: 2 additions & 2 deletions etna/ensembles/direct_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def _predict(
def params_to_tune(self) -> Dict[str, BaseDistribution]:
"""Get hyperparameter grid to tune.
Not implemented for this class.
Currently, returns empty dict, but could have a proper implementation in the future.
Returns
-------
:
Grid with hyperparameters.
"""
raise NotImplementedError(f"{self.__class__.__name__} doesn't support this method!")
return {}
4 changes: 2 additions & 2 deletions etna/ensembles/stacking_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ def _predict(
def params_to_tune(self) -> Dict[str, BaseDistribution]:
"""Get hyperparameter grid to tune.
Not implemented for this class.
Currently, returns empty dict, but could have a proper implementation in the future.
Returns
-------
:
Grid with hyperparameters.
"""
raise NotImplementedError(f"{self.__class__.__name__} doesn't support this method!")
return {}
4 changes: 2 additions & 2 deletions etna/ensembles/voting_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,11 @@ def _predict(
def params_to_tune(self) -> Dict[str, BaseDistribution]:
"""Get hyperparameter grid to tune.
Not implemented for this class.
Currently, returns empty dict, but could have a proper implementation in the future.
Returns
-------
:
Grid with hyperparameters.
"""
raise NotImplementedError(f"{self.__class__.__name__} doesn't support this method!")
return {}
6 changes: 3 additions & 3 deletions tests/test_ensembles/test_direct_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def test_predict_with_return_components_fails(example_tsds, direct_ensemble_pipe
direct_ensemble_pipeline.predict(ts=example_tsds, return_components=True)


def test_params_to_tune_not_implemented(direct_ensemble_pipeline):
with pytest.raises(NotImplementedError, match="DirectEnsemble doesn't support this method"):
_ = direct_ensemble_pipeline.params_to_tune()
def test_params_to_tune(direct_ensemble_pipeline):
result = direct_ensemble_pipeline.params_to_tune()
assert result == {}


@pytest.mark.parametrize("n_jobs", (1, 5))
Expand Down
6 changes: 3 additions & 3 deletions tests/test_ensembles/test_stacking_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,6 @@ def test_ts_with_segment_named_target(
assert isinstance(df, pd.DataFrame)


def test_params_to_tune_not_implemented(stacking_ensemble_pipeline):
with pytest.raises(NotImplementedError, match="StackingEnsemble doesn't support this method"):
_ = stacking_ensemble_pipeline.params_to_tune()
def test_params_to_tune(stacking_ensemble_pipeline):
result = stacking_ensemble_pipeline.params_to_tune()
assert result == {}
6 changes: 3 additions & 3 deletions tests/test_ensembles/test_voting_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,6 @@ def test_predict_with_return_components_fails(example_tsds, voting_ensemble_naiv
voting_ensemble_naive.predict(ts=example_tsds, return_components=True)


def test_params_to_tune_not_implemented(voting_ensemble_pipeline):
with pytest.raises(NotImplementedError, match="VotingEnsemble doesn't support this method"):
_ = voting_ensemble_pipeline.params_to_tune()
def test_params_to_tune(voting_ensemble_pipeline):
result = voting_ensemble_pipeline.params_to_tune()
assert result == {}

0 comments on commit f84ed67

Please sign in to comment.