From 4ea170f30a7b185d3804741d6fa2cbafe9115cc2 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Mon, 21 Aug 2023 12:05:37 -0500 Subject: [PATCH] [python-package] use dataclass for CallbackEnv (#6048) --- .ci/test-python-oldest.sh | 1 + python-package/lightgbm/callback.py | 28 ++++++++++++++++------------ python-package/lightgbm/engine.py | 8 ++++---- python-package/pyproject.toml | 1 + 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/.ci/test-python-oldest.sh b/.ci/test-python-oldest.sh index 09cc24633e15..3a0ea08dddda 100644 --- a/.ci/test-python-oldest.sh +++ b/.ci/test-python-oldest.sh @@ -7,6 +7,7 @@ # echo "installing lightgbm's dependencies" pip install \ + 'dataclasses' \ 'numpy==1.12.0' \ 'pandas==0.24.0' \ 'scikit-learn==0.18.2' \ diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 77856f5bdab6..ccf0059faf84 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -1,10 +1,14 @@ # coding: utf-8 """Callbacks library.""" -import collections +from collections import OrderedDict +from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union -from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning +from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning + +if TYPE_CHECKING: + from .engine import CVBooster __all__ = [ 'early_stopping', @@ -43,14 +47,14 @@ def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> # Callback environment used by callbacks -CallbackEnv = collections.namedtuple( - "CallbackEnv", - ["model", - "params", - "iteration", - "begin_iteration", - "end_iteration", - "evaluation_result_list"]) +@dataclass +class CallbackEnv: + model: Union[Booster, "CVBooster"] + params: Dict[str, Any] + iteration: int + begin_iteration: int + end_iteration: int + evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]] def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str: @@ -126,7 +130,7 @@ def _init(self, env: CallbackEnv) -> None: data_name, eval_name = item[:2] else: # cv data_name, eval_name = item[1].split() - self.eval_result.setdefault(data_name, collections.OrderedDict()) + self.eval_result.setdefault(data_name, OrderedDict()) if len(item) == 4: self.eval_result[data_name].setdefault(eval_name, []) else: diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 2d640d741629..5e687e4d2678 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -1,8 +1,8 @@ # coding: utf-8 """Library with training routines of LightGBM.""" -import collections import copy import json +from collections import OrderedDict, defaultdict from operator import attrgetter from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -293,7 +293,7 @@ def train( booster.best_iteration = earlyStopException.best_iteration + 1 evaluation_result_list = earlyStopException.best_score break - booster.best_score = collections.defaultdict(collections.OrderedDict) + booster.best_score = defaultdict(OrderedDict) for dataset_name, eval_name, score, _ in evaluation_result_list: booster.best_score[dataset_name][eval_name] = score if not keep_training_booster: @@ -526,7 +526,7 @@ def _agg_cv_result( raw_results: List[List[Tuple[str, str, float, bool]]] ) -> List[Tuple[str, str, float, bool, float]]: """Aggregate cross-validation results.""" - cvmap: Dict[str, List[float]] = collections.OrderedDict() + cvmap: Dict[str, List[float]] = OrderedDict() metric_type: Dict[str, bool] = {} for one_result in raw_results: for one_line in one_result: @@ -717,7 +717,7 @@ def cv( .set_feature_name(feature_name) \ .set_categorical_feature(categorical_feature) - results = collections.defaultdict(list) + results = defaultdict(list) cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold, params=params, seed=seed, fpreproc=fpreproc, stratified=stratified, shuffle=shuffle, diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 2d7bc34576dc..40d57e1af634 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence" ] dependencies = [ + "dataclasses ; python_version < '3.7'", "numpy", "scipy" ]