Skip to content

Commit

Permalink
disable all the tvm tests in tests/
Browse files Browse the repository at this point in the history
  • Loading branch information
mshr-h committed Feb 10, 2024
1 parent 185cc5a commit aed6996
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 1 deletion.
33 changes: 33 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
sparkml_installed,
pandas_installed,
prophet_installed,
is_on_github_actions,
)
from hummingbird.ml.exceptions import MissingBackend

Expand Down Expand Up @@ -396,6 +397,10 @@ def test_load_fails_bad_path_onnx(self):
self.assertRaises(AssertionError, hummingbird.ml.ONNXContainer.load, "nonsense")

@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_load_fails_bad_path_tvm(self):
self.assertRaises(AssertionError, hummingbird.ml.TVMContainer.load, "nonsense.zip")
self.assertRaises(AssertionError, hummingbird.ml.TVMContainer.load, "nonsense")
Expand Down Expand Up @@ -434,6 +439,10 @@ def test_torchscript_test_data(self):

# Test TVM requires test_data
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_test_data(self):
warnings.filterwarnings("ignore")
max_depth = 10
Expand All @@ -451,6 +460,10 @@ def test_tvm_test_data(self):

# Test tvm save and load
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_save_load(self):
warnings.filterwarnings("ignore")
max_depth = 10
Expand All @@ -474,6 +487,10 @@ def test_tvm_save_load(self):

# Test tvm save and load
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_save_load_digest(self):
warnings.filterwarnings("ignore")
max_depth = 10
Expand All @@ -500,6 +517,10 @@ def test_tvm_save_load_digest(self):

# Test tvm save and generic load
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_save_generic_load(self):
warnings.filterwarnings("ignore")
max_depth = 10
Expand All @@ -523,6 +544,10 @@ def test_tvm_save_generic_load(self):

# Test tvm save and load zip file
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_save_load_zip(self):
warnings.filterwarnings("ignore")
max_depth = 10
Expand All @@ -545,6 +570,10 @@ def test_tvm_save_load_zip(self):
os.remove("tvm-tmp.zip")

@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_save_load_load(self):
warnings.filterwarnings("ignore")
max_depth = 10
Expand All @@ -567,6 +596,10 @@ def test_tvm_save_load_load(self):
os.remove("tvm-tmp.zip")

@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_save_load_no_versions(self):
from hummingbird.ml.operator_converters import constants

Expand Down
52 changes: 52 additions & 0 deletions tests/test_extra_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ def test_onnx_iforest_batch(self):

# Test tvm transform with batching.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_batch_transform(self):
warnings.filterwarnings("ignore")
model = StandardScaler(with_mean=True, with_std=True)
Expand All @@ -440,6 +444,10 @@ def test_tvm_batch_transform(self):

# Test tvm regression with batching.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_regression_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
Expand All @@ -461,6 +469,10 @@ def test_tvm_regression_batch(self):

# Test tvm classification with batching.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_classification_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
Expand All @@ -482,6 +494,10 @@ def test_tvm_classification_batch(self):

# Test tvm iforest with batching.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_iforest_batch(self):
warnings.filterwarnings("ignore")
num_classes = 2
Expand All @@ -503,6 +519,10 @@ def test_tvm_iforest_batch(self):

# Test tvm transform with batching and uneven numer of records.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_batch_remainder_transform(self):
warnings.filterwarnings("ignore")
model = StandardScaler(with_mean=True, with_std=True)
Expand All @@ -521,6 +541,10 @@ def test_tvm_batch_remainder_transform(self):

# Test tvm regression with batching and uneven numer of records.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_regression_remainder_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
Expand All @@ -542,6 +566,10 @@ def test_tvm_regression_remainder_batch(self):

# Test tvm classification with batching and uneven numer of records.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_classification_remainder_batch(self):
warnings.filterwarnings("ignore")
max_depth = 10
Expand All @@ -564,6 +592,10 @@ def test_tvm_classification_remainder_batch(self):

# Test tvm iforest with batching and uneven numer of records.
@unittest.skipIf(not tvm_installed(), reason="TVM test require TVM")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_iforest_remainder_batch(self):
warnings.filterwarnings("ignore")
num_classes = 2
Expand Down Expand Up @@ -744,6 +776,10 @@ def test_pandas_batch_onnxml(self):
# Test batch with pandas tvm.
@unittest.skipIf(not pandas_installed(), reason="Test requires pandas installed")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_pandas_batch_tvm(self):
import pandas

Expand Down Expand Up @@ -801,6 +837,10 @@ def test_lightgbm_pytorch_extra_config(self):

# Test max fuse depth configuration in TVM.
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_max_fuse(self):
warnings.filterwarnings("ignore")

Expand All @@ -816,6 +856,10 @@ def test_tvm_max_fuse(self):

# Test TVM without padding returns an errror is sizes don't match.
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_no_padding(self):
warnings.filterwarnings("ignore")

Expand All @@ -832,6 +876,10 @@ def test_tvm_no_padding(self):

# Test padding in TVM.
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_padding(self):
warnings.filterwarnings("ignore")

Expand All @@ -848,6 +896,10 @@ def test_tvm_padding(self):

# Test padding in TVM does not create problems when not necessary.
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_tvm_padding_2(self):
warnings.filterwarnings("ignore")

Expand Down
11 changes: 10 additions & 1 deletion tests/test_xgboost_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests XGBoost converters.
"""
import sys
import unittest
import warnings

Expand All @@ -9,7 +10,7 @@
from sklearn.model_selection import train_test_split

import hummingbird.ml
from hummingbird.ml._utils import xgboost_installed, tvm_installed, pandas_installed
from hummingbird.ml._utils import xgboost_installed, tvm_installed, pandas_installed, is_on_github_actions
from hummingbird.ml import constants
from tree_utils import gbdt_implementation_map

Expand Down Expand Up @@ -291,6 +292,10 @@ def test_xgb_classifier_converter_torchscript(self):
# TVM backend regression.
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_xgb_regressor_converter_tvm(self):
warnings.filterwarnings("ignore")
import torch
Expand All @@ -311,6 +316,10 @@ def test_xgb_regressor_converter_tvm(self):
# Test TVM backend classification.
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
@unittest.skipIf(
((sys.platform == "linux") and is_on_github_actions()),
reason="This test is flaky on Ubuntu on GitHub Actions. See https://github.com/microsoft/hummingbird/pull/709 for more info.",
)
def test_xgb_classifier_converter_tvm(self):
warnings.filterwarnings("ignore")
import torch
Expand Down

0 comments on commit aed6996

Please sign in to comment.