Skip to content

Commit

Permalink
[python-package] use dataclass for CallbackEnv (#6048)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Aug 21, 2023
1 parent 5fe84f8 commit 4ea170f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 16 deletions.
1 change: 1 addition & 0 deletions .ci/test-python-oldest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#
echo "installing lightgbm's dependencies"
pip install \
'dataclasses' \
'numpy==1.12.0' \
'pandas==0.24.0' \
'scikit-learn==0.18.2' \
Expand Down
28 changes: 16 additions & 12 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# coding: utf-8
"""Callbacks library."""
import collections
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning

if TYPE_CHECKING:
from .engine import CVBooster

__all__ = [
'early_stopping',
Expand Down Expand Up @@ -43,14 +47,14 @@ def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) ->


# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
"CallbackEnv",
["model",
"params",
"iteration",
"begin_iteration",
"end_iteration",
"evaluation_result_list"])
@dataclass
class CallbackEnv:
model: Union[Booster, "CVBooster"]
params: Dict[str, Any]
iteration: int
begin_iteration: int
end_iteration: int
evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]]


def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
Expand Down Expand Up @@ -126,7 +130,7 @@ def _init(self, env: CallbackEnv) -> None:
data_name, eval_name = item[:2]
else: # cv
data_name, eval_name = item[1].split()
self.eval_result.setdefault(data_name, collections.OrderedDict())
self.eval_result.setdefault(data_name, OrderedDict())
if len(item) == 4:
self.eval_result[data_name].setdefault(eval_name, [])
else:
Expand Down
8 changes: 4 additions & 4 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# coding: utf-8
"""Library with training routines of LightGBM."""
import collections
import copy
import json
from collections import OrderedDict, defaultdict
from operator import attrgetter
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -293,7 +293,7 @@ def train(
booster.best_iteration = earlyStopException.best_iteration + 1
evaluation_result_list = earlyStopException.best_score
break
booster.best_score = collections.defaultdict(collections.OrderedDict)
booster.best_score = defaultdict(OrderedDict)
for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score
if not keep_training_booster:
Expand Down Expand Up @@ -526,7 +526,7 @@ def _agg_cv_result(
raw_results: List[List[Tuple[str, str, float, bool]]]
) -> List[Tuple[str, str, float, bool, float]]:
"""Aggregate cross-validation results."""
cvmap: Dict[str, List[float]] = collections.OrderedDict()
cvmap: Dict[str, List[float]] = OrderedDict()
metric_type: Dict[str, bool] = {}
for one_result in raw_results:
for one_line in one_result:
Expand Down Expand Up @@ -717,7 +717,7 @@ def cv(
.set_feature_name(feature_name) \
.set_categorical_feature(categorical_feature)

results = collections.defaultdict(list)
results = defaultdict(list)
cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold,
params=params, seed=seed, fpreproc=fpreproc,
stratified=stratified, shuffle=shuffle,
Expand Down
1 change: 1 addition & 0 deletions python-package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
"dataclasses ; python_version < '3.7'",
"numpy",
"scipy"
]
Expand Down

0 comments on commit 4ea170f

Please sign in to comment.