diff --git a/tests/test_lightgbm_converter.py b/tests/test_lightgbm_converter.py index d6c37d5d..ea09d33d 100644 --- a/tests/test_lightgbm_converter.py +++ b/tests/test_lightgbm_converter.py @@ -1,6 +1,7 @@ """ Tests LightGBM converters. """ +import sys import unittest import warnings @@ -8,7 +9,7 @@ import hummingbird.ml from hummingbird.ml import constants -from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed, tvm_installed +from hummingbird.ml._utils import lightgbm_installed, onnx_runtime_installed, tvm_installed, is_on_github_actions from tree_utils import gbdt_implementation_map if lightgbm_installed(): @@ -400,6 +401,10 @@ def test_lightgbm_onnx(self): # TVM backend tests. @unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM") + @unittest.skipIf( + ((sys.platform == "linux") and is_on_github_actions()), + reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.", + ) def test_lightgbm_tvm_regressor(self): warnings.filterwarnings("ignore") @@ -417,6 +422,10 @@ def test_lightgbm_tvm_regressor(self): np.testing.assert_allclose(tvm_model.predict(X), model.predict(X)) @unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM installed") + @unittest.skipIf( + ((sys.platform == "linux") and is_on_github_actions()), + reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.", + ) def test_lightgbm_tvm_classifier(self): warnings.filterwarnings("ignore") @@ -436,6 +445,10 @@ def test_lightgbm_tvm_classifier(self): # Test TVM with large input datasets. @unittest.skipIf(not (tvm_installed()), reason="TVM tests require TVM installed") + @unittest.skipIf( + ((sys.platform == "linux") and is_on_github_actions()), + reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.", + ) def test_lightgbm_tvm_classifier_large_dataset(self): warnings.filterwarnings("ignore")