diff --git a/CHANGELOG.md b/CHANGELOG.md index c83b33d4b..532165868 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,7 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Changed -- +- Add support of property attributes in `__repr__` and `to_dict` of `BaseMixin` ([#469](https://github.com/etna-team/etna/pull/469)) - - - diff --git a/etna/core/mixins.py b/etna/core/mixins.py index 99f47be77..d638287af 100644 --- a/etna/core/mixins.py +++ b/etna/core/mixins.py @@ -32,11 +32,11 @@ def __repr__(self): if param.kind == param.VAR_POSITIONAL: continue elif param.kind == param.VAR_KEYWORD: - for arg_, value in self.__dict__[arg].items(): + for arg_, value in getattr(self, arg).items(): args_str_representation += f"{arg_} = {repr(value)}, " else: try: - value = self.__dict__[arg] + value = getattr(self, arg) except KeyError as e: value = None warnings.warn(f"You haven't set all parameters inside class __init__ method: {e}") @@ -90,7 +90,7 @@ def to_dict(self): init_parameters = self._get_init_parameters() params = {} for arg in init_parameters.keys(): - value = self.__dict__[arg] + value = getattr(self, arg) if value is None: continue params[arg] = BaseMixin._parse_value(value=value) diff --git a/pyproject.toml b/pyproject.toml index 3681d4620..2977f6de1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,7 +209,6 @@ filterwarnings = [ "ignore: Given top_k=.* is less than n_segments=.*. Algo will filter data without Gale-Shapley run.", "ignore: This model doesn't work with exogenous features", "ignore: Some of external objects in input parameters could be not", - "ignore: You haven't set all parameters inside class __init__ method.* 'is_freezed'", # external warnings "ignore: Attribute 'logging_metrics' is an instance of `nn.Module` and is already", "ignore: Attribute 'loss' is an instance of `nn.Module` and is already", diff --git a/tests/test_core/conftest.py b/tests/test_core/conftest.py new file mode 100644 index 000000000..b5c1e1c02 --- /dev/null +++ b/tests/test_core/conftest.py @@ -0,0 +1,11 @@ +from etna.core.mixins import BaseMixin + + +class BaseDummy(BaseMixin): + def __init__(self, a: int = 1, b: int = 2): + self.a = a + self._b = b + + @property + def b(self): + return self._b diff --git a/tests/test_core/test_repr.py b/tests/test_core/test_repr.py new file mode 100644 index 000000000..05fccc16a --- /dev/null +++ b/tests/test_core/test_repr.py @@ -0,0 +1,6 @@ +from tests.test_core.conftest import BaseDummy + + +def test_repr_public_property_private_attribute(): + dummy = BaseDummy(a=1, b=2) + assert repr(dummy) == "BaseDummy(a = 1, b = 2, )" diff --git a/tests/test_core/test_set_params.py b/tests/test_core/test_set_params.py index 04f193ac6..673078cde 100644 --- a/tests/test_core/test_set_params.py +++ b/tests/test_core/test_set_params.py @@ -4,6 +4,7 @@ from etna.models import CatBoostMultiSegmentModel from etna.pipeline import Pipeline from etna.transforms import AddConstTransform +from tests.test_core.conftest import BaseDummy def test_base_mixin_set_params_changes_params_estimator(): @@ -130,3 +131,11 @@ def test_update_nested_structure(nested_structure, keys, value, expected_result) def test_update_nested_structure_fail(nested_structure, keys, value): with pytest.raises(ValueError, match=f"Structure to update is .* with type .*"): _ = BaseMixin._update_nested_structure(nested_structure, keys, value) + + +def test_set_params_public_property_private_attribute(): + dummy = BaseDummy(a=1, b=2) + dummy = dummy.set_params(**{"a": 10, "b": 20}) + expected_dict = {"a": 10, "b": 20, "_target_": "tests.test_core.conftest.BaseDummy"} + obtained_dict = dummy.to_dict() + assert obtained_dict == expected_dict diff --git a/tests/test_core/test_to_dict.py b/tests/test_core/test_to_dict.py index bc2548b1f..51a009c18 100644 --- a/tests/test_core/test_to_dict.py +++ b/tests/test_core/test_to_dict.py @@ -29,6 +29,7 @@ from etna.transforms import LogTransform from etna.transforms.decomposition.change_points_based import RupturesChangePointsModel from etna.transforms.decomposition.change_points_based import SklearnRegressionPerIntervalModel +from tests.test_core.conftest import BaseDummy def ensemble_samples(): @@ -205,3 +206,8 @@ def __init__(self, a: _Dummy): def test_warnings(): with pytest.warns(Warning, match="Some of external objects in input parameters could be not written in dict"): _ = _InvalidParsing(_Dummy()).to_dict() + + +def test_to_dict_public_property_private_attribute(): + dummy = BaseDummy(a=1, b=2) + assert dummy.to_dict() == {"a": 1, "b": 2, "_target_": "tests.test_core.conftest.BaseDummy"}