diff --git a/src/objective/regression_objective.hpp b/src/objective/regression_objective.hpp index 71c1a6d7cdfe..eb149756c205 100644 --- a/src/objective/regression_objective.hpp +++ b/src/objective/regression_objective.hpp @@ -24,14 +24,14 @@ namespace LightGBM { for (data_size_t i = 0; i < cnt_data; ++i) { \ ref_data[i] = data_reader(i); \ } \ - const double float_pos = static_cast(1.0 - alpha) * cnt_data; \ - const data_size_t pos = static_cast(float_pos); \ + const double float_pos = static_cast(cnt_data - 1) * (1.0 - alpha); \ + const data_size_t pos = static_cast(float_pos) + 1; \ if (pos < 1) { \ return ref_data[ArrayArgs::ArgMax(ref_data)]; \ } else if (pos >= cnt_data) { \ return ref_data[ArrayArgs::ArgMin(ref_data)]; \ } else { \ - const double bias = float_pos - pos; \ + const double bias = float_pos - (pos - 1); \ if (pos > cnt_data / 2) { \ ArrayArgs::ArgMaxAtK(&ref_data, 0, cnt_data, pos - 1); \ T v1 = ref_data[pos - 1]; \ diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index e87cea3bfcbb..d87f90bec775 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -142,7 +142,7 @@ def test_regression(objective): elif objective == 'quantile': assert ret < 1311 else: - assert ret < 338 + assert ret < 343 assert evals_result['valid_0']['l2'][-1] == pytest.approx(ret)