Skip to content

Commit

Permalink
Support property attributes in repr of BaseMixin (#469)
Browse files Browse the repository at this point in the history
* replace self.__dict__[arg] with getattr(self, arg)

* update changelog

* add tests

* fix

* fix

---------

Co-authored-by: Egor Baturin <[email protected]>
  • Loading branch information
egoriyaa and Egor Baturin authored Sep 5, 2024
1 parent 7dbbf50 commit 72ab4dd
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-

### Changed
-
- Add support of property attributes in `__repr__` and `to_dict` of `BaseMixin` ([#469](https://github.com/etna-team/etna/pull/469))
-
-
-
Expand Down
6 changes: 3 additions & 3 deletions etna/core/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def __repr__(self):
if param.kind == param.VAR_POSITIONAL:
continue
elif param.kind == param.VAR_KEYWORD:
for arg_, value in self.__dict__[arg].items():
for arg_, value in getattr(self, arg).items():
args_str_representation += f"{arg_} = {repr(value)}, "
else:
try:
value = self.__dict__[arg]
value = getattr(self, arg)
except KeyError as e:
value = None
warnings.warn(f"You haven't set all parameters inside class __init__ method: {e}")
Expand Down Expand Up @@ -90,7 +90,7 @@ def to_dict(self):
init_parameters = self._get_init_parameters()
params = {}
for arg in init_parameters.keys():
value = self.__dict__[arg]
value = getattr(self, arg)
if value is None:
continue
params[arg] = BaseMixin._parse_value(value=value)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ filterwarnings = [
"ignore: Given top_k=.* is less than n_segments=.*. Algo will filter data without Gale-Shapley run.",
"ignore: This model doesn't work with exogenous features",
"ignore: Some of external objects in input parameters could be not",
"ignore: You haven't set all parameters inside class __init__ method.* 'is_freezed'",
# external warnings
"ignore: Attribute 'logging_metrics' is an instance of `nn.Module` and is already",
"ignore: Attribute 'loss' is an instance of `nn.Module` and is already",
Expand Down
11 changes: 11 additions & 0 deletions tests/test_core/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from etna.core.mixins import BaseMixin


class BaseDummy(BaseMixin):
def __init__(self, a: int = 1, b: int = 2):
self.a = a
self._b = b

@property
def b(self):
return self._b
6 changes: 6 additions & 0 deletions tests/test_core/test_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from tests.test_core.conftest import BaseDummy


def test_repr_public_property_private_attribute():
dummy = BaseDummy(a=1, b=2)
assert repr(dummy) == "BaseDummy(a = 1, b = 2, )"
9 changes: 9 additions & 0 deletions tests/test_core/test_set_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from etna.models import CatBoostMultiSegmentModel
from etna.pipeline import Pipeline
from etna.transforms import AddConstTransform
from tests.test_core.conftest import BaseDummy


def test_base_mixin_set_params_changes_params_estimator():
Expand Down Expand Up @@ -130,3 +131,11 @@ def test_update_nested_structure(nested_structure, keys, value, expected_result)
def test_update_nested_structure_fail(nested_structure, keys, value):
with pytest.raises(ValueError, match=f"Structure to update is .* with type .*"):
_ = BaseMixin._update_nested_structure(nested_structure, keys, value)


def test_set_params_public_property_private_attribute():
dummy = BaseDummy(a=1, b=2)
dummy = dummy.set_params(**{"a": 10, "b": 20})
expected_dict = {"a": 10, "b": 20, "_target_": "tests.test_core.conftest.BaseDummy"}
obtained_dict = dummy.to_dict()
assert obtained_dict == expected_dict
6 changes: 6 additions & 0 deletions tests/test_core/test_to_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from etna.transforms import LogTransform
from etna.transforms.decomposition.change_points_based import RupturesChangePointsModel
from etna.transforms.decomposition.change_points_based import SklearnRegressionPerIntervalModel
from tests.test_core.conftest import BaseDummy


def ensemble_samples():
Expand Down Expand Up @@ -205,3 +206,8 @@ def __init__(self, a: _Dummy):
def test_warnings():
with pytest.warns(Warning, match="Some of external objects in input parameters could be not written in dict"):
_ = _InvalidParsing(_Dummy()).to_dict()


def test_to_dict_public_property_private_attribute():
dummy = BaseDummy(a=1, b=2)
assert dummy.to_dict() == {"a": 1, "b": 2, "_target_": "tests.test_core.conftest.BaseDummy"}

0 comments on commit 72ab4dd

Please sign in to comment.