From e765fc8c3e1b9ac8f0287c7e5df5e14e1f4cfadd Mon Sep 17 00:00:00 2001 From: Kento Nozawa Date: Sat, 17 Aug 2024 16:05:19 +0900 Subject: [PATCH] Fix E721 errors --- tests/lightgbm/test_optimize.py | 6 +++--- tests/sklearn/test_sklearn.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/lightgbm/test_optimize.py b/tests/lightgbm/test_optimize.py index 370067e3..536f8289 100644 --- a/tests/lightgbm/test_optimize.py +++ b/tests/lightgbm/test_optimize.py @@ -286,7 +286,7 @@ def test_no_eval_set_args(self) -> None: callbacks=[early_stopping(stopping_rounds=2)], ) - assert excinfo.type == ValueError + assert excinfo.type is ValueError assert str(excinfo.value) == "`valid_sets` is required." @pytest.mark.parametrize( @@ -314,7 +314,7 @@ def test_inconsistent_study_direction(self, metric: str, study_direction: str) - study=study, ) - assert excinfo.type == ValueError + assert excinfo.type is ValueError assert str(excinfo.value).startswith("Study direction is inconsistent with the metric") def test_with_minimum_required_args(self) -> None: @@ -769,7 +769,7 @@ def test_inconsistent_study_direction(self, metric: str, study_direction: str) - study=study, ) - assert excinfo.type == ValueError + assert excinfo.type is ValueError assert str(excinfo.value).startswith("Study direction is inconsistent with the metric") def test_with_minimum_required_args(self) -> None: diff --git a/tests/sklearn/test_sklearn.py b/tests/sklearn/test_sklearn.py index cd48f177..822760c1 100644 --- a/tests/sklearn/test_sklearn.py +++ b/tests/sklearn/test_sklearn.py @@ -112,8 +112,8 @@ def test_optuna_search_properties() -> None: assert np.allclose(optuna_search.classes_, np.array([0, 1, 2])) assert optuna_search.n_trials_ == 10 assert optuna_search.user_attrs_ == {"dataset": "blobs"} - assert type(optuna_search.predict_log_proba(X)) == np.ndarray - assert type(optuna_search.predict_proba(X)) == np.ndarray + assert isinstance(optuna_search.predict_log_proba(X), np.ndarray) + assert isinstance(optuna_search.predict_proba(X), np.ndarray) @pytest.mark.filterwarnings("ignore::UserWarning") @@ -139,8 +139,8 @@ def test_optuna_search_transforms() -> None: est, {}, cv=3, error_score="raise", random_state=0, return_train_score=True ) optuna_search.fit(X) - assert type(optuna_search.transform(X)) == np.ndarray - assert type(optuna_search.inverse_transform(X)) == np.ndarray + assert isinstance(optuna_search.transform(X), np.ndarray) + assert isinstance(optuna_search.inverse_transform(X), np.ndarray) def test_optuna_search_invalid_estimator() -> None: