diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index 23321ce823c3..afd716cde389 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -230,7 +230,7 @@ def test_reduce_combiner_simplify(): # Check that the remaining components are the expected ones. for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]): - assert tvm.ir.structural_equal(lhs, rhs) + tvm.ir.assert_structural_equal(lhs, rhs) # Test that components with side effects are not removed dummy = tvm.ir.GlobalVar("dummy") diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 1a876548af31..9a0245d27487 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -32,7 +32,7 @@ def test_simplify_reshape_flattened_index(): ana.bind(i1, tvm.ir.Range(0, 3)) i_flattened = i0 * 3 + i1 - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4), i_flattened, ) diff --git a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py index df54f7ce55f1..88ae2cba5f57 100644 --- a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py +++ b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py @@ -211,7 +211,7 @@ def test_primary_operands_all_scalars(): mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) new_mod = relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body) + tvm.ir.assert_structural_equal(mod[global_var].body, new_mod[global_var].body) @tvm.testing.requires_cmsisnn @@ -253,7 +253,7 @@ def test_all_primary_operands_tensor_constants(): mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) new_mod = relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body) + tvm.ir.assert_structural_equal(mod[global_var].body, new_mod[global_var].body) @tvm.testing.requires_cmsisnn @@ -294,7 +294,7 @@ def test_duplicate_constant_arguments(): mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) new_mod = relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body) + tvm.ir.assert_structural_equal(mod[global_var].body, new_mod[global_var].body) @tvm.testing.requires_cmsisnn @@ -329,7 +329,7 @@ def get_mod(): expected = get_mod()["external_function"].body actual = ScalarToTensorConstants()(get_mod())["external_function"].body - assert tvm.ir.structural_equal(expected, actual) + tvm.ir.assert_structural_equal(expected, actual) if __name__ == "__main__": diff --git a/tests/python/contrib/test_coreml_codegen.py b/tests/python/contrib/test_coreml_codegen.py index f0cdf14aa019..f4f84876fe13 100644 --- a/tests/python/contrib/test_coreml_codegen.py +++ b/tests/python/contrib/test_coreml_codegen.py @@ -100,7 +100,7 @@ def test_annotate(): mod = transform.PartitionGraph()(mod) expected = _create_graph_annotated() - assert tvm.ir.structural_equal(mod, expected, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected, map_free_vars=True) @pytest.mark.skipif(not _has_xcode(), reason="Xcode is not available") diff --git a/tests/python/contrib/test_ethosn/test_convert_equivalents.py b/tests/python/contrib/test_ethosn/test_convert_equivalents.py index 58173a9ea6c3..5f05804517b2 100644 --- a/tests/python/contrib/test_ethosn/test_convert_equivalents.py +++ b/tests/python/contrib/test_ethosn/test_convert_equivalents.py @@ -30,16 +30,6 @@ from .test_addition import _get_addition_qnn_params -def _assert_structural_equal(a, b): - """Check structural equality of two Relay expressions.""" - reason = ( - "Actual and expected relay functions are not equal. " - "ConvertEquivalents is not correctly transforming the input " - "graph." - ) - assert tvm.ir.structural_equal(a, b), reason - - @requires_ethosn @pytest.mark.parametrize("dtype", ["uint8", "int8"]) @pytest.mark.parametrize("shape,channels", [((1, 4, 4, 8), 8), ((1, 16, 12, 4), 4)]) @@ -114,7 +104,7 @@ def expected(): mod = before() mod = ConvertEquivalents()(mod) expected_mod = expected() - _assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"]) + tvm.ir.assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"]) @requires_ethosn @@ -221,7 +211,7 @@ def expected(): mod = before() mod = ConvertEquivalents()(mod) expected_mod = expected() - _assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"]) + tvm.ir.assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"]) @requires_ethosn @@ -438,7 +428,7 @@ def expected(): mod = before() mod = ConvertEquivalents()(mod) expected_mod = expected() - _assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"]) + tvm.ir.assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"]) @requires_ethosn diff --git a/tests/python/contrib/test_ethosn/test_inline_partitions.py b/tests/python/contrib/test_ethosn/test_inline_partitions.py index 79c35fc5bcb2..735148bc660a 100644 --- a/tests/python/contrib/test_ethosn/test_inline_partitions.py +++ b/tests/python/contrib/test_ethosn/test_inline_partitions.py @@ -27,16 +27,6 @@ from . import infrastructure as tei -def _assert_structural_equal(a, b): - """Check structural equality of two Relay expressions.""" - reason = ( - "Actual and expected relay functions are not equal. " - "InlineNonComputeIntensiveSubgraphs is not correctly " - "transforming the input graph." - ) - assert tvm.ir.structural_equal(a, b, map_free_vars=True), reason - - @requires_ethosn def test_single_reshape(): """Check that a single reshape is inlined correctly.""" @@ -57,7 +47,7 @@ def expected(): mod = before() mod = InlineNonComputeIntensivePartitions()(mod) expected_mod = expected() - _assert_structural_equal(mod, expected_mod) + tvm.ir.assert_structural_equal(mod, expected_mod) @requires_ethosn @@ -86,7 +76,7 @@ def expected(): mod = before() mod = InlineNonComputeIntensivePartitions()(mod) expected_mod = expected() - _assert_structural_equal(mod, expected_mod) + tvm.ir.assert_structural_equal(mod, expected_mod) @requires_ethosn @@ -105,7 +95,7 @@ def before(): mod = before() transformed_mod = InlineNonComputeIntensivePartitions()(mod) for global_var in mod.get_global_vars(): - _assert_structural_equal(mod[global_var], transformed_mod[global_var]) + tvm.ir.assert_structural_equal(mod[global_var], transformed_mod[global_var]) @requires_ethosn @@ -164,4 +154,8 @@ def expected(): mod = InlineNonComputeIntensivePartitions()(mod) expected_mod = expected() for global_var in mod.get_global_vars(): - _assert_structural_equal(mod[global_var.name_hint], expected_mod[global_var.name_hint]) + tvm.ir.assert_structural_equal( + mod[global_var.name_hint], + expected_mod[global_var.name_hint], + map_free_vars=True, + ) diff --git a/tests/python/contrib/test_ethosu/test_extract_constants.py b/tests/python/contrib/test_ethosu/test_extract_constants.py index c5646b2c1229..204ff34bb806 100644 --- a/tests/python/contrib/test_ethosu/test_extract_constants.py +++ b/tests/python/contrib/test_ethosu/test_extract_constants.py @@ -45,7 +45,7 @@ def _expected(): func, const = _get_func() new_func, const_dict = extract_constants(func) - assert tvm.ir.structural_equal(new_func, _expected()) + tvm.ir.assert_structural_equal(new_func, _expected()) assert 1 in const_dict assert (const_dict[1] == const.data.asnumpy()).all() @@ -89,7 +89,7 @@ def _expected(): func, consts = _get_func() new_func, const_dict = extract_constants(func) - assert tvm.ir.structural_equal(new_func, _expected()) + tvm.ir.assert_structural_equal(new_func, _expected()) for i, const in enumerate(consts): assert i + 2 in const_dict assert (const_dict[i + 2] == consts[i].data.asnumpy()).all() diff --git a/tests/python/contrib/test_ethosu/test_identity_optimizer.py b/tests/python/contrib/test_ethosu/test_identity_optimizer.py index 3ae58dfc81ba..83aca640f767 100644 --- a/tests/python/contrib/test_ethosu/test_identity_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_identity_optimizer.py @@ -45,16 +45,6 @@ def _optimize(func, optimize=True): return entry if isinstance(func, relay.Function) else entry.body -def _assert_structural_equal(a, b): - """Check structural equality of two Relay expressions.""" - reason = ( - "Actual and expected relay functions are not equal. " - "IdentityOptimizer is not correctly removing redundant " - "identity operations." - ) - assert tvm.ir.structural_equal(a, b), reason - - def test_simple_reshape_identity_removal(): """Check identity is removed when there is a reshape in the graph and a compute operation follows.""" @@ -70,7 +60,7 @@ def get_graph(get_expected=False): actual = _optimize(get_graph()) expected = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_simple_strided_slice_identity_removal(): @@ -90,7 +80,7 @@ def get_graph(get_expected=False): actual = _optimize(get_graph()) expected = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_no_identity(): @@ -108,7 +98,7 @@ def get_graph(): actual = _optimize(get_graph()) expected = _optimize(get_graph(), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_reshape_last(): @@ -123,7 +113,7 @@ def get_graph(): actual = _optimize(get_graph()) expected = _optimize(get_graph(), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_requantize_identity_no_removal(): @@ -140,7 +130,7 @@ def get_graph(): actual = _optimize(get_graph()) expected = _optimize(get_graph(), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_activation_identity_no_removal(): @@ -155,7 +145,7 @@ def get_graph(): actual = _optimize(get_graph()) expected = _optimize(get_graph(), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_multiple_output_identity(): @@ -172,7 +162,7 @@ def get_graph(get_expected=False): actual = _optimize(get_graph()) expected = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_many_output_identity(): @@ -195,7 +185,7 @@ def get_graph(get_expected=False): actual = _optimize(get_graph()) expected = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_identity_before_concatenate_no_removal(): @@ -215,7 +205,7 @@ def get_graph(): actual = _optimize(get_graph()) expected = _optimize(get_graph(), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_identity_removal_with_multiple_transform_ops(): @@ -235,7 +225,7 @@ def get_graph(get_expected=False): actual = _optimize(get_graph()) expected = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_identity_removal_on_binary_elementwise(): @@ -252,7 +242,7 @@ def get_graph(get_expected=False): actual = _optimize(get_graph()) expected = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_identity_single_removal_on_binary_elementwise(): @@ -270,7 +260,7 @@ def get_graph(get_expected=False): actual = _optimize(get_graph()) expected = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_multiple_transform_ops_with_reduction_in_dimensionality(): @@ -289,7 +279,7 @@ def get_graph(): actual = _optimize(get_graph()) expected = _optimize(get_graph(), optimize=False) - _assert_structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_identity_optimizer_runs_in_compilation_pipeline(): diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py b/tests/python/contrib/test_ethosu/test_layout_optimizer.py index 69d549acbb3b..445eedbf64a8 100644 --- a/tests/python/contrib/test_ethosu/test_layout_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py @@ -49,15 +49,6 @@ def _optimize(func, optimize=True): return entry if isinstance(func, relay.Function) else entry.body -def _assert_structural_equal(a, b): - """Check structural equality of two Relay expressions.""" - reason = ( - "Actual and expected relay functions are not equal. " - "LayoutOptimizer is not correctly converting layouts." - ) - assert tvm.ir.structural_equal(a, b), reason - - def _compile_and_compare_model(tflite_graph, ifm_shape, dtype): """Compare running result of compilation against TFLite.""" tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) @@ -118,7 +109,7 @@ def get_graph(): a = _optimize(get_graph()) b = _optimize(get_graph(), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) @pytest.mark.parametrize("dtype", ["int8", "int32"]) @@ -157,7 +148,7 @@ def get_graph(get_expected=False): a = _optimize(get_graph()) b = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_multiple_convolution(): @@ -190,7 +181,7 @@ def get_graph(get_expected=False): a = _optimize(get_graph()) b = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_multiple_depthwise_convolution(): @@ -222,7 +213,7 @@ def get_graph(get_expected=False): a = _optimize(get_graph()) b = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_ignore_transform_operations(): @@ -268,7 +259,7 @@ def get_graph(): a = _optimize(get_graph()) b = _optimize(get_graph(), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_ignore_concatenate(): @@ -314,7 +305,7 @@ def get_graph(): a = _optimize(get_graph()) b = _optimize(get_graph(), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_ignore_concatnate_with_layout_transform(): @@ -373,7 +364,7 @@ def get_graph(): a = _optimize(get_graph()) b = _optimize(get_graph(), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_multiple_inputs(): @@ -422,7 +413,7 @@ def get_graph(): a = _optimize(get_graph()) b = _optimize(get_graph(), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_multiple_outputs(): @@ -471,7 +462,7 @@ def get_graph(get_expected=False): a = _optimize(get_graph()) b = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_multiple_binary_elementwise(): @@ -525,7 +516,7 @@ def get_graph(get_expected=False): a = _optimize(get_graph()) b = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_multiple_pooling(): @@ -561,7 +552,7 @@ def get_graph(get_expected=False): a = _optimize(get_graph()) b = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_multiple_unary_elementwise(): @@ -591,7 +582,7 @@ def get_graph(get_expected=False): a = _optimize(get_graph()) b = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_op_without_ethosu_consumer(): @@ -632,7 +623,7 @@ def get_graph(get_expected=False): a = _optimize(get_graph()) b = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_diamond_graph(): @@ -687,7 +678,7 @@ def get_graph(get_expected=False): a = _optimize(get_graph()) b = _optimize(get_graph(get_expected=True), optimize=False) - _assert_structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_same_output_multiple_convolutions(): diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py index dc3dd59a5a93..b8a275446207 100644 --- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py @@ -69,7 +69,7 @@ def after(): mod = LUTsOptimizer()(before()) mod = relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(mod, after()) + tvm.ir.assert_structural_equal(mod, after()) def test_merge_lut_into_binary_elementwise(): @@ -111,7 +111,7 @@ def after(): mod = LUTsOptimizer()(before()) mod = relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(mod, after()) + tvm.ir.assert_structural_equal(mod, after()) def test_multiple_luts(): @@ -146,7 +146,7 @@ def after(): mod = LUTsOptimizer()(before()) mod = relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(mod, after()) + tvm.ir.assert_structural_equal(mod, after()) def test_lut_optimizer_runs_in_compilation_pipeline(): diff --git a/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py index 062637b3bb94..5a6ed70a5902 100644 --- a/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py +++ b/tests/python/contrib/test_ethosu/test_outline_compiler_functions.py @@ -83,4 +83,4 @@ def expected(): global_vars = [str(gv) for gv in after.get_global_vars()] assert 'I.GlobalVar("ext_func")' in global_vars assert 'I.GlobalVar("ext_func_2")' not in global_vars - assert tvm.ir.structural_equal(after["ext_func"], exp["ext_func"]) + tvm.ir.assert_structural_equal(after["ext_func"], exp["ext_func"]) diff --git a/tests/python/contrib/test_ethosu/test_partition.py b/tests/python/contrib/test_ethosu/test_partition.py index 578485c8aa88..94896856db74 100644 --- a/tests/python/contrib/test_ethosu/test_partition.py +++ b/tests/python/contrib/test_ethosu/test_partition.py @@ -62,4 +62,4 @@ def get_graph(): mod = relay.transform.InferType()(get_graph()) partitioned_mod = ethosu.partition_for_ethosu(mod) - assert tvm.ir.structural_equal(mod, partitioned_mod) + tvm.ir.assert_structural_equal(mod, partitioned_mod) diff --git a/tests/python/contrib/test_ethosu/test_preprocess.py b/tests/python/contrib/test_ethosu/test_preprocess.py index 0a0aa2cf69a6..a80555b02277 100644 --- a/tests/python/contrib/test_ethosu/test_preprocess.py +++ b/tests/python/contrib/test_ethosu/test_preprocess.py @@ -67,7 +67,7 @@ def create_external_func1(mod_, compiler_name, symbol_name): mod = create_graph() exp = create_graph() mod = preprocess.preprocess_ext_io()(mod) - assert tvm.ir.structural_equal(mod, exp, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, exp, map_free_vars=True) def test_2ins_single_out(): @@ -140,7 +140,7 @@ def create_external_func1(mod_, compiler_name, symbol_name): mod = create_graph() exp = expected() mod = preprocess.preprocess_ext_io()(mod) - assert tvm.ir.structural_equal(mod, exp, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, exp, map_free_vars=True) def test_single_in_2outs(): @@ -219,7 +219,7 @@ def create_external_func1(mod_, compiler_name, symbol_name): exp = expected() mod = relay.transform.InferType()(mod) mod = preprocess.preprocess_ext_io()(mod) - assert tvm.ir.structural_equal(mod, exp, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, exp, map_free_vars=True) def test_4ins_2outs(): @@ -336,7 +336,7 @@ def create_external_func1(mod_, compiler_name, symbol_name): mod = create_graph() exp = expected() mod = preprocess.preprocess_ext_io()(mod) - assert tvm.ir.structural_equal(mod, exp, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, exp, map_free_vars=True) if __name__ == "__main__": diff --git a/tests/python/contrib/test_hexagon/test_relay_simplify_conv_pat.py b/tests/python/contrib/test_hexagon/test_relay_simplify_conv_pat.py index b2c60b083cc1..0f8a9a739559 100644 --- a/tests/python/contrib/test_hexagon/test_relay_simplify_conv_pat.py +++ b/tests/python/contrib/test_hexagon/test_relay_simplify_conv_pat.py @@ -157,7 +157,7 @@ def test_simplify_conv_pat(hexagon_session: Session): mod = simplify_conv_pat(mod) mod = tvm.relay.transform.InferType()(mod) exp_relay_mod = tvm.relay.transform.InferType()(exp_relay_mod) - assert tvm.ir.structural_equal(mod["main"], exp_relay_mod["main"], map_free_vars=True) + tvm.ir.assert_structural_equal(mod["main"], exp_relay_mod["main"], map_free_vars=True) mod = tvm.relay.transform.FoldConstant()(mod) hexagon_lowered_opt = build_module( mod, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET) @@ -196,7 +196,7 @@ def test_negative(): orig_mod = tvm.relay.transform.InferType()(orig_mod) opt_mod = simplify_conv_pat(orig_mod) opt_mod = tvm.relay.transform.InferType()(opt_mod) - assert tvm.ir.structural_equal(orig_mod["main"], opt_mod["main"], map_free_vars=True) + tvm.ir.assert_structural_equal(orig_mod["main"], opt_mod["main"], map_free_vars=True) if __name__ == "__main__": diff --git a/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py b/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py index 728ec8124359..4eda615a1dd5 100644 --- a/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py +++ b/tests/python/contrib/test_hexagon/test_relay_simplify_qnn_concat.py @@ -92,7 +92,7 @@ def test_simplify_qnn_concat(): out_mod = get_expected_output_module() out_mod = tvm.relay.transform.InferType()(out_mod) - assert tvm.ir.structural_equal(mod["main"], out_mod["main"]) + tvm.ir.assert_structural_equal(mod["main"], out_mod["main"]) if __name__ == "__main__": diff --git a/tests/python/contrib/test_hexagon/test_relay_transforms.py b/tests/python/contrib/test_hexagon/test_relay_transforms.py index ef57e298ab69..32c8ff126544 100644 --- a/tests/python/contrib/test_hexagon/test_relay_transforms.py +++ b/tests/python/contrib/test_hexagon/test_relay_transforms.py @@ -85,14 +85,14 @@ def test_rewrite_qdistilbert(): ref_func = relay.Function(relay.analysis.free_vars(ref), ref) ref_mod = tvm.IRModule.from_expr(ref_func) - assert tvm.ir.structural_equal(mod["main"], ref_mod["main"]) + tvm.ir.assert_structural_equal(mod["main"], ref_mod["main"]) # If the pattern does not match, should return the original. func = relay.expr.Tuple(expand_dims) # omitting concatenate mod = tvm.IRModule.from_expr(func) out_mod = rewrite_qdistilbert(mod) # out does not return ref_mod but the original mod - assert tvm.ir.structural_equal(mod["main"], out_mod["main"]) + tvm.ir.assert_structural_equal(mod["main"], out_mod["main"]) def test_remove_empty_pad(): @@ -113,7 +113,7 @@ def test_remove_empty_pad(): ref_func = relay.Function(relay.analysis.free_vars(ref), ref) ref_mod = tvm.IRModule.from_expr(ref_func) - assert tvm.ir.structural_equal(mod["main"], ref_mod["main"]) + tvm.ir.assert_structural_equal(mod["main"], ref_mod["main"]) if __name__ == "__main__": diff --git a/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py b/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py index 058faa8a24e6..b4d12cf62ced 100644 --- a/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py +++ b/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py @@ -373,7 +373,7 @@ def expected(): ref_mod = expected() - assert tvm.ir.structural_equal(partitioned_mod, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(partitioned_mod, ref_mod, map_free_vars=True) if __name__ == "__main__": diff --git a/tests/python/frontend/caffe2/test_graph.py b/tests/python/frontend/caffe2/test_graph.py index 51a9a53ec057..3bf5beff3fce 100644 --- a/tests/python/frontend/caffe2/test_graph.py +++ b/tests/python/frontend/caffe2/test_graph.py @@ -24,7 +24,7 @@ def compare_graph(lhs_mod, rhs_mod): lhs_mod = transform.InferType()(lhs_mod) rhs_mod = transform.InferType()(rhs_mod) - assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"]) + tvm.ir.assert_structural_equal(lhs_mod["main"], rhs_mod["main"]) def test_squeeze_net(): diff --git a/tests/python/frontend/mxnet/test_graph.py b/tests/python/frontend/mxnet/test_graph.py index 5c009febc296..63ce763f1725 100644 --- a/tests/python/frontend/mxnet/test_graph.py +++ b/tests/python/frontend/mxnet/test_graph.py @@ -26,7 +26,7 @@ def compare_graph(lhs_mod, rhs_mod): lhs_mod = transform.InferType()(lhs_mod) rhs_mod = transform.InferType()(rhs_mod) - assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"]) + tvm.ir.assert_structural_equal(lhs_mod["main"], rhs_mod["main"]) def test_mlp(): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 20d9c7cd33f2..a5811d0dbd46 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -117,7 +117,7 @@ def get_tvm_output_with_vm( freeze_params=freeze_params, convert_config=convert_config, ) - assert tvm.ir.structural_equal(mod, mod_with_span) + tvm.ir.assert_structural_equal(mod, mod_with_span) result = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()( *input_data, **params @@ -8480,7 +8480,7 @@ def _verify(self, res_fptr, golden_fptr): with_span = res_fptr() with tvm.testing.disable_span_filling(): without_span = res_fptr() - assert tvm.ir.structural_equal(with_span, without_span) + tvm.ir.assert_structural_equal(with_span, without_span) _verify_structural_equal_with_span(with_span, golden_fptr()) def test_conv2d_bias_add_span(self): diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index beaeeb999923..1cc1a46cea6b 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -53,7 +53,7 @@ def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=Fal mod_with_span, _ = relay.frontend.from_pytorch( script_module, input_shapes, keep_quantized_weight=keep_quantized_weight ) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) if keep_quantized_weight: for p in params.values(): @@ -639,7 +639,7 @@ def run_qnn_mergecomposite(script_module, input_name, ishape): mod, params = relay.frontend.from_pytorch(script_module, input_shapes) with tvm.testing.enable_span_filling(): mod_with_span, _ = relay.frontend.from_pytorch(script_module, input_shapes) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) pattern_table = get_pattern_table("test_table") with tvm.transform.PassContext(opt_level=3): pass_list = [ @@ -792,7 +792,7 @@ def forward(self, input): mod, _ = relay.frontend.from_pytorch(script_module, input_infos) with tvm.testing.enable_span_filling(): mod_with_span, _ = relay.frontend.from_pytorch(script_module, input_infos) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) output = mod["main"].body assert isinstance(output, relay.Tuple) and len(output) == 2 diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3b82c96a3631..a273af8fb89d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -183,7 +183,7 @@ def verify_model( if validate_structural_equal: with tvm.testing.enable_span_filling(): mod_with_span, _ = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) for arg in mod["main"].params[: len(input_names)]: assert arg.name_hint in input_names @@ -254,7 +254,7 @@ def verify_model_with_input( if validate_structural_equal: with tvm.testing.enable_span_filling(): mod_with_span, _ = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) with tvm.transform.PassContext(opt_level=3): for target in ["llvm", "cuda"]: @@ -2775,7 +2775,7 @@ def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=None) mod, params = relay.frontend.from_pytorch(input_model, input_shapes) with tvm.testing.enable_span_filling(): mod_with_span, _ = relay.frontend.from_pytorch(input_model, input_shapes) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) for tgt in targets: if not tvm.testing.device_enabled(tgt): @@ -5666,7 +5666,7 @@ def _verify(self, res_fptr, golden_fptr): with_span = res_fptr() with tvm.testing.disable_span_filling(): without_span = res_fptr() - assert tvm.ir.structural_equal(with_span, without_span) + tvm.ir.assert_structural_equal(with_span, without_span) _verify_structural_equal_with_span(with_span, golden_fptr()) def test_conv2d_bias_add(self): diff --git a/tests/python/frontend/pytorch/test_fx_quant.py b/tests/python/frontend/pytorch/test_fx_quant.py index b87c0b0f00b2..7f3083a7dcd0 100644 --- a/tests/python/frontend/pytorch/test_fx_quant.py +++ b/tests/python/frontend/pytorch/test_fx_quant.py @@ -44,7 +44,7 @@ def quantize_and_build(model, in_size): mod, _ = relay.frontend.from_pytorch(script_module, [(input_name, inp.shape)]) with tvm.testing.enable_span_filling(): mod_with_span, _ = relay.frontend.from_pytorch(script_module, [(input_name, inp.shape)]) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) mod = relay.transform.InferType()(mod) # Make sure that the model is quantized diff --git a/tests/python/frontend/pytorch/test_lstm.py b/tests/python/frontend/pytorch/test_lstm.py index e9dd2b380c1e..da4e1ae96e03 100644 --- a/tests/python/frontend/pytorch/test_lstm.py +++ b/tests/python/frontend/pytorch/test_lstm.py @@ -341,7 +341,7 @@ def test_custom_lstm(): mod, params = from_pytorch(script_module, input_shapes) with tvm.testing.enable_span_filling(): mod_with_span, _ = from_pytorch(script_module, input_shapes) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) with torch.no_grad(): pt_result = raw_model(inp.clone(), states) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 25e784b00a1b..9dd336f7e9d2 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -108,7 +108,7 @@ def test_detection_models(): mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) with tvm.testing.enable_span_filling(): mod_with_span, _ = relay.frontend.from_pytorch(scripted_model, shape_list) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) data = process_image(img) data_np = data.detach().numpy() diff --git a/tests/python/frontend/pytorch/test_rnns.py b/tests/python/frontend/pytorch/test_rnns.py index 3ea423250010..b43af58d69a3 100644 --- a/tests/python/frontend/pytorch/test_rnns.py +++ b/tests/python/frontend/pytorch/test_rnns.py @@ -464,7 +464,7 @@ def get_onnx_model(model): mod_with_span, _ = relay.frontend.from_pytorch( traced_script_module, shape_desc ) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) elif format == "onnx": try: onnx_model = get_onnx_model(model) @@ -480,7 +480,7 @@ def get_onnx_model(model): mod, params = relay.frontend.from_onnx(onnx_model, shape_desc) with tvm.testing.enable_span_filling(): mod_with_span, _ = relay.frontend.from_onnx(onnx_model, shape_desc) - assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, mod_with_span, map_free_vars=True) # Model compilation by tvm with tvm.transform.PassContext(opt_level=3): diff --git a/tests/python/frontend/tensorflow/test_bn_dynamic.py b/tests/python/frontend/tensorflow/test_bn_dynamic.py index df7052008821..99d8f790028c 100644 --- a/tests/python/frontend/tensorflow/test_bn_dynamic.py +++ b/tests/python/frontend/tensorflow/test_bn_dynamic.py @@ -69,7 +69,7 @@ def verify_fused_batch_norm(shape): mod, params = relay.frontend.from_tensorflow(constant_graph, outputs=["output"]) with tvm.testing.enable_span_filling(): mod_with_span, _ = relay.frontend.from_tensorflow(constant_graph, outputs=["output"]) - assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"]) + tvm.ir.assert_structural_equal(mod["main"], mod_with_span["main"]) with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target=device, params=params) from tvm.contrib import graph_executor diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index ea4842771967..db270ccb2e9f 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -167,7 +167,7 @@ def run_tvm_graph( outputs=out_names, convert_config=convert_config, ) - assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"], map_free_vars=True) + tvm.ir.assert_structural_equal(mod["main"], mod_with_span["main"], map_free_vars=True) dev = tvm.device(target, 0) if mode == "debug": @@ -1868,7 +1868,7 @@ def test_read_variable_op(target, dev): mod_with_span, _ = relay.frontend.from_tensorflow( final_graph_def, layout=None, shape=shape_dict, outputs=None ) - assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"]) + tvm.ir.assert_structural_equal(mod["main"], mod_with_span["main"]) assert execinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph") @@ -4164,7 +4164,7 @@ def _get_tvm_graph_module(graph_def): "Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:6", ], ) - assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"]) + tvm.ir.assert_structural_equal(mod["main"], mod_with_span["main"]) target = "llvm" with tvm.transform.PassContext(opt_level=0): @@ -5809,7 +5809,7 @@ def test_moments(): mod, _ = from_tensorflow(g.as_graph_def(add_shapes=True)) with tvm.testing.enable_span_filling(): mod_with_span, _ = from_tensorflow(g.as_graph_def(add_shapes=True)) - assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"], map_free_vars=True) + tvm.ir.assert_structural_equal(mod["main"], mod_with_span["main"], map_free_vars=True) program = """ def @main(%A: Tensor[(4, 176, 8, 8), float32]) { @@ -5932,7 +5932,7 @@ def _verify(self, res_fptr, golden_fptr): with_span = res_fptr() with tvm.testing.disable_span_filling(): without_span = res_fptr() - assert tvm.ir.structural_equal(with_span, without_span) + tvm.ir.assert_structural_equal(with_span, without_span) _verify_structural_equal_with_span(with_span, golden_fptr()) def test_conv2d_bias_add_span(self): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ebf7bce250b1..75a2a37c636a 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -224,7 +224,7 @@ def run_tvm_graph( mod_with_span, _ = relay.frontend.from_tflite( tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict, op_converter=op_converter ) - assert tvm.ir.structural_equal(mod["main"], mod_with_span["main"]) + tvm.ir.assert_structural_equal(mod["main"], mod_with_span["main"]) if mode in ["debug", "vm"]: inputs = [] @@ -5548,7 +5548,7 @@ def _verify(res_fptr, golden_fptr): with_span = res_fptr() with tvm.testing.disable_span_filling(): without_span = res_fptr() - assert tvm.ir.structural_equal(with_span, without_span) + tvm.ir.assert_structural_equal(with_span, without_span) _verify_structural_equal_with_span(with_span, golden_fptr()) def _tf_to_tflite( diff --git a/tests/python/ir/test_ir_attrs.py b/tests/python/ir/test_ir_attrs.py index 9ac0648eb36c..13e10cdbee2b 100644 --- a/tests/python/ir/test_ir_attrs.py +++ b/tests/python/ir/test_ir_attrs.py @@ -50,7 +50,7 @@ def test_attrs_equal(): dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20]) dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1) dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None) - assert tvm.ir.structural_equal(dattr0, dattr1) + tvm.ir.assert_structural_equal(dattr0, dattr1) assert not tvm.ir.structural_equal(dattr0, dattr2) assert not tvm.ir.structural_equal({"x": 1}, tvm.runtime.convert(1)) assert not tvm.ir.structural_equal([1, 2], tvm.runtime.convert(1)) diff --git a/tests/python/ir/test_ir_type.py b/tests/python/ir/test_ir_type.py index 986e48dc69b9..2355aa19adec 100644 --- a/tests/python/ir/test_ir_type.py +++ b/tests/python/ir/test_ir_type.py @@ -21,7 +21,7 @@ def check_json_roundtrip(node): json_str = tvm.ir.save_json(node) back = tvm.ir.load_json(json_str) - assert tvm.ir.structural_equal(back, node, map_free_vars=True) + tvm.ir.assert_structural_equal(back, node, map_free_vars=True) def test_prim_type(): diff --git a/tests/python/meta_schedule/test_meta_schedule_database.py b/tests/python/meta_schedule/test_meta_schedule_database.py index 11fbeb811ea7..f87c8753f8f7 100644 --- a/tests/python/meta_schedule/test_meta_schedule_database.py +++ b/tests/python/meta_schedule/test_meta_schedule_database.py @@ -104,7 +104,7 @@ def _equal_record(a: ms.database.TuningRecord, b: ms.database.TuningRecord): assert str(a.run_secs) == str(b.run_secs) # AWAIT(@zxybazh): change to export after fixing "(bool)0" assert str(a.target) == str(b.target) - assert tvm.ir.structural_equal(a.workload.mod, b.workload.mod) + tvm.ir.assert_structural_equal(a.workload.mod, b.workload.mod) for arg0, arg1 in zip(a.args_info, b.args_info): assert str(arg0.as_json()) == str(arg1.as_json()) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 7fbf9a2da141..e7e8f94fc2ac 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -18,7 +18,6 @@ import pytest import tvm from tvm import relax -from tvm.ir import structural_equal import tvm.script from tvm.script import tir as T, relax as R @@ -117,7 +116,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert isinstance(s1, relax.Call) assert s1.op.name == "relax.builtin.alloc_tensor" assert isinstance(s1.args[0], relax.ShapeExpr) - assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) + tvm.ir.assert_structural_equal(s1.args[0], s0.sinfo_args[0].shape) s2 = block.bindings[1].value tvm.ir.expr.GlobalVar assert s2.op.name_hint == "exp" @@ -262,7 +261,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert isinstance(s1, relax.Call) assert s1.op.name == "relax.builtin.alloc_tensor" assert isinstance(s1.args[0], relax.ShapeExpr) - assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) + tvm.ir.assert_structural_equal(s1.args[0], s0.sinfo_args[0].shape) s2 = block.bindings[1].value assert s2.op.global_symbol == "test.op.identity" diff --git a/tests/python/relay/test_analysis_extract_intermediate_expr.py b/tests/python/relay/test_analysis_extract_intermediate_expr.py index 57585552b4a1..f0267ebc7951 100644 --- a/tests/python/relay/test_analysis_extract_intermediate_expr.py +++ b/tests/python/relay/test_analysis_extract_intermediate_expr.py @@ -108,22 +108,22 @@ def expected_4(): tuple_out = relay.op.split(z, indices_or_sections=1, axis=0) return tvm.IRModule.from_expr(tuple_out[0]) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( relay.analysis.extract_intermdeiate_expr(before(), 0), expected_0() ) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( relay.analysis.extract_intermdeiate_expr(before(), 1), expected_1() ) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( relay.analysis.extract_intermdeiate_expr(before(), 2), expected_2() ) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( (relay.analysis.extract_intermdeiate_expr(before(), 3)), expected_3() ) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( relay.analysis.extract_intermdeiate_expr(before(), 4), expected_4() ) - assert tvm.ir.structural_equal(relay.analysis.extract_intermdeiate_expr(before(), 5), before()) + tvm.ir.assert_structural_equal(relay.analysis.extract_intermdeiate_expr(before(), 5), before()) if __name__ == "__main__": diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py index 26106c31d5ce..be4d52f8812a 100644 --- a/tests/python/relay/test_call_graph.py +++ b/tests/python/relay/test_call_graph.py @@ -27,7 +27,7 @@ def test_callgraph_construct(): mod["g1"] = relay.Function([x, y], x + y) call_graph = relay.analysis.CallGraph(mod) assert "g1" in str(call_graph) - assert tvm.ir.structural_equal(mod, call_graph.module) + tvm.ir.assert_structural_equal(mod, call_graph.module) def test_print_element(): diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 3950c02c08a4..6942c47491de 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -118,7 +118,7 @@ def test_ShapePattern(): shape = [10, 10] pattern = has_shape(shape) assert isinstance(pattern, ShapePattern) - assert tvm.ir.structural_equal(pattern.shape, shape) + tvm.ir.assert_structural_equal(pattern.shape, shape) def test_AttrPattern(): @@ -929,7 +929,7 @@ def pattern(): pat = pattern() new_out = rewrite(PatternCallback(pat), out) - assert tvm.ir.structural_equal(out, new_out) + tvm.ir.assert_structural_equal(out, new_out) def test_not_fuse_multi_diamond(): @@ -985,7 +985,7 @@ def test_fuse_batchnorm(): BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta out = rewrite(BatchnormCallback(), BN) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0] ) @@ -1000,7 +1000,7 @@ def test_no_fuse_batchnorm(): fake_BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta out = rewrite(BatchnormCallback(), fake_BN) - assert tvm.ir.structural_equal(out, fake_BN) + tvm.ir.assert_structural_equal(out, fake_BN) def test_fuse_double_batchnorm(): @@ -1018,7 +1018,7 @@ def test_fuse_double_batchnorm(): bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0] bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon=1e-5)[0] - assert tvm.ir.structural_equal(out, bn2) + tvm.ir.assert_structural_equal(out, bn2) def test_partial_fuse_double_batchnorm(): @@ -1035,7 +1035,7 @@ def test_partial_fuse_double_batchnorm(): bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon=1e-5)[0] - assert tvm.ir.structural_equal(out, bn2) + tvm.ir.assert_structural_equal(out, bn2) def test_fuse_batchnorm_commutation(): @@ -1048,21 +1048,21 @@ def test_fuse_batchnorm_commutation(): # commute add BN = beta + gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) out = rewrite(BatchnormCallback(), BN) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0] ) # associate divide/multiply BN = (gamma * (x - mean)) / relay.op.sqrt(var + relay.const(1e-5)) + beta out = rewrite(BatchnormCallback(), BN) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0] ) # associate multiply/divide BN = gamma * ((x - mean) / relay.op.sqrt(var + relay.const(1e-5))) + beta out = rewrite(BatchnormCallback(), BN) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0] ) @@ -1121,7 +1121,7 @@ def callback(self, pre, post, node_map): three = relay.op.nn.conv2d(two, weight) four = relay.op.nn.conv2d(three, weight) - assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four) + tvm.ir.assert_structural_equal(DominatorRemovalCallback().rewrite(out), four) def algebraic_simplify(expr): @@ -1210,7 +1210,7 @@ def test_algebraic_simplify(): assert algebraic_simplify(zero / x) == zero assert algebraic_simplify(zerof / x) == zerof - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y ) @@ -1260,7 +1260,7 @@ def test_double_partition(): ) expected = func1(func0(x, w, b), w2, b2) - assert tvm.ir.structural_equal(partitioned, expected) + tvm.ir.assert_structural_equal(partitioned, expected) def test_partition_dominator(): @@ -1290,7 +1290,7 @@ def generate_diamond(inp, weight): f = relay.Function([i, w], generate_diamond(i, w)).with_attr( "PartitionedFromPattern", "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_" ) - assert tvm.ir.structural_equal(partitioned, f(inp * inp, weight * weight)) + tvm.ir.assert_structural_equal(partitioned, f(inp * inp, weight * weight)) def test_quadruple_partition_dominator(): @@ -1364,7 +1364,7 @@ def nested_diamond(inp, weight): reference = functions[3]( functions[2](functions[1](functions[0](inp, weight), weight), weight), weight ) - assert tvm.ir.structural_equal(partitioned, reference) + tvm.ir.assert_structural_equal(partitioned, reference) def get_BN(x, var, mean, beta, gamma, eps): @@ -1392,7 +1392,7 @@ def test_partition_batchnorm(): partitioned = BatchnormCallback().pattern.partition(BN) reference = f(gamma, x, mean, var, beta) - assert tvm.ir.structural_equal(partitioned, reference) + tvm.ir.assert_structural_equal(partitioned, reference) def test_partition_double_batchnorm(): @@ -1426,7 +1426,7 @@ def test_partition_double_batchnorm(): partitioned = BatchnormCallback().pattern.partition(BN2) reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta) - assert tvm.ir.structural_equal(partitioned, reference) + tvm.ir.assert_structural_equal(partitioned, reference) def test_overlappting_partitions(): @@ -1481,11 +1481,11 @@ def concat(*args): return relay.op.concatenate(relay.expr.Tuple(args), axis=0) one = concat_pattern.partition(concat(x)) - assert tvm.ir.structural_equal(one, create_func([xp], concat(xp))(x)) + tvm.ir.assert_structural_equal(one, create_func([xp], concat(xp))(x)) two = concat_pattern.partition(concat(x, y)) - assert tvm.ir.structural_equal(two, create_func([xp, yp], concat(xp, yp))(x, y)) + tvm.ir.assert_structural_equal(two, create_func([xp, yp], concat(xp, yp))(x, y)) three = concat_pattern.partition(concat(x, y, z)) - assert tvm.ir.structural_equal(three, create_func([xp, yp, zp], concat(xp, yp, zp))(x, y, z)) + tvm.ir.assert_structural_equal(three, create_func([xp, yp, zp], concat(xp, yp, zp))(x, y, z)) def test_partition_fuzzy_function_args(): @@ -1510,13 +1510,13 @@ def create_func(call): f1 = relay.Function([xp], xp + xp)(x) one = func_pattern.partition(f1 + b) - assert tvm.ir.structural_equal(one, create_func(f1)) + tvm.ir.assert_structural_equal(one, create_func(f1)) f2 = relay.Function([xp, yp], xp + yp)(x, y) two = func_pattern.partition(f2 + b) - assert tvm.ir.structural_equal(two, create_func(f2)) + tvm.ir.assert_structural_equal(two, create_func(f2)) f3 = relay.Function([xp, yp, zp], xp + yp + zp)(x, y, z) three = func_pattern.partition(f3 + b) - assert tvm.ir.structural_equal(three, create_func(f3)) + tvm.ir.assert_structural_equal(three, create_func(f3)) def test_partition_check(): @@ -1538,7 +1538,7 @@ def check(pre): reference = func(x, w) partitioned = pattern.partition(relu, check=check) - assert tvm.ir.structural_equal(partitioned, reference) + tvm.ir.assert_structural_equal(partitioned, reference) conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC") relu = relay.op.nn.relu(conv2d) @@ -1604,10 +1604,10 @@ def test_partition_option(): ) assert pattern1.match(relu) - assert tvm.ir.structural_equal(func(x, w, b), pattern1.partition(relu)) + tvm.ir.assert_structural_equal(func(x, w, b), pattern1.partition(relu)) assert pattern2.match(relu) - assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu)) + tvm.ir.assert_structural_equal(func(x, w, b), pattern2.partition(relu)) def test_partition_function(): @@ -1637,7 +1637,7 @@ def test_partition_function(): "PartitionedFromPattern", "nn.conv2d_FunctionCall_add_" ) expr2 = func2(x, w, b) + b - assert tvm.ir.structural_equal(pattern.partition(expr), expr2) + tvm.ir.assert_structural_equal(pattern.partition(expr), expr2) def test_partition_optional_function(): @@ -1670,7 +1670,7 @@ def test_partition_optional_function(): "PartitionedFromPattern", "nn.conv2d_nn.relu_FunctionCall_" ) expr2 = func2(x, w) + b - assert tvm.ir.structural_equal(pattern.partition(expr), expr2) + tvm.ir.assert_structural_equal(pattern.partition(expr), expr2) def test_rewrite_function_with_fuzzy_body(): @@ -1703,7 +1703,7 @@ def callback(self, pre, post, node_map): return x + w out = rewrite(TestRewrite(), expr) - assert tvm.ir.structural_equal(out, x + w + b) + tvm.ir.assert_structural_equal(out, x + w + b) def test_partition_function_with_fuzzy_body(): @@ -1736,7 +1736,7 @@ def test_partition_function_with_fuzzy_body(): "PartitionedFromPattern", "nn.conv2d_FunctionCall_add_" ) expr2 = func2(x, w, b) + b - assert tvm.ir.structural_equal(pattern.partition(expr), expr2) + tvm.ir.assert_structural_equal(pattern.partition(expr), expr2) def test_match_match(): @@ -1754,7 +1754,7 @@ def callback(self, pre, post, node_map): tvm.relay.prelude.Prelude(mod) # Apply rewrite on IR including relay.Match out = rewrite(TestRewrite(), mod["tensor_concatenate_int64"]) - assert tvm.ir.structural_equal(mod["tensor_concatenate_int64"], out) + tvm.ir.assert_structural_equal(mod["tensor_concatenate_int64"], out) def test_partition_constant_embedding(): @@ -1782,43 +1782,43 @@ def test_partition_constant_embedding(): pattern = is_op("nn.relu")( is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), wildcard()), wildcard()) ) - assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) - assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc)) + tvm.ir.assert_structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + tvm.ir.assert_structural_equal(lifted_func(x, wc, b), pattern.partition(reluc)) # Check lifting of input matches pattern = is_op("nn.relu")( is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_var()), wildcard()) ) - assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) - assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) # Constants are not Inputs + tvm.ir.assert_structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + tvm.ir.assert_structural_equal(reluc, pattern.partition(reluc)) # Constants are not Inputs # Check embedding of constant matches pattern = is_op("nn.relu")( is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_constant()), wildcard()) ) - assert tvm.ir.structural_equal(relu, pattern.partition(relu)) - assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + tvm.ir.assert_structural_equal(relu, pattern.partition(relu)) + tvm.ir.assert_structural_equal(embeded_func(x, b), pattern.partition(reluc)) # Check embedding of constant ExprPatterns pattern = is_op("nn.relu")( is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_expr(wc)), wildcard()) ) - assert tvm.ir.structural_equal(relu, pattern.partition(relu)) - assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + tvm.ir.assert_structural_equal(relu, pattern.partition(relu)) + tvm.ir.assert_structural_equal(embeded_func(x, b), pattern.partition(reluc)) # Check lifting/embedding of Alt matches pattern = is_op("nn.relu")( is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_var() | is_constant()), wildcard()) ) - assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) - assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + tvm.ir.assert_structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + tvm.ir.assert_structural_equal(embeded_func(x, b), pattern.partition(reluc)) # Check lifting/embedding of Alt matches with the other ordering pattern = is_op("nn.relu")( is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_constant() | is_var()), wildcard()) ) - assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) - assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + tvm.ir.assert_structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + tvm.ir.assert_structural_equal(embeded_func(x, b), pattern.partition(reluc)) def test_rewrite_once(): @@ -1846,12 +1846,12 @@ def test_one_callback(): # Let the rewriter run recursively out = rewrite(ConcatRewriter(False), concat) expected = x - assert tvm.ir.structural_equal(out, expected) + tvm.ir.assert_structural_equal(out, expected) # Run the rewriter once out = rewrite(ConcatRewriter(True), concat) expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0) - assert tvm.ir.structural_equal(out, expected) + tvm.ir.assert_structural_equal(out, expected) def test_multi_callbacks(): # This class recursively add a nn.relu operator after nn.softmax @@ -1901,14 +1901,14 @@ def recursive_concat(): [OneMoreReluRewriter(True), ConcatRewriter(True)], before(), ) - assert tvm.ir.structural_equal(out, once_concat()) + tvm.ir.assert_structural_equal(out, once_concat()) # Run ConcatRewriter recursively, OneMoreReluRewriter once out = rewrite( [OneMoreReluRewriter(True), ConcatRewriter(False)], before(), ) - assert tvm.ir.structural_equal(out, recursive_concat()) + tvm.ir.assert_structural_equal(out, recursive_concat()) test_one_callback() test_multi_callbacks() @@ -1992,7 +1992,7 @@ def test_partition_parallel_branch_with_same_input(): partitioned = pattern.partition(add) reference = f(l, conv2d, r) - assert tvm.ir.structural_equal(partitioned, reference) + tvm.ir.assert_structural_equal(partitioned, reference) def test_rewrite_with_pattern_recursion(): diff --git a/tests/python/relay/test_ir_bind.py b/tests/python/relay/test_ir_bind.py index 0ab0122fa798..1e5ab92cf2c5 100644 --- a/tests/python/relay/test_ir_bind.py +++ b/tests/python/relay/test_ir_bind.py @@ -29,11 +29,11 @@ def test_bind_params(): f = relay.Function([x, y], z) fbinded = relay.bind(f, {x: relay.const(1, "float32")}) fexpected = relay.Function([y], relay.add(relay.const(1, "float32"), y)) - assert tvm.ir.structural_equal(fbinded, fexpected) + tvm.ir.assert_structural_equal(fbinded, fexpected) zbinded = relay.bind(z, {y: x}) zexpected = relay.add(x, x) - assert tvm.ir.structural_equal(zbinded, zexpected) + tvm.ir.assert_structural_equal(zbinded, zexpected) def test_bind_duplicated_params(): diff --git a/tests/python/relay/test_ir_structural_equal_hash.py b/tests/python/relay/test_ir_structural_equal_hash.py index a808259d26af..97b631a22518 100644 --- a/tests/python/relay/test_ir_structural_equal_hash.py +++ b/tests/python/relay/test_ir_structural_equal_hash.py @@ -792,7 +792,7 @@ def func3(): sb.ret(a2) return relay.Function([p0, p1], sb.get()) - assert tvm.ir.structural_equal(func1(), func2()) + tvm.ir.assert_structural_equal(func1(), func2()) assert not tvm.ir.structural_equal(func1(), func3()) diff --git a/tests/python/relay/test_name_supply.py b/tests/python/relay/test_name_supply.py index 688be19c8171..f48fe0a47485 100644 --- a/tests/python/relay/test_name_supply.py +++ b/tests/python/relay/test_name_supply.py @@ -18,7 +18,7 @@ import tvm.testing from tvm import relay -from tvm.ir import GlobalVar, structural_equal +from tvm.ir import GlobalVar, structural_equal, assert_structural_equal from tvm.ir.supply import NameSupply from tvm.ir.supply import GlobalVarSupply @@ -39,7 +39,7 @@ def test_global_var_supply_from_none(): global_var = GlobalVar("test") var_supply.reserve_global(global_var) - assert structural_equal(var_supply.unique_global_for("test"), global_var) + assert_structural_equal(var_supply.unique_global_for("test"), global_var) assert not structural_equal(var_supply.fresh_global("test"), global_var) @@ -49,7 +49,7 @@ def test_global_var_supply_from_name_supply(): global_var = GlobalVar("test") var_supply.reserve_global(global_var) - assert structural_equal(var_supply.unique_global_for("test", False), global_var) + assert_structural_equal(var_supply.unique_global_for("test", False), global_var) assert not structural_equal(var_supply.unique_global_for("test"), global_var) @@ -63,7 +63,7 @@ def test_global_var_supply_from_ir_mod(): second_global_var = var_supply.fresh_global("test", False) - assert structural_equal(var_supply.unique_global_for("test", False), global_var) + assert_structural_equal(var_supply.unique_global_for("test", False), global_var) assert not structural_equal(var_supply.unique_global_for("test"), global_var) assert not structural_equal(second_global_var, global_var) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index eb57f795e238..2463baa725a4 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -74,7 +74,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_return_none(): @@ -97,7 +97,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(before(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) assert called[0] @@ -162,7 +162,7 @@ def expected(): a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_multi(): @@ -208,7 +208,7 @@ def expected(): a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_lrn(): @@ -260,7 +260,7 @@ def expected(): a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_dual_path(): @@ -313,7 +313,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_resnet(): @@ -361,7 +361,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_broadcast_op(): @@ -409,7 +409,7 @@ def expected(): a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_broadcast_scalar_op(): @@ -468,7 +468,7 @@ def expected(): a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_scalar(): @@ -509,7 +509,7 @@ def expected(): a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_scalar_regression(): @@ -599,7 +599,7 @@ def expected(): ) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_concatenate(): @@ -643,7 +643,7 @@ def expected_nchw(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nchw(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) # NHWC layout transformation. def before_nhwc(): @@ -681,7 +681,7 @@ def expected_nhwc(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_nchw_upsamping_op(): @@ -720,7 +720,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_nchw_dyn_upsamping_op(): @@ -759,7 +759,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) @tvm.testing.parametrize_targets("llvm") @@ -872,7 +872,7 @@ def expected(): mod_new = tvm.IRModule() mod_before["main"] = a mod_new["main"] = b - assert tvm.ir.structural_equal(mod_before, mod_new) + tvm.ir.assert_structural_equal(mod_before, mod_new) def test_alter_layout_depthwise_conv2d(): @@ -916,7 +916,7 @@ def expected(): a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_prelu(): @@ -956,7 +956,7 @@ def expected(): a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_pad(): @@ -994,7 +994,7 @@ def expected_nchw(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nchw(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) # Check NHWC conversion. def before_nhwc(): @@ -1024,7 +1024,7 @@ def expected_nhwc(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) # Check that conversion does not happen when padding along split axis. def before(): @@ -1052,7 +1052,7 @@ def expected(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_pool(): @@ -1090,7 +1090,7 @@ def expected_nchw(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nchw(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) # Check NHWC conversion. def before_nhwc(): @@ -1120,7 +1120,7 @@ def expected_nhwc(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_sum(): @@ -1158,7 +1158,7 @@ def expected_nchw(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nchw(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) # Check NHWC conversion. def before_nhwc(): @@ -1188,7 +1188,7 @@ def expected_nhwc(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_nhwc_arm(): @@ -1225,7 +1225,7 @@ def expected_nhwc(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_layout_nhwc_int8_aarch64(): @@ -1302,7 +1302,7 @@ def expected_nhwc_int8(): a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected_nhwc_int8(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_alter_op_with_global_var(): @@ -1349,7 +1349,7 @@ def expected(): a = transform.AlterOpLayout()(a) b = transform.InferType()(expected()) - assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b, map_free_vars=True) def test_alter_op_dense(): @@ -1383,7 +1383,7 @@ def expected(): a = before() a = run_opt_pass(a, transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_not_inplace_modify(): @@ -1449,7 +1449,7 @@ def expected(): ): a = run_opt_pass(before(), transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) @pytest.mark.skipif( @@ -1475,7 +1475,7 @@ def expected(): with TempOpAttr("nn.dense", "FTVMAlterOpLayout", topi.arm_cpu.dense_alter_op._alter_dense): a = run_opt_pass(before(), transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) @pytest.mark.skipif( @@ -1505,7 +1505,7 @@ def expected(): with TempOpAttr("nn.dense", "FTVMAlterOpLayout", topi.arm_cpu.dense_alter_op._alter_dense): a = run_opt_pass(before(), transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) @pytest.mark.skipif( @@ -1534,7 +1534,7 @@ def expected(): with TempOpAttr("nn.dense", "FTVMAlterOpLayout", topi.arm_cpu.dense_alter_op._alter_dense): a = run_opt_pass(before(), transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_conv2d_strided_slice_packed_to_unpacked(): @@ -1583,7 +1583,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = run_opt_pass(before(), transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b) + tvm.ir.assert_structural_equal(a, b) def test_conv2d_strided_slice_arbitrary_stride(): @@ -1675,7 +1675,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = run_opt_pass(before(), transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) inp = np.random.uniform(size=(1, 16, 3, 3)).astype(np.float32) weight = np.random.uniform(size=(16, 16, 1, 1)).astype(np.float32) @@ -1737,7 +1737,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = run_opt_pass(before(), transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32) weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32) @@ -1799,7 +1799,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = run_opt_pass(before(), transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32) weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32) @@ -1887,7 +1887,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = run_opt_pass(before(), transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32) weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32) @@ -1959,7 +1959,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = run_opt_pass(before(), transform.AlterOpLayout()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) inp = np.random.uniform(size=(1, 4, 3, 3, 4)).astype(np.float32) weight = np.random.uniform(size=(4, 4, 1, 1, 4, 4)).astype(np.float32) @@ -2043,7 +2043,7 @@ def test_alter_with_subfunc(): func = relay.Function([x1], x3) mod = tvm.IRModule.from_expr(func) mod = relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(relay.transform.AlterOpLayout()(mod), mod) + tvm.ir.assert_structural_equal(relay.transform.AlterOpLayout()(mod), mod) def test_alter_with_reduce(): diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 908a06ffc8b2..a32f7d7f6190 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -217,7 +217,7 @@ def after(): for annotate_non_call_ops in [False, True]: result = transform.AnnotateTarget("test", annotate_non_call_ops)(before()) expected = transform.InferType()(after()) - assert tvm.ir.structural_equal(expected, result) + tvm.ir.assert_structural_equal(expected, result) def test_type_propagation(): @@ -285,7 +285,7 @@ def after(annotate_non_call_ops): for annotate_non_call_ops in [True, False, True]: result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) expected = transform.InferType()(after(annotate_non_call_ops)) - assert tvm.ir.structural_equal(expected, result) + tvm.ir.assert_structural_equal(expected, result) def test_tuple(): @@ -339,7 +339,7 @@ def after(annotate_non_call_ops): for annotate_non_call_ops in [False, True]: result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) expected = transform.InferType()(after(annotate_non_call_ops)) - assert tvm.ir.structural_equal(expected, result) + tvm.ir.assert_structural_equal(expected, result) def test_composite_function(): @@ -384,7 +384,7 @@ def after(): result = transform.AnnotateTarget("test")(before()) expected = transform.InferType()(after()) - assert tvm.ir.structural_equal(expected, result) + tvm.ir.assert_structural_equal(expected, result) def test_double_target(): @@ -402,7 +402,7 @@ def before(): mod = before() mod1 = transform.AnnotateTarget("double.A", annotate_non_call_ops)(mod) mod2 = transform.AnnotateTarget("double.A", annotate_non_call_ops)(mod1) - assert tvm.ir.structural_equal(mod1, mod2) + tvm.ir.assert_structural_equal(mod1, mod2) def test_different_targets(): @@ -426,7 +426,7 @@ def before(): mod1 = transform.AnnotateTarget("different.A", annotate_non_call_ops)(mod) mod1 = transform.AnnotateTarget("different.B", annotate_non_call_ops)(mod1) mod2 = transform.AnnotateTarget(["different.A", "different.B"], annotate_non_call_ops)(mod) - assert tvm.ir.structural_equal(mod1, mod2) + tvm.ir.assert_structural_equal(mod1, mod2) def test_multiple_runs(): @@ -453,7 +453,7 @@ def before(): mod = transform.AnnotateTarget("A", annotate_non_call_ops)(before()) mod = transform.AnnotateTarget("B", annotate_non_call_ops)(mod) expected = transform.AnnotateTarget(["A", "B"], annotate_non_call_ops)(before()) - assert tvm.ir.structural_equal(expected, mod) + tvm.ir.assert_structural_equal(expected, mod) def test_ends_with_tuple(): @@ -504,7 +504,7 @@ def get_expected(annotate_non_call_ops, get_item): mod = get_model(get_item) mod = transform.AnnotateTarget("clip", annotate_non_call_ops)(mod) expected = transform.InferType()(get_expected(annotate_non_call_ops, get_item)) - assert tvm.ir.structural_equal(expected, mod) + tvm.ir.assert_structural_equal(expected, mod) def test_if_else(): @@ -576,7 +576,7 @@ def after(): expected = transform.InferType()(after()) for annotate_non_call_ops in [True, False]: result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) - assert tvm.ir.structural_equal(expected, result) + tvm.ir.assert_structural_equal(expected, result) def test_while_let(): @@ -677,7 +677,7 @@ def after(annotate_non_call_ops): for annotate_non_call_ops in [False, True]: result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) expected = transform.InferType()(after(annotate_non_call_ops)) - assert tvm.ir.structural_equal(expected, result) + tvm.ir.assert_structural_equal(expected, result) def test_if_free_vars(): @@ -743,7 +743,7 @@ def after(): for annotate_non_call_ops in [True, False]: result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) expected = transform.InferType()(after()) - assert tvm.ir.structural_equal(expected, result) + tvm.ir.assert_structural_equal(expected, result) def test_free_vars_zeros(): @@ -763,7 +763,7 @@ def after(): result = transform.AnnotateTarget(target)(before()) expected = transform.InferType()(after()) - assert tvm.ir.structural_equal(expected, result) + tvm.ir.assert_structural_equal(expected, result) def test_empty_tuple(): @@ -784,7 +784,7 @@ def after(): for annotate_non_call_ops in [True, False]: result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) expected = transform.InferType()(after()) - assert tvm.ir.structural_equal(expected, result) + tvm.ir.assert_structural_equal(expected, result) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_canonicalize_cast.py b/tests/python/relay/test_pass_canonicalize_cast.py index 321d866a9e46..2a7d83fe27df 100644 --- a/tests/python/relay/test_pass_canonicalize_cast.py +++ b/tests/python/relay/test_pass_canonicalize_cast.py @@ -61,7 +61,7 @@ def check(shape): mod[gv] = y_expected mod = _transform.InferType()(mod) y_expected = mod["expected"] - assert tvm.ir.structural_equal(y, y_expected) + tvm.ir.assert_structural_equal(y, y_expected) check((1, 16, 7, 7)) diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index b9a5cca85cd2..0d41ed1294f8 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -82,7 +82,7 @@ def check(x_shape, channels1, channels2, channels3, channels4): y = run_opt_pass(y_before, transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) check((1, 4, 16, 16), 4, 4, 4, 4) check((1, 4, 16, 16), 4, 8, 4, 7) @@ -132,7 +132,7 @@ def check(x_shape, channels1, channels2): y = run_opt_pass(y_before, transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) check((1, 4, 16, 16), 4, 8) @@ -175,7 +175,7 @@ def check(x_shape, channels1, channels2): y = run_opt_pass(y_before, transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) check((1, 4, 16, 16), 4, 8) @@ -214,7 +214,7 @@ def check(x_shape, repeat): y = run_opt_pass(y_before, transform.CombineParallelConv2D(min_num_branches=2)) y_expected = expected(x, w, out_c, repeat) y_expected = run_opt_pass(y_expected, transform.InferType()) - assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) check((1, 4, 16, 16), 4) diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index c3d579186d4a..49afe492a121 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -54,7 +54,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_qnn_binary_no_convert_layout(): @@ -81,7 +81,7 @@ def expected(): a = before() a = run_opt_pass(a, transform.ConvertLayout({})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_convert_layout(): @@ -116,7 +116,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_nhwc_convert_layout(): @@ -159,7 +159,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_transpose_convert_layout(): @@ -194,7 +194,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d_transpose": ["NCHW", "IOHW"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_bias_pool_convert_layout(): @@ -246,7 +246,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_bias_pool_uses_specified_convert_layout(): @@ -301,7 +301,7 @@ def expected(): ) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) def test_conv_concat_convert_layout(): @@ -349,7 +349,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_deformable_conv_bias_pool_convert_layout(): @@ -457,7 +457,7 @@ def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout): b = run_opt_pass( expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NHWC", "NCHW"), transform.InferType() ) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) # NCHW -> NHWC a = before(1, 3, 224, 224, 32, 3, 3, "NCHW") @@ -465,7 +465,7 @@ def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout): b = run_opt_pass( expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NCHW", "NHWC"), transform.InferType() ) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_deformable_conv_bias_pool_uses_specified_convert_layout(): @@ -582,7 +582,7 @@ def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout, max_pool_l expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NHWC", "NCHW", max_pool_layout="NHWC"), transform.InferType(), ) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) # NCHW -> NHWC a = before(1, 3, 224, 224, 32, 3, 3, "NCHW") @@ -598,7 +598,7 @@ def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout, max_pool_l expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NCHW", "NHWC", max_pool_layout="NCHW"), transform.InferType(), ) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) def test_dual_path_convert_layout(): @@ -653,7 +653,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_bn_convert_layout(): @@ -888,7 +888,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_resnet_pool_uses_specified_convert_layout(): @@ -939,7 +939,7 @@ def expected(): ) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) def test_scalar_convert_layout(): @@ -975,7 +975,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_ln_convert_layout(): @@ -1022,7 +1022,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_InstanceNorm_convert_layout(): @@ -1069,7 +1069,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_bn_convert_layout(): @@ -1122,7 +1122,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_qnn_conv_requantize_convert_layout(): @@ -1188,7 +1188,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_qnn_conv_concat_convert_layout(): @@ -1282,7 +1282,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_qnn_conv_add_convert_layout(): @@ -1380,7 +1380,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_qnn_conv_nhwc_convert_layout(): @@ -1431,7 +1431,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d": ["NHWC", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_qnn_conv_transpose_requantize_convert_layout(): @@ -1498,7 +1498,7 @@ def expected(): a = before() a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d_transpose": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_convert_kernel_layout(): @@ -1539,7 +1539,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_roi_align_convert_layout(): @@ -1592,7 +1592,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout(desired_layouts)) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_strided_slice_convert_layout(): @@ -1637,7 +1637,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_split_convert_layout(): @@ -1679,7 +1679,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def _test_conv_split_convert_layout2(): def before(): @@ -1719,7 +1719,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def _test_conv_split_convert_layout3(): def before(): @@ -1762,7 +1762,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def _test_conv_split_convert_layout_blocking(): def before(): @@ -1810,7 +1810,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW4c", "OIHW4o"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) _test_conv_split_convert_layout1() _test_conv_split_convert_layout2() @@ -1858,7 +1858,7 @@ def expected(): a = run_opt_pass(before(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_topk_convert_layout(): @@ -1898,7 +1898,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_roi_pool_convert_layout(): @@ -1951,7 +1951,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout(desired_layouts)) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_default_keyword(): @@ -1992,7 +1992,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_different_ops_convert_layout(): @@ -2098,7 +2098,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout(desired_layouts)) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_no_desired_layout(): @@ -2147,7 +2147,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "HWIO"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_convert_with_config(): @@ -2219,7 +2219,7 @@ def expected(): with layout_config: a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["HWNC", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_squeeze_convert_layout(): @@ -2255,7 +2255,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def _test_conv_squeeze_convert_layout2(): # all axes of dimension 1 are squeezed @@ -2288,7 +2288,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def _test_conv_squeeze_convert_layout3(): # squeeze axis is empty @@ -2322,7 +2322,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) _test_conv_squeeze_convert_layout1() _test_conv_squeeze_convert_layout2() @@ -2366,7 +2366,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def _test_conv_reduce_convert_layout2(): def _set_span(y, text): @@ -2414,7 +2414,7 @@ def expected(): assert "SpanSum" in a.astext() b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) _test_conv_reduce_convert_layout1() _test_conv_reduce_convert_layout2() @@ -2440,7 +2440,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"image.resize2d": ["NHWC"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def _test_image_resize_convert_layout_nhwc_to_nchw(): def before(): @@ -2461,7 +2461,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"image.resize2d": ["NCHW"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) _test_image_resize_convert_layout_nchw_to_nhwc() _test_image_resize_convert_layout_nhwc_to_nchw() @@ -2501,7 +2501,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_infer_correct_layout(): @@ -2587,7 +2587,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_conv_max_pool_uses_specified_convert_layout(): @@ -2636,7 +2636,7 @@ def expected(): ) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) def test_simulated_quantize_uses_specified_convert_layout(): @@ -2681,7 +2681,7 @@ def expected(): a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) @pytest.mark.parametrize( @@ -2792,7 +2792,7 @@ def expected(): a = before() a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": [data_layout, kernel_layout]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n Expect = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) def test_conv_l2n_convert_layout(): @@ -2831,7 +2831,7 @@ def expected(): a = before() a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index 70dc1dd4f794..6374d20173b2 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -41,7 +41,7 @@ def optimize_and_check(before_program, after_program, passes): print(optimized_program) print("Expected:") print(after_program) - assert tvm.ir.structural_equal(optimized_program, after_program, map_free_vars=True) + tvm.ir.assert_structural_equal(optimized_program, after_program, map_free_vars=True) def test_dead_let(): diff --git a/tests/python/relay/test_pass_defuse_ops.py b/tests/python/relay/test_pass_defuse_ops.py index ec6431ee269a..4f446865c7a7 100644 --- a/tests/python/relay/test_pass_defuse_ops.py +++ b/tests/python/relay/test_pass_defuse_ops.py @@ -37,7 +37,7 @@ def before(): fused = run_opt_pass(x, transform.FuseOps()) defused = run_opt_pass(fused, transform.DefuseOps()) - assert tvm.ir.structural_equal(x, defused) + tvm.ir.assert_structural_equal(x, defused) def test_inception_like(): @@ -62,7 +62,7 @@ def before(dshape): fused = run_opt_pass(x, transform.FuseOps()) defused = run_opt_pass(fused, transform.DefuseOps()) - assert tvm.ir.structural_equal(x, defused) + tvm.ir.assert_structural_equal(x, defused) def test_defuse_complex(): @@ -206,9 +206,7 @@ def golden_defused(conv_layer1_weight, conv_layer2_weight): golden1 = golden_defused(conv_layer1_weight, conv_layer2_weight) golden1 = run_opt_pass(golden1, transform.InferType()) - assert tvm.ir.structural_equal(defused, golden1), ( - "Actual = \n" + str(defused) + "\nGolden = \n" + str(golden1) - ) + tvm.ir.assert_structural_equal(defused, golden1) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py b/tests/python/relay/test_pass_eliminate_common_subexpr.py index a8ca5058ad7f..fd4bb0c9fbfa 100644 --- a/tests/python/relay/test_pass_eliminate_common_subexpr.py +++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py @@ -53,7 +53,7 @@ def expected(): z = before() z = run_opt_pass(z, transform.EliminateCommonSubexpr()) - assert tvm.ir.structural_equal(z, expected()) + tvm.ir.assert_structural_equal(z, expected()) def test_callback(): @@ -83,7 +83,7 @@ def fskip(expr): z = before() z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip)) - assert tvm.ir.structural_equal(z, expected()) + tvm.ir.assert_structural_equal(z, expected()) def test_tuple_get_time(): @@ -114,7 +114,7 @@ def expected(): z = before() z = run_opt_pass(z, transform.EliminateCommonSubexpr()) - assert tvm.ir.structural_equal(z, expected()) + tvm.ir.assert_structural_equal(z, expected()) def test_tuple_arg(): @@ -143,7 +143,7 @@ def expected(): z = before() z = run_opt_pass(z, transform.EliminateCommonSubexpr()) - assert tvm.ir.structural_equal(z, expected()) + tvm.ir.assert_structural_equal(z, expected()) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 3425a9a72b9b..6edb3949d683 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -890,7 +890,7 @@ def conv2d(expr, type_map): # pylint: disable=unused-variable mod = tvm.relay.transform.InferType()(mod) mod_int = tvm.relay.transform.FakeQuantizationToInteger(hard_fail=False)(mod) - assert tvm.ir.structural_equal(mod_int, mod) + tvm.ir.assert_structural_equal(mod_int, mod) # Catch a generic exception because the tvm FFI eats the python exception type with pytest.raises(Exception): mod_int = tvm.relay.transform.FakeQuantizationToInteger(hard_fail=True)(mod) @@ -902,7 +902,7 @@ def compare_expected_fq_qat_to_int(expr, expected_expr, args, allow_rounding_err mod_int = tvm.relay.transform.FakeQuantizationToInteger(False, True)(mod_def) mod_exp = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expected_expr)) assert not tvm.ir.structural_equal(mod, mod_int) - assert tvm.ir.structural_equal(mod_int, mod_exp) + tvm.ir.assert_structural_equal(mod_int, mod_exp) result_def = ( relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm") .evaluate()(*args) diff --git a/tests/python/relay/test_pass_flatten_atrous_conv.py b/tests/python/relay/test_pass_flatten_atrous_conv.py index 39c92c5ed6c7..37b69a426df2 100644 --- a/tests/python/relay/test_pass_flatten_atrous_conv.py +++ b/tests/python/relay/test_pass_flatten_atrous_conv.py @@ -29,7 +29,7 @@ def compare_expected_fac(expr, expected_expr, args): mod_exp = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expected_expr)) assert expr is expected_expr or not tvm.ir.structural_equal(mod_def, mod_flat) - assert tvm.ir.structural_equal(mod_flat, mod_exp) + tvm.ir.assert_structural_equal(mod_flat, mod_exp) result_def = ( relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm") diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index f69447d43e80..585ae5d7a21d 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -55,7 +55,7 @@ def expected(): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, zexpected) + tvm.ir.assert_structural_equal(zz, zexpected) def test_fold_const(): diff --git a/tests/python/relay/test_pass_fold_explicit_padding.py b/tests/python/relay/test_pass_fold_explicit_padding.py index 35354508a953..f2bd360fc667 100644 --- a/tests/python/relay/test_pass_fold_explicit_padding.py +++ b/tests/python/relay/test_pass_fold_explicit_padding.py @@ -64,7 +64,7 @@ def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout, no_fold zz = run_opt_pass(conv, transform.FoldExplicitPadding()) expected = run_opt_pass(after, transform.InferType()) - assert tvm.ir.structural_equal(zz, expected) + tvm.ir.assert_structural_equal(zz, expected) mod1 = tvm.IRModule.from_expr(conv) mod2 = tvm.IRModule.from_expr(zz) @@ -187,7 +187,7 @@ def validate( zz = run_opt_pass(pool, transform.FoldExplicitPadding()) expected = run_opt_pass(after, transform.InferType()) - assert tvm.ir.structural_equal(zz, expected) + tvm.ir.assert_structural_equal(zz, expected) mod1 = tvm.IRModule.from_expr(pool) mod2 = tvm.IRModule.from_expr(zz) @@ -310,7 +310,7 @@ def expected(): a = run_opt_pass(before(), relay.transform.FoldExplicitPadding()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b, map_free_vars=True) def test_pad_qconv2d_no_fold(): @@ -336,9 +336,7 @@ def get_expr(): a = run_opt_pass(get_expr(), relay.transform.FoldExplicitPadding()) b = run_opt_pass(get_expr(), transform.InferType()) - assert tvm.ir.structural_equal(a, b, map_free_vars=True), ( - "\nActual = \n" + str(a) + "\nExpected = \n" + str(b) - ) + tvm.ir.assert_structural_equal(a, b, map_free_vars=True) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index 8ffa3ef832e0..bf8dcc0d9c47 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -118,7 +118,7 @@ def check(shape, channels, blocking): y1_folded = run_opt_pass(y1_folded, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 2, None) check((2, 2, 10, 10, 2), 8, (2, 4)) @@ -226,7 +226,7 @@ def check(dshape, channels, blocking): weight = relay.var("weight", type_dict["weight"]) y1_expected = expected(x, weight, in_bias, in_scale, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4, 10, 3), 3, None) check((2, 4, 10, 2, 2), 4, (2, 2)) @@ -266,7 +266,7 @@ def check(shape, channels, blocking): y1 = before(x, weight, in_bias, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) - assert tvm.ir.structural_equal(y1, y1_folded) + tvm.ir.assert_structural_equal(y1, y1_folded) check((2, 11, 10, 4), 4, None) check((2, 11, 10, 2, 2), 4, (2, 2)) @@ -304,7 +304,7 @@ def check(shape, channels, blocking, in_scale): y1 = before(x, weight, in_bias, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) - assert tvm.ir.structural_equal(y1, y1_folded) + tvm.ir.assert_structural_equal(y1, y1_folded) in_scale = relay.var("in_scale", shape=(4,)) check((2, 11, 10, 4), 4, None, in_scale) @@ -350,7 +350,7 @@ def check(shape, channels): y1 = before(x, weight, in_bias, in_scale, channels) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) - assert tvm.ir.structural_equal(y1, y1_folded) + tvm.ir.assert_structural_equal(y1, y1_folded) check((2, 11, 10, 4), 4) @@ -413,7 +413,7 @@ def check(shape, channels, blocking): y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) y1_expected = expected(x, weight, in_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 4, None) check((2, 2, 10, 10, 2), 8, (2, 2)) @@ -453,7 +453,7 @@ def check(data_shape, weight_shape): y1_folded = run_opt_pass(y1_folded, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4), (3, 4)) check((3, 5), (4, 5)) @@ -539,7 +539,7 @@ def check(shape, in_channels, channels, blocking): y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 4, 8, None) check((2, 2, 10, 10, 16), 32, 64, (16, 16)) @@ -636,7 +636,7 @@ def check(shape, in_channels, channels, blocking): y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 4, 8, None) check((2, 2, 10, 10, 2), 4, 8, (2, 2)) @@ -798,7 +798,7 @@ def check(shape, in_channels, channels, blocking): y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 4, 4, None) check((2, 2, 10, 10, 2), 4, 4, (2, 2)) @@ -867,7 +867,7 @@ def check(shape, in_channels, channels, blocking, fbefore): y1 = fbefore(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - assert tvm.ir.structural_equal(y1_folded, y1) + tvm.ir.assert_structural_equal(y1_folded, y1) check((4, 4, 10, 10), 4, 4, None, fail1) check((2, 2, 10, 10, 2), 4, 4, (2, 2), fail1) @@ -899,7 +899,7 @@ def check(shape, channels, blocking, out_scale): y1 = before(x, weight, out_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - assert tvm.ir.structural_equal(y1, y1_folded) + tvm.ir.assert_structural_equal(y1, y1_folded) out_scale = relay.var("in_scale", shape=(4, 1, 1)) check((4, 4, 10, 10), 4, None, out_scale) @@ -972,7 +972,7 @@ def check(shape, channels, blocking): y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_scale, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 8, None) check((2, 2, 10, 10, 2), 8, (2, 2)) @@ -1013,7 +1013,7 @@ def check(data_shape, weight_shape): y1_folded = run_opt_pass(y1_folded, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4), (3, 4)) check((3, 5), (4, 5)) @@ -1073,7 +1073,7 @@ def check(shape, channels): y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, channels) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 4) @@ -1160,7 +1160,7 @@ def check(shape, channels, blocking): y1_folded = run_opt_pass(y1_folded, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10, 10), 2, None) check((2, 2, 10, 10, 10, 2), 8, (2, 4)) @@ -1248,7 +1248,7 @@ def check(shape, in_channels, channels, blocking): y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) - assert tvm.ir.structural_equal(y1_folded, y1_expected) + tvm.ir.assert_structural_equal(y1_folded, y1_expected) check((2, 4, 10, 10, 10), 4, 8, None) check((2, 2, 10, 10, 10, 16), 32, 64, (16, 16)) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 714818328f66..11411a830658 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -49,7 +49,7 @@ def expected(): z = before() zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) def test_conv2d_fuse(): @@ -114,7 +114,7 @@ def expected(dshape): z = before(dshape) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) def test_concatenate(): @@ -154,7 +154,7 @@ def expected(dshape): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) def test_tuple_root(): @@ -191,7 +191,7 @@ def expected(dshape): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) def test_stop_fusion(): @@ -222,7 +222,7 @@ def expected(dshape): z = before(dshape) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) def test_fuse_myia_regression(): @@ -255,7 +255,7 @@ def expected(dshape, dtype): f = before(dshape, dtype) zz = run_opt_pass(f, transform.FuseOps()) after = run_opt_pass(expected(dshape, dtype), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) def test_fuse_tuple_get_elemwise(): @@ -293,7 +293,7 @@ def expected(dim): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dim), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) def test_tuple_get_root(): @@ -330,7 +330,7 @@ def expected(dim): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dim), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) def fuse0(mod): @@ -370,7 +370,7 @@ def expected(p0): m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, "llvm") after = run_opt_pass(expected(x), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) + tvm.ir.assert_structural_equal(m["main"], after) def test_tuple_consecutive(): @@ -428,7 +428,7 @@ def expected(dshape): m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, "llvm") after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) + tvm.ir.assert_structural_equal(m["main"], after) def test_inception_like(): @@ -498,7 +498,7 @@ def expected(dshape): m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, "llvm") after = run_opt_pass(expected(dshape), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) + tvm.ir.assert_structural_equal(m["main"], after) def test_fuse_parallel_injective(): @@ -530,7 +530,7 @@ def expected(): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) def test_immutable(): @@ -560,8 +560,8 @@ def expected(): mod = transform.InferType()(before()) new_mod = transform.FuseOps(fuse_opt_level=2)(mod) - assert tvm.ir.structural_equal(mod, transform.InferType()(before())) - assert tvm.ir.structural_equal(new_mod, transform.InferType()(expected())) + tvm.ir.assert_structural_equal(mod, transform.InferType()(before())) + tvm.ir.assert_structural_equal(new_mod, transform.InferType()(expected())) def test_split(): @@ -612,7 +612,7 @@ def expected(n, max_fused_ops): zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) max_fused_ops = 10 n = 20 @@ -622,13 +622,13 @@ def expected(n, max_fused_ops): with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): zz = run_opt_pass(z, transform.FuseOps()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) with tvm.target.Target("opencl"): with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): cl_zz = run_opt_pass(z, transform.FuseOps()) - assert tvm.ir.structural_equal(cl_zz, after) + tvm.ir.assert_structural_equal(cl_zz, after) link_params = tvm.testing.parameter(False, True) @@ -664,7 +664,7 @@ def expected(link_params): with tvm.transform.PassContext(opt_level=2, config={"relay.FuseOps.link_params": link_params}): m = run_opt_pass(before(), transform.InferType()) m = run_opt_pass(m, transform.FuseOps()) - assert tvm.ir.structural_equal(m, after) + tvm.ir.assert_structural_equal(m, after) relay.build(m, "llvm") @@ -698,7 +698,7 @@ def expected(link_params): with tvm.transform.PassContext(opt_level=2, config={"relay.FuseOps.link_params": link_params}): m = run_opt_pass(before(), transform.InferType()) m = run_opt_pass(m, transform.FuseOps()) - assert tvm.ir.structural_equal(m, after) + tvm.ir.assert_structural_equal(m, after) relay.build(m, "llvm") @@ -728,7 +728,7 @@ def expected(): for tgt, dev in tvm.testing.enabled_targets(): relay.build(m, tgt) after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) + tvm.ir.assert_structural_equal(m["main"], after) def test_fuse_max_diamond(): @@ -769,7 +769,7 @@ def create_diamond_func(inp): fused = run_opt_pass(before(branch_len, num_diamond), transform.FuseOps()) expected = run_opt_pass(after(branch_len, num_diamond), transform.InferType()) - assert tvm.ir.structural_equal(fused, expected) + tvm.ir.assert_structural_equal(fused, expected) def test_fuse_dynamic_squeeze_slice_take(): @@ -823,7 +823,7 @@ def expected(): orig = before() m = fuse2(tvm.IRModule.from_expr(orig)) after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(m["main"], after) + tvm.ir.assert_structural_equal(m["main"], after) inp = np.random.randn(16, channel_size).astype("float32") ref = tvm.topi.testing.softmax_python(inp).astype("float16") @@ -941,7 +941,7 @@ def create_accum_func(args_limit): expected = run_opt_pass(after(ops_num), transform.InferType()) - assert tvm.ir.structural_equal(fused, expected) + tvm.ir.assert_structural_equal(fused, expected) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_inline.py b/tests/python/relay/test_pass_inline.py index f5898774f50b..482c2246654d 100644 --- a/tests/python/relay/test_pass_inline.py +++ b/tests/python/relay/test_pass_inline.py @@ -113,7 +113,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) def test_call_chain_inline_multiple_levels(): @@ -186,7 +186,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) def test_call_chain_inline_multiple_levels_extern_compiler(): @@ -264,7 +264,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) def test_recursive_call_with_global(): @@ -315,7 +315,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) def test_recursive_called(): @@ -324,7 +324,7 @@ def test_recursive_called(): mod["main"] = relay.Function([iarg], sum_up(iarg)) ref_mod = mod mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, ref_mod, map_free_vars=True) def test_recursive_not_called(): @@ -350,7 +350,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) ref_mod = expected() - assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, ref_mod, map_free_vars=True) def test_recursive_not_called_extern_compiler(): @@ -381,7 +381,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) ref_mod = expected() - assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, ref_mod, map_free_vars=True) def test_globalvar_as_call_arg(): @@ -428,7 +428,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) def test_globalvar_as_call_arg_extern_compiler(): @@ -494,7 +494,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) def test_inline_globalvar_without_args(): @@ -525,7 +525,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) def test_inline_globalvar_without_args_extern_compiler(): @@ -559,7 +559,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) def test_globalvar_called_by_multiple_functions(): @@ -637,7 +637,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) def test_entry_with_inline(): @@ -667,7 +667,7 @@ def get_mod(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, get_mod(), map_free_vars=True) def test_callee_not_inline(): @@ -700,7 +700,7 @@ def get_mod(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, get_mod(), map_free_vars=True) def test_callee_not_inline_leaf_inline(): @@ -758,7 +758,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) def test_callee_not_inline_leaf_inline_extern_compiler(): @@ -823,7 +823,7 @@ def expected(): mod = get_mod() mod = relay.transform.Inline()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_legalize.py b/tests/python/relay/test_pass_legalize.py index 1466784394ac..614663a62df2 100644 --- a/tests/python/relay/test_pass_legalize.py +++ b/tests/python/relay/test_pass_legalize.py @@ -71,7 +71,7 @@ def expected(): a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_legalize_none(): @@ -94,7 +94,7 @@ def legalize_conv2d(attrs, inputs, types): a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(before(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) assert called[0] @@ -140,7 +140,7 @@ def expected(): a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_legalize_multi_input(): @@ -176,7 +176,7 @@ def expected(): a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) @pytest.mark.parametrize( diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py index c9782aec1b2c..9f4a09dac46b 100644 --- a/tests/python/relay/test_pass_legalize_tensorcore.py +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -97,7 +97,7 @@ def expected(): a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) for dtype in ["float16", "int8", "int4"]: # conv2d pad batch @@ -177,7 +177,7 @@ def expected(): a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) # conv2d pad batch _test_legalize_conv2d((16, 16, 7, 64), (3, 3, 64, 64), (1, 0, 0), "int8") @@ -250,7 +250,7 @@ def expected(): a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) # dense for dtype in ["float16", "int8"]: @@ -345,7 +345,7 @@ def expected(): a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) + tvm.ir.assert_structural_equal(a, b) for dtype in ["float16", "int8"]: _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 0, 0), dtype, False) diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 4088cfdef073..9da3869288e9 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -110,7 +110,7 @@ def get_rand(shape, dtype="float32"): def check_func(func, ref_func): func = run_infer_type(func) ref_func = run_infer_type(ref_func) - assert tvm.ir.structural_equal(func, ref_func) + tvm.ir.assert_structural_equal(func, ref_func) @tvm.testing.uses_gpu @@ -216,7 +216,7 @@ def transform_function(self, func, mod, ctx): # wrap in expr mod2 = tvm.IRModule.from_expr(f1) mod2 = tvm.relay.transform.InferType()(mod2) - assert tvm.ir.structural_equal(mod["main"], mod2["main"]) + tvm.ir.assert_structural_equal(mod["main"], mod2["main"]) @tvm.testing.uses_gpu @@ -504,7 +504,7 @@ def expected(): zz = mod["main"] zexpected = run_infer_type(expected()) - assert tvm.ir.structural_equal(zz, zexpected) + tvm.ir.assert_structural_equal(zz, zexpected) def test_nested_sequential_with_scoping(): @@ -532,7 +532,7 @@ def expected(): zz = tvm.transform.Sequential(passes)(z) expected = relay.transform.InferType()(expected()) - assert tvm.ir.structural_equal(zz, expected) + tvm.ir.assert_structural_equal(zz, expected) def test_print_ir(capfd): diff --git a/tests/python/relay/test_pass_manifest_lifetimes.py b/tests/python/relay/test_pass_manifest_lifetimes.py index 98e203e697be..ee9f824582ab 100644 --- a/tests/python/relay/test_pass_manifest_lifetimes.py +++ b/tests/python/relay/test_pass_manifest_lifetimes.py @@ -35,7 +35,7 @@ def optimize_and_check(before_program, after_program, passes): print(optimized_program) print("Expected:") print(after_program) - assert tvm.ir.structural_equal(optimized_program, after_program, map_free_vars=True) + tvm.ir.assert_structural_equal(optimized_program, after_program, map_free_vars=True) def test_simple_linear(): diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py index a2c1c1006ba8..440a56f43b21 100644 --- a/tests/python/relay/test_pass_merge_compiler_regions.py +++ b/tests/python/relay/test_pass_merge_compiler_regions.py @@ -84,7 +84,7 @@ def expected(): result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions()) golden = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, golden) + tvm.ir.assert_structural_equal(result, golden) def test_example_graph(): @@ -212,7 +212,7 @@ def expected(): mod = relay.transform.InferType()(mod) ref_mod = expected() ref_mod = relay.transform.InferType()(ref_mod) - assert tvm.ir.structural_equal(mod, ref_mod) + tvm.ir.assert_structural_equal(mod, ref_mod) def test_if_else(): diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 739db69e10f1..7983c5370bea 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -175,9 +175,7 @@ def check_result(pattern_table, graph, expected_graph, import_prelude=False): str(result) ) expected = run_opt_pass(expected_graph, relay.transform.InferType()) - assert tvm.ir.structural_equal( - result, expected, map_free_vars=True - ), "Graph mismatch: output vs. expected\n{0}\n=====\n{1}".format(str(result), str(expected)) + tvm.ir.assert_structural_equal(result, expected, map_free_vars=True) def test_simple_merge(): diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index bec9041e4688..214b9fa330ec 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -73,7 +73,7 @@ def test_tuple(): f = Function([x], body, None, [t]) expected = relay.Function([x], x, None, [t]) expected = run_opt_pass(expected, transform.InferType()) - assert tvm.ir.structural_equal(dcpe(f), expected) + tvm.ir.assert_structural_equal(dcpe(f), expected) def test_const_inline(): @@ -81,7 +81,7 @@ def test_const_inline(): d = Var("d", t) double = Function([d], d + d) orig = double(const(4.0)) - assert tvm.ir.structural_equal(dcpe(orig), const(8.0)) + tvm.ir.assert_structural_equal(dcpe(orig), const(8.0)) def test_ref(): @@ -96,7 +96,7 @@ def test_ref(): expected = run_opt_pass(Function([d], d * d), transform.InferType()) # TODO(mbs): Revisit once DCE eliminates dead writes. actual = dcpe(square, ignore_impurity=True) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_empty_ad(): @@ -109,7 +109,7 @@ def test_empty_ad(): g = dcpe(f, grad=True, ignore_impurity=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) expected = run_opt_pass(expected, transform.InferType()) - assert tvm.ir.structural_equal(g, expected) + tvm.ir.assert_structural_equal(g, expected) def test_ad(): @@ -185,7 +185,7 @@ def test_head_cons(): f = Function([x], body, None, [t]) res = dcpe(f, mod) expected_mod = tvm.IRModule.from_expr(Function([x], x, t, [t])) - assert tvm.ir.structural_equal(res, expected_mod["main"]) + tvm.ir.assert_structural_equal(res, expected_mod["main"]) def test_map(): @@ -205,7 +205,7 @@ def test_map(): expected = mod["main"] orig = Function([], orig) res = dcpe(orig, mod=mod) - assert tvm.ir.structural_equal(res.body, expected.body) + tvm.ir.assert_structural_equal(res.body, expected.body) def test_loop(): @@ -220,7 +220,7 @@ def test_loop(): expected = mod["main"].body call = Function([], loop(const(1))) res = dcpe(call, mod=mod) - assert tvm.ir.structural_equal(res.body, expected) + tvm.ir.assert_structural_equal(res.body, expected) def test_swap_loop(): @@ -235,7 +235,7 @@ def test_swap_loop(): prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) res = Function([], prog) res = dcpe(res, mod=mod) - assert tvm.ir.structural_equal(prog, res.body) + tvm.ir.assert_structural_equal(prog, res.body) def test_abs_diff(): @@ -257,7 +257,7 @@ def test_abs_diff(): orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) - assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4)) + tvm.ir.assert_structural_equal(res.body, make_nat_expr(p, 4)) def test_match_nat_id(): @@ -274,7 +274,7 @@ def test_match_nat_id(): orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) - assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3)) + tvm.ir.assert_structural_equal(res.body, make_nat_expr(p, 3)) def test_nat_id(): @@ -289,7 +289,7 @@ def test_nat_id(): orig = nat_id(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) - assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3)) + tvm.ir.assert_structural_equal(res.body, make_nat_expr(p, 3)) def test_global_match_nat_id(): @@ -303,7 +303,7 @@ def test_global_match_nat_id(): orig = Match(make_nat_expr(p, 3), [z_case, s_case]) orig = Function([], orig) res = dcpe(orig, mod=mod) - assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3)) + tvm.ir.assert_structural_equal(res.body, make_nat_expr(p, 3)) def test_double(): @@ -314,7 +314,7 @@ def test_double(): orig = double(make_nat_expr(p, 3)) orig = Function([], orig) res = dcpe(orig, mod=mod) - assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 6)) + tvm.ir.assert_structural_equal(res.body, make_nat_expr(p, 6)) def test_concat(): diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index ce09a939cefc..5ee1c955b093 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -327,7 +327,7 @@ def expected(): mod = transform.PartitionGraph()(mod) fused_mod = transform.FuseOps(2)(mod) expected_mod = expected() - assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(fused_mod, expected_mod, map_free_vars=True) x_data = np.random.rand(8, 8).astype("float32") y_data = np.random.rand(8, 8).astype("float32") @@ -376,7 +376,7 @@ def expected(): mod = transform.PartitionGraph()(mod) fused_mod = transform.FuseOps(2)(mod) expected_mod = expected() - assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(fused_mod, expected_mod, map_free_vars=True) def test_extern_ccompiler_multiple_functions(): @@ -451,7 +451,7 @@ def expected(): fused_mod = transform.FuseOps(2)(mod) expected_mod = expected() - assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(fused_mod, expected_mod, map_free_vars=True) x_data = np.random.rand(8, 8).astype("float32") y_data = np.random.rand(8, 8).astype("float32") @@ -529,7 +529,7 @@ def get_func(): mod = transform.PartitionGraph()(mod) mod = transform.InferType()(mod) - assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected(), map_free_vars=True) ref_mod = tvm.IRModule() ref_mod["main"] = get_func() @@ -650,7 +650,7 @@ def expected(): partitioned = partition() ref_mod = expected() - assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(partitioned, ref_mod, map_free_vars=True) def test_function_lifting_inline(): @@ -712,7 +712,7 @@ def expected(): partitioned = partition() ref_mod = expected() - assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(partitioned, ref_mod, map_free_vars=True) def test_constant_propagation(): @@ -751,7 +751,7 @@ def expected(): expected_mod = expected() expected_mod = relay.transform.InferType()(expected_mod) - assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected_mod, map_free_vars=True) y_data = np.random.rand(8, 8).astype("float32") np_add = ones + y_data @@ -847,7 +847,7 @@ def expected(): mod["main"] = create_graph() ref_mod = expected() partitioned = transform.PartitionGraph()(mod) - assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(partitioned, ref_mod, map_free_vars=True) def test_mixed_single_multiple_outputs(): @@ -914,7 +914,7 @@ def expected(): ref_mod = expected() partitioned = transform.PartitionGraph()(mod) - assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(partitioned, ref_mod, map_free_vars=True) def test_dnnl_fuse(): @@ -1201,7 +1201,7 @@ def test_same_output_region(): mod = transform.PartitionGraph()(mod) expected_mod = expected_same_output_region() - assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected_mod, map_free_vars=True) def test_different_output_region(): mod = get_mod() @@ -1210,7 +1210,7 @@ def test_different_output_region(): mod = transform.PartitionGraph()(mod) expected_mod = expected_different_output_region() - assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, expected_mod, map_free_vars=True) test_same_output_region() test_different_output_region() @@ -1274,7 +1274,7 @@ def expected(): ref_mod = expected() partitioned = seq(mod) - assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(partitioned, ref_mod, map_free_vars=True) def test_duplicate_merge_and_tuplegetitem(): @@ -1357,7 +1357,7 @@ def expected(): ref_mod = expected() partitioned = seq(mod) - assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(partitioned, ref_mod, map_free_vars=True) def test_constant_tuples(): @@ -1477,7 +1477,7 @@ def expected(): partitioned = seq(create_graph()) partitioned = transform.InferType()(partitioned) expected_mod = transform.InferType()(expected()) - assert tvm.ir.structural_equal(partitioned, expected_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(partitioned, expected_mod, map_free_vars=True) def test_tuple_output_exec(): diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 4bb4e4813e30..adc93a0d2309 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -96,12 +96,12 @@ def expected(): # Check that Relay Legalize does not change the graph. a = run_opt_pass(a, relay.transform.Legalize()) b = run_opt_pass(before(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) # Check that QNN Legalize modifies the graph. a = run_opt_pass(a, relay.qnn.transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + tvm.ir.assert_structural_equal(a, b) def test_qnn_legalize_qnn_conv2d(): @@ -152,7 +152,7 @@ def _get_mod(data_dtype, kernel_dtype): "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod" ): legalized_mod = relay.qnn.transform.Legalize()(mod) - assert tvm.ir.structural_equal(mod, legalized_mod) + tvm.ir.assert_structural_equal(mod, legalized_mod) ################################################################ # Check transformations for platforms without fast Int8 support. @@ -176,7 +176,7 @@ def _get_mod(data_dtype, kernel_dtype): with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512"): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) - assert tvm.ir.structural_equal(mod, legalized_mod) + tvm.ir.assert_structural_equal(mod, legalized_mod) # ARM - so check that transformation has happened. with tvm.target.Target( @@ -249,7 +249,7 @@ def _get_mod(data_dtype, kernel_dtype): "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod" ): legalized_mod = relay.qnn.transform.Legalize()(mod) - assert tvm.ir.structural_equal(mod, legalized_mod) + tvm.ir.assert_structural_equal(mod, legalized_mod) ################################################################ # Check transformations for platforms without fast Int8 support. @@ -273,7 +273,7 @@ def _get_mod(data_dtype, kernel_dtype): with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512"): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) - assert tvm.ir.structural_equal(mod, legalized_mod) + tvm.ir.assert_structural_equal(mod, legalized_mod) # ARM - so check that transformation has happened. with tvm.target.Target( diff --git a/tests/python/relay/test_pass_remove_unused_functions.py b/tests/python/relay/test_pass_remove_unused_functions.py index 67efc9b20262..3c7aad40a506 100644 --- a/tests/python/relay/test_pass_remove_unused_functions.py +++ b/tests/python/relay/test_pass_remove_unused_functions.py @@ -117,7 +117,7 @@ def get_mod(): mod = get_mod() ref_mod = get_mod() mod = relay.transform.RemoveUnusedFunctions()(mod) - assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True) + tvm.ir.assert_structural_equal(mod, ref_mod, map_free_vars=True) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index ac6920d5b780..7e2971a04e1b 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -54,12 +54,12 @@ def symbolic(): z = before() zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) z = symbolic() zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(symbolic(), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) def test_simplify_transpose(): @@ -302,9 +302,7 @@ def expected11(): ]: after = run_opt_pass(before, transform.SimplifyExpr()) expected = run_opt_pass(expected, transform.InferType()) - assert tvm.ir.structural_equal(after, expected), "\nafter: {} \nexpected: {}".format( - after, expected - ) + tvm.ir.assert_structural_equal(after, expected) def test_simplify_full_elementwise(): @@ -348,12 +346,12 @@ def after_right(x, elem_op, value): z = before_left(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(after_left(x, op, value), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) z = before_right(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(after_right(x, op, value), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) # Test the case in which x is broadcast to full's shape full_ops = [] @@ -368,12 +366,12 @@ def after_right(x, elem_op, value): z = before_left(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(before_left(x, op, full), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) z = before_right(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(before_right(x, op, full), transform.InferType()) - assert tvm.ir.structural_equal(zz, after) + tvm.ir.assert_structural_equal(zz, after) for shape in [[10], [10, 10], [10, 10, 10]]: for dtype in ["float32", "int32", "bool"]: @@ -386,11 +384,11 @@ def check(x, y=None, do_nothing=False): expected = run_infer_type(x) if do_nothing: actual = run_opt_pass(x, transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) else: assert y is not None actual = run_opt_pass(y, transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) shape = [2, 3, 4] dtype = "float32" @@ -434,9 +432,9 @@ def test_simplify_same_cast(): expected = run_infer_type(data) actual1 = run_opt_pass(expr1, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual1, expected) + tvm.ir.assert_structural_equal(actual1, expected) actual2 = run_opt_pass(expr2, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual2, expected) + tvm.ir.assert_structural_equal(actual2, expected) def test_simplify_consecutive_cast(): @@ -451,13 +449,13 @@ def test_simplify_consecutive_cast(): actual1 = run_opt_pass(expr2, relay.transform.SimplifyExpr()) expected = run_infer_type(relay.cast(x, "int32")) - assert tvm.ir.structural_equal(actual1, expected) + tvm.ir.assert_structural_equal(actual1, expected) actual2 = run_opt_pass(expr3, relay.transform.SimplifyExpr()) expected = run_infer_type(relay.cast(x, "int64")) - assert tvm.ir.structural_equal(actual2, expected) + tvm.ir.assert_structural_equal(actual2, expected) actual3 = run_opt_pass(expr4, relay.transform.SimplifyExpr()) expected = run_infer_type(relay.cast(x, "float32")) - assert tvm.ir.structural_equal(actual3, expected) + tvm.ir.assert_structural_equal(actual3, expected) # cannot simplify the narrow cast x = relay.var("x", shape=(3, 4, 5), dtype="float32") @@ -466,14 +464,14 @@ def test_simplify_consecutive_cast(): expr2 = relay.cast_like(expr1, y) actual = run_opt_pass(expr2, relay.transform.SimplifyExpr()) expected = run_infer_type(relay.cast(expr1, "float32")) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) x = relay.var("x", shape=(3, 4), dtype="int64") expr1 = relay.cast(x, "bool") expr2 = relay.cast(expr1, "int32") actual = run_opt_pass(expr2, relay.transform.SimplifyExpr()) expected = run_infer_type(expr2) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_concretize_reshape_like(): @@ -483,7 +481,7 @@ def test_concretize_reshape_like(): expected = run_infer_type(relay.reshape(data, (6, 2, 2))) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_concretize_reshape_like_attrs(): @@ -493,7 +491,7 @@ def test_concretize_reshape_like_attrs(): expected = run_infer_type(relay.reshape(data, (2, 3, 2, 2))) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_concretize_zeros_like(): @@ -503,7 +501,7 @@ def test_concretize_zeros_like(): expected = run_infer_type(relay.zeros((3, 4, 5), dtype)) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_concretize_ones_like(): @@ -513,7 +511,7 @@ def test_concretize_ones_like(): expected = run_infer_type(relay.ones((3, 4, 5), dtype)) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_concretize_full_like(): @@ -524,7 +522,7 @@ def test_concretize_full_like(): expected = run_infer_type(relay.full(fill_value, (3, 4, 5), dtype)) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_concretize_collapse_sum_like(): @@ -534,7 +532,7 @@ def test_concretize_collapse_sum_like(): expected = run_infer_type(relay.collapse_sum_to(data, (3,))) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_concretize_broadcast_to_like(): @@ -544,7 +542,7 @@ def test_concretize_broadcast_to_like(): expected = run_infer_type(relay.broadcast_to(data, (3, 3, 3))) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_concretize_cast_like(): @@ -555,7 +553,7 @@ def test_concretize_cast_like(): expected = run_infer_type(relay.cast(data, "int32")) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_concretize_multiple(): @@ -580,14 +578,14 @@ def test_concretize_multiple(): expected = run_infer_type(ret_c) actual = run_opt_pass(ret, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + tvm.ir.assert_structural_equal(actual, expected) def test_simplify_mul_add(): def check_simple_fold(origin_exprs, expect_expr): for origin_expr in origin_exprs: simple_expr = run_opt_pass(origin_expr, transform.SimplifyExpr()) - assert tvm.ir.structural_equal(simple_expr, expect_expr) + tvm.ir.assert_structural_equal(simple_expr, expect_expr) n = 32 c1_val = np.random.uniform(size=n).astype("float32") @@ -670,7 +668,7 @@ def expected(c): for c in [1.0, 2.0, 2.5]: opt = run_opt_pass(before(c), transform.SimplifyExpr()) after = run_opt_pass(expected(c), transform.InferType()) - assert tvm.ir.structural_equal(opt, after) + tvm.ir.assert_structural_equal(opt, after) def test_simplify_dq_argmax(): @@ -686,7 +684,7 @@ def expected(): opt = run_opt_pass(before(), transform.SimplifyExpr()) after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(opt, after) + tvm.ir.assert_structural_equal(opt, after) def test_simplify_dq_argmin(): @@ -702,7 +700,7 @@ def expected(): opt = run_opt_pass(before(), transform.SimplifyExpr()) after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(opt, after) + tvm.ir.assert_structural_equal(opt, after) def test_simplify_dq_argsort(): @@ -718,7 +716,7 @@ def expected(): opt = run_opt_pass(before(), transform.SimplifyExpr()) after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(opt, after) + tvm.ir.assert_structural_equal(opt, after) def test_simplify_clip_cast(): @@ -797,9 +795,7 @@ def expected5(): ]: after = run_opt_pass(before, transform.SimplifyExpr()) expected = run_opt_pass(expected, transform.InferType()) - assert tvm.ir.structural_equal(after, expected), "\nafter: {} \nexpected: {}".format( - after, expected - ) + tvm.ir.assert_structural_equal(after, expected) def test_simplify_cast_clip(): @@ -842,9 +838,7 @@ def expected3(): ]: after = run_opt_pass(before, transform.SimplifyExpr()) expected = run_opt_pass(expected, transform.InferType()) - assert tvm.ir.structural_equal(after, expected), "\nafter: {} \nexpected: {}".format( - after, expected - ) + tvm.ir.assert_structural_equal(after, expected) def test_simplify_add(): @@ -859,7 +853,7 @@ def expected(): opt = run_opt_pass(before(), transform.SimplifyExpr()) ref = run_infer_type(expected()) - assert tvm.ir.structural_equal(opt, ref) + tvm.ir.assert_structural_equal(opt, ref) def test_binomials(): diff --git a/tests/python/relay/test_pass_simplify_inference.py b/tests/python/relay/test_pass_simplify_inference.py index 24a63e97b30e..42df54e5d2e7 100644 --- a/tests/python/relay/test_pass_simplify_inference.py +++ b/tests/python/relay/test_pass_simplify_inference.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from tvm.ir import IRModule, structural_equal +from tvm.ir import IRModule, assert_structural_equal from tvm import relay as rly from tvm.relay.transform import SimplifyInference, InferType @@ -72,7 +72,7 @@ def check(dim, axis, nstep): mod = simplify(mod) y1 = mod["main"].body - assert structural_equal(y1, y2, map_free_vars=True) + assert_structural_equal(y1, y2, map_free_vars=True) check(2, 1, 1) check(4, 1, 1) diff --git a/tests/python/relay/test_pass_split_args.py b/tests/python/relay/test_pass_split_args.py index 508f74f11269..04a3c5af1cd9 100644 --- a/tests/python/relay/test_pass_split_args.py +++ b/tests/python/relay/test_pass_split_args.py @@ -91,7 +91,7 @@ def expected(limit): limit = tvm.target.Target(target_name).max_function_args res = run_opt_pass(before(), transform.SplitArgs(limit)) exp = run_opt_pass(expected(limit), transform.InferType()) - assert tvm.ir.structural_equal(res, exp) + tvm.ir.assert_structural_equal(res, exp) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 70971d243c97..873124ebf13a 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -77,7 +77,7 @@ def test_order(): expected_output = relay.Let(b, y, expected_output) expected_output = relay.Let(a, x, expected_output) expected_output = run_opt_pass(expected_output, transform.InferType()) - assert tvm.ir.structural_equal(anf, expected_output) + tvm.ir.assert_structural_equal(anf, expected_output) def test_if(): @@ -94,7 +94,7 @@ def test_if(): expected_output = relay.Let(d, expected_output, d) expected_output = relay.Let(c, cond, expected_output) expected_output = run_opt_pass(expected_output, transform.InferType()) - assert tvm.ir.structural_equal(anf, expected_output) + tvm.ir.assert_structural_equal(anf, expected_output) def test_let_as_subexpr(): diff --git a/tests/python/relay/test_pass_to_basic_block_normal_form.py b/tests/python/relay/test_pass_to_basic_block_normal_form.py index 2a97e985d91d..5c852e970190 100644 --- a/tests/python/relay/test_pass_to_basic_block_normal_form.py +++ b/tests/python/relay/test_pass_to_basic_block_normal_form.py @@ -136,7 +136,7 @@ def expected(): } """ expected_output = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True) + tvm.ir.assert_structural_equal(bblock, expected_output, map_free_vars=True) def test_nested_if(): @@ -205,7 +205,7 @@ def expected(): } """ expected_output = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(bblock, expected_output, map_free_vars=True) + tvm.ir.assert_structural_equal(bblock, expected_output, map_free_vars=True) check_basic_block_normal_form(bblock) @@ -294,7 +294,7 @@ def test_let1(): %x """ opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm()) - assert tvm.ir.structural_equal(body, opt_body) + tvm.ir.assert_structural_equal(body, opt_body) check_basic_block_normal_form(opt_body) def test_let1_1(): @@ -303,7 +303,7 @@ def test_let1_1(): body = relay.Let(x, d, relay.add(x, x)) body = run_opt_pass(body, transform.InferType()) opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm()) - assert tvm.ir.structural_equal(body, opt_body) + tvm.ir.assert_structural_equal(body, opt_body) check_basic_block_normal_form(opt_body) def test_let2(): @@ -325,7 +325,7 @@ def expected(): opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm()) expected_body = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(opt_body, expected_body) + tvm.ir.assert_structural_equal(opt_body, expected_body) check_basic_block_normal_form(opt_body) def test_let3(): @@ -339,7 +339,7 @@ def test_let3(): body = relay.Let(y, c, body) body = run_opt_pass(body, transform.InferType()) opt_body = run_opt_pass(body, transform.ToBasicBlockNormalForm()) - assert tvm.ir.structural_equal(body, opt_body) + tvm.ir.assert_structural_equal(body, opt_body) check_basic_block_normal_form(opt_body) test_let1() @@ -424,14 +424,14 @@ def expected_if_expr(x): expected_body = expected_if_expr(x) bblock = run_opt_pass(body, [transform.ToBasicBlockNormalForm(), transform.InferType()]) expected_bblock = run_opt_pass(expected_body, transform.InferType()) - assert tvm.ir.structural_equal(bblock, expected_bblock, map_free_vars=True) + tvm.ir.assert_structural_equal(bblock, expected_bblock, map_free_vars=True) check_basic_block_normal_form(bblock) func = relay.Function([x], body) expected_func = relay.Function([x], expected_body) bblock = run_opt_pass(func, [transform.ToBasicBlockNormalForm(), transform.InferType()]) expected_bblock = run_opt_pass(expected_func, transform.InferType()) - assert tvm.ir.structural_equal(bblock, expected_bblock) + tvm.ir.assert_structural_equal(bblock, expected_bblock) check_basic_block_normal_form(bblock) diff --git a/tests/python/relay/test_prng.py b/tests/python/relay/test_prng.py index 7e62ee8a75c8..98b4396a51f7 100644 --- a/tests/python/relay/test_prng.py +++ b/tests/python/relay/test_prng.py @@ -106,7 +106,7 @@ def test_threefry_generate_infer(): rand1 = tvm.relay.random.threefry_generate(key, oshape) f = tvm.relay.Function([], rand1) f = run_infer_type(f) - assert tvm.ir.structural_equal(f.ret_type, expected_type) + tvm.ir.assert_structural_equal(f.ret_type, expected_type) def test_threefry_split_infer(): @@ -117,7 +117,7 @@ def test_threefry_split_infer(): out_keys = tvm.relay.random.threefry_split(key) f = tvm.relay.Function([], out_keys) f = run_infer_type(f) - assert tvm.ir.structural_equal(f.ret_type, expected_type) + tvm.ir.assert_structural_equal(f.ret_type, expected_type) def test_uniform_infer(): @@ -132,7 +132,7 @@ def test_uniform_infer(): rand1 = tvm.relay.random.uniform(key, oshape, odtype) f = tvm.relay.Function([], rand1) f = run_infer_type(f) - assert tvm.ir.structural_equal(f.ret_type, expected_type) + tvm.ir.assert_structural_equal(f.ret_type, expected_type) @pytest.mark.xfail(raises=tvm.error.TVMError) diff --git a/tests/python/relay/test_recast.py b/tests/python/relay/test_recast.py index 19803594c968..fea8a2d2b402 100644 --- a/tests/python/relay/test_recast.py +++ b/tests/python/relay/test_recast.py @@ -40,7 +40,7 @@ def expected(): pre = before() post = recast(pre, "int8", "int32") expected = expected() - assert tvm.ir.structural_equal(expected, post) + tvm.ir.assert_structural_equal(expected, post) def test_recast_medium(): @@ -71,7 +71,7 @@ def expected(): pre = before() post = recast(pre, "int8", "int32") expected = expected() - assert tvm.ir.structural_equal(expected, post) + tvm.ir.assert_structural_equal(expected, post) def test_recast_skip(): @@ -99,7 +99,7 @@ def expected(): pre = before() post = recast(pre, "int8", "int32", skip_layers=[0]) expected = expected() - assert tvm.ir.structural_equal(expected, post) + tvm.ir.assert_structural_equal(expected, post) def test_recast_concat(): @@ -123,7 +123,7 @@ def expected(): pre = before() post = recast(pre, "float16", "float32", ops=["concatenate"]) expected = expected() - assert tvm.ir.structural_equal(expected, post) + tvm.ir.assert_structural_equal(expected, post) def test_recast_relu(): @@ -151,7 +151,7 @@ def expected(): pre = before() post = recast(pre, "float16", "float16", ops=["nn.conv2d", "nn.relu"]) expected = expected() - assert tvm.ir.structural_equal(expected, post) + tvm.ir.assert_structural_equal(expected, post) if __name__ == "__main__": diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 4c97642498d9..ae5172f6caf0 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -163,7 +163,7 @@ def test_convert_single_conv(target_precision): expected_mod = tvm.relay.transform.InferType()(expected_mod) assert not tvm.ir.structural_equal(amp_mod, mod) - assert tvm.ir.structural_equal(amp_mod, expected_mod) + tvm.ir.assert_structural_equal(amp_mod, expected_mod) def test_convert_single_conv_fp64(): @@ -198,7 +198,7 @@ def test_convert_single_conv_fp64(): expected_mod = tvm.relay.transform.InferType()(expected_mod) assert not tvm.ir.structural_equal(amp_mod, mod) - assert tvm.ir.structural_equal(amp_mod, expected_mod) + tvm.ir.assert_structural_equal(amp_mod, expected_mod) def test_convert_conv_bn(target_precision): @@ -245,7 +245,7 @@ def test_convert_conv_bn(target_precision): expected_mod = tvm.IRModule.from_expr(bn[0]) expected_mod = tvm.relay.transform.InferType()(expected_mod) assert not tvm.ir.structural_equal(amp_mod, mod) - assert tvm.ir.structural_equal(amp_mod, expected_mod) + tvm.ir.assert_structural_equal(amp_mod, expected_mod) def test_do_not_convert_softmax(target_precision): @@ -257,7 +257,7 @@ def test_do_not_convert_softmax(target_precision): mod = tvm.relay.transform.InferType()(mod) out_mod = ToMixedPrecision(target_precision)(mod) orig_mod = tvm.relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(orig_mod, out_mod) + tvm.ir.assert_structural_equal(orig_mod, out_mod) def test_do_not_convert_arange(target_precision): @@ -267,7 +267,7 @@ def test_do_not_convert_arange(target_precision): mod = tvm.IRModule.from_expr(arange) out_mod = ToMixedPrecision(target_precision)(mod) orig_mod = tvm.relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(orig_mod, out_mod) + tvm.ir.assert_structural_equal(orig_mod, out_mod) def test_do_not_convert_summation(target_precision): @@ -284,7 +284,7 @@ def test_do_not_convert_summation(target_precision): mod = tvm.IRModule.from_expr(op(a)) out_mod = ToMixedPrecision(target_precision)(mod) orig_mod = tvm.relay.transform.InferType()(mod) - assert tvm.ir.structural_equal(orig_mod, out_mod) + tvm.ir.assert_structural_equal(orig_mod, out_mod) def test_green_gray_propagates_simple(target_precision): @@ -320,7 +320,7 @@ def test_green_gray_propagates_simple(target_precision): expected_mod = tvm.relay.transform.InferType()(expected_mod) assert not tvm.ir.structural_equal(amp_mod, mod) - assert tvm.ir.structural_equal(amp_mod, expected_mod) + tvm.ir.assert_structural_equal(amp_mod, expected_mod) def test_green_red_not_use_extraneous_cast(target_precision): @@ -382,7 +382,7 @@ def test_green_red_not_use_extraneous_cast(target_precision): expected_mod = tvm.IRModule.from_expr(result) expected_mod = InferType()(expected_mod) - assert tvm.ir.structural_equal(expected_mod, amp_mod) + tvm.ir.assert_structural_equal(expected_mod, amp_mod) def test_red_gray_propagates_simple(target_precision): @@ -401,7 +401,7 @@ def test_red_gray_propagates_simple(target_precision): mod, mod_params, mixed_precision_dtype=target_precision, atol=0.0, rtol=0.0 ) - assert tvm.ir.structural_equal(mod, output_mod) + tvm.ir.assert_structural_equal(mod, output_mod) def test_let_statement_simple(target_precision): @@ -450,7 +450,7 @@ def test_let_statement_simple(target_precision): expected_mod = tvm.IRModule.from_expr(let1) expected_mod = InferType()(expected_mod) - assert tvm.ir.structural_equal(expected_mod, output_mod) + tvm.ir.assert_structural_equal(expected_mod, output_mod) def test_where_simple(target_precision): @@ -476,7 +476,7 @@ def test_where_simple(target_precision): expected_mod = tvm.IRModule.from_expr(b) expected_mod = InferType()(expected_mod) - assert tvm.ir.structural_equal(expected_mod, output_mod) + tvm.ir.assert_structural_equal(expected_mod, output_mod) def test_batch_matmul_simple(target_precision): @@ -502,7 +502,7 @@ def test_batch_matmul_simple(target_precision): a = relay.nn.batch_matmul(data, weight, out_dtype=target_precision) expected_mod = tvm.IRModule.from_expr(a) expected_mod = InferType()(expected_mod) - assert tvm.ir.structural_equal(expected_mod, output_mod) + tvm.ir.assert_structural_equal(expected_mod, output_mod) def test_convert_follow_node_with_integer_arguments(target_precision): @@ -533,7 +533,7 @@ def test_convert_follow_node_with_integer_arguments(target_precision): take = relay.take(data, indices, axis=0) expected_mod = tvm.IRModule.from_expr(take) expected_mod = InferType()(expected_mod) - assert tvm.ir.structural_equal(expected_mod, output_mod) + tvm.ir.assert_structural_equal(expected_mod, output_mod) def test_clip(target_precision): @@ -555,7 +555,7 @@ def test_clip(target_precision): res = relay.clip(data, a_min=-128000, a_max=128000) expected_mod = tvm.IRModule.from_expr(res) expected_mod = InferType()(expected_mod) - assert tvm.ir.structural_equal(expected_mod, output_mod) + tvm.ir.assert_structural_equal(expected_mod, output_mod) def test_clip_with_pre_op(target_precision): @@ -582,7 +582,7 @@ def test_clip_with_pre_op(target_precision): res = relay.clip(res, a_min=-128000, a_max=128000) expected_mod = tvm.IRModule.from_expr(res) expected_mod = InferType()(expected_mod) - assert tvm.ir.structural_equal(expected_mod, output_mod) + tvm.ir.assert_structural_equal(expected_mod, output_mod) def test_loop(target_precision): @@ -616,7 +616,7 @@ def _body(i, st): # Create expected module expected_mod = InferType()(mod) - assert tvm.ir.structural_equal(expected_mod, output_mod) + tvm.ir.assert_structural_equal(expected_mod, output_mod) if __name__ == "__main__": diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index ec88143db6a6..f18994d52ce9 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -244,7 +244,7 @@ def test_free_expr(): x = relay.var("x", "float32") y = relay.add(x, x) yy = infer_expr(y) - assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True) + tvm.ir.assert_structural_equal(yy.args[0], x, map_free_vars=True) assert yy.checked_type == relay.scalar_type("float32") assert x.vid.same_as(yy.args[0].vid) diff --git a/tests/python/relay/utils/tag_span.py b/tests/python/relay/utils/tag_span.py index 77042be60285..3f9aaff3ee8d 100644 --- a/tests/python/relay/utils/tag_span.py +++ b/tests/python/relay/utils/tag_span.py @@ -91,7 +91,7 @@ def _verify_span(lhs, rhs): assert len(lhs_spans) == len(rhs_spans) for i in range(len(lhs_spans)): - assert tvm.ir.structural_equal(lhs_spans[i], rhs_spans[i]) + tvm.ir.assert_structural_equal(lhs_spans[i], rhs_spans[i]) def _verify_structural_equal_with_span(lhs, rhs, assert_mode=False, map_free_vars=False): @@ -103,6 +103,6 @@ def _verify_structural_equal_with_span(lhs, rhs, assert_mode=False, map_free_var if assert_mode: tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars) else: - assert tvm.ir.structural_equal(lhs, rhs, map_free_vars) + tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars) _verify_span(lhs, rhs) diff --git a/tests/python/te/test_te_hybrid_script.py b/tests/python/te/test_te_hybrid_script.py index d6b11785a4a3..862e80ffb6ce 100644 --- a/tests/python/te/test_te_hybrid_script.py +++ b/tests/python/te/test_te_hybrid_script.py @@ -189,7 +189,7 @@ def fanout(n, a): assert isinstance(ir, tvm.tir.For) assert ir.loop_var.name == "i" assert ir.min.value == 0 - assert tvm.ir.structural_equal(ir.extent, n - 3) + tvm.ir.assert_structural_equal(ir.extent, n - 3) # Check loopbody abody = ir.body assert isinstance(abody, tvm.tir.ProducerRealize) @@ -220,7 +220,7 @@ def fanout(n, a): assert value.a.indices[0].value == 0 assert value.b.producer.name == "a" assert len(value.b.indices) == 1 - assert tvm.ir.structural_equal(value.b.indices[0], ir.loop_var + jloop.loop_var) + tvm.ir.assert_structural_equal(value.b.indices[0], ir.loop_var + jloop.loop_var) divide = rbody[2] assert isinstance(divide, tvm.tir.ProducerStore) assert len(divide.indices) == 1 diff --git a/tests/python/te/test_te_schedule_tensorize.py b/tests/python/te/test_te_schedule_tensorize.py index ae5e7051bfba..79aecb78902a 100644 --- a/tests/python/te/test_te_schedule_tensorize.py +++ b/tests/python/te/test_te_schedule_tensorize.py @@ -108,13 +108,13 @@ def check(m, factor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[z], dom_map) - assert tvm.ir.structural_equal(out_dom[z.op.axis[0]].extent, factor) - assert tvm.ir.structural_equal(out_dom[z.op.axis[0]].min, xo * factor) - assert tvm.ir.structural_equal(in_dom.items()[0][1][0].extent, factor) + tvm.ir.assert_structural_equal(out_dom[z.op.axis[0]].extent, factor) + tvm.ir.assert_structural_equal(out_dom[z.op.axis[0]].min, xo * factor) + tvm.ir.assert_structural_equal(in_dom.items()[0][1][0].extent, factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[z], out_dom, in_dom, vadd) ana = tvm.arith.Analyzer() - assert tvm.ir.structural_equal(ana.simplify(body[0]), ana.simplify(vadd.op.body[0])) + tvm.ir.assert_structural_equal(ana.simplify(body[0]), ana.simplify(vadd.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [x, y, z]) @@ -133,7 +133,7 @@ def check_cache_write(m, factor): finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[z_global], dom_map) # outer loop var will be rebased, so min value is the new loop var and extent is 1 - assert tvm.ir.structural_equal(out_dom[xo].extent, 1) + tvm.ir.assert_structural_equal(out_dom[xo].extent, 1) assert isinstance(out_dom[xo].min, tvm.tir.Var) assert xo.var.name == out_dom[xo].min.name @@ -142,7 +142,7 @@ def check_cache_write(m, factor): ana = tvm.arith.Analyzer() vars = tvm.runtime.convert({xo.var: out_dom[xo].min}) vadd_body = tvm.tir.stmt_functor.substitute(vadd.op.body[0], vars) - assert tvm.ir.structural_equal(ana.simplify(body), ana.simplify(vadd_body)) + tvm.ir.assert_structural_equal(ana.simplify(body), ana.simplify(vadd_body)) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [x, y, z]) @@ -183,14 +183,14 @@ def check(factor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.ir.structural_equal(out_dom[x].extent, 1) - assert tvm.ir.structural_equal(out_dom[y].extent, factor) - assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) + tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[y].extent, factor) + tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) ana = tvm.arith.Analyzer() - assert tvm.ir.structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0])) + tvm.ir.assert_structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -207,13 +207,13 @@ def check_rfactor(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.ir.structural_equal(out_dom[x].extent, 1) - assert tvm.ir.structural_equal(out_dom[y].extent, factor) - assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) + tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[y].extent, factor) + tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) ana = tvm.arith.Analyzer() - assert tvm.ir.structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0])) + tvm.ir.assert_structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -230,13 +230,13 @@ def check_rfactor_no_reset(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.ir.structural_equal(out_dom[x].extent, 1) - assert tvm.ir.structural_equal(out_dom[y].extent, factor) - assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) + tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[y].extent, factor) + tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) ana = tvm.arith.Analyzer() - assert tvm.ir.structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0])) + tvm.ir.assert_structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -254,13 +254,13 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - assert tvm.ir.structural_equal(out_dom[x].extent, 1) - assert tvm.ir.structural_equal(out_dom[y].extent, factor) - assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) + tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[y].extent, factor) + tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) ana = tvm.arith.Analyzer() - assert tvm.ir.structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0])) + tvm.ir.assert_structural_equal(ana.simplify(body[0]), ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) diff --git a/tests/python/tir-base/test_tir_buffer.py b/tests/python/tir-base/test_tir_buffer.py index 78185510fbab..1ab7662b0b6b 100644 --- a/tests/python/tir-base/test_tir_buffer.py +++ b/tests/python/tir-base/test_tir_buffer.py @@ -39,7 +39,7 @@ def test_buffer_access_ptr(): n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1, 1]) aptr = Ab.access_ptr("rw") - assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m) + tvm.ir.assert_structural_equal(aptr.args[3], Ab.strides[0] * m) assert aptr.args[0].dtype == Ab.dtype assert aptr.args[4].value == Buffer.READ | Buffer.WRITE aptr = Ab.access_ptr("w") @@ -69,18 +69,18 @@ def test_buffer_access_ptr_extent(): n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32") aptr = Ab.access_ptr("rw") - assert tvm.ir.structural_equal(aptr.args[3], m * n) + tvm.ir.assert_structural_equal(aptr.args[3], m * n) aptr = Ab.access_ptr("rw", offset=100) - assert tvm.ir.structural_equal(aptr.args[3], m * n - 100) + tvm.ir.assert_structural_equal(aptr.args[3], m * n - 100) Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1, 1]) aptr = Ab.access_ptr("rw", offset=100) - assert tvm.ir.structural_equal(aptr.args[3], Ab.strides[0] * m - 100) + tvm.ir.assert_structural_equal(aptr.args[3], Ab.strides[0] * m - 100) # Test extent from input params aptr = Ab.access_ptr("rw", extent=200) - assert tvm.ir.structural_equal(aptr.args[3], 200) + tvm.ir.assert_structural_equal(aptr.args[3], 200) aptr = Ab.access_ptr("rw", offset=100, extent=100) - assert tvm.ir.structural_equal(aptr.args[3], 100) + tvm.ir.assert_structural_equal(aptr.args[3], 100) def test_buffer_vload(): @@ -109,7 +109,7 @@ def test_buffer_index_merge_mult_mod(): A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1)) def assert_simplified_equal(index_simplified, index_direct): - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( index_simplified, index_direct ), "index_simplified=%s, index_direct=%s" % (index_simplified, index_direct) diff --git a/tests/python/tir-base/test_tir_ops.py b/tests/python/tir-base/test_tir_ops.py index 8cffe8171a23..f2a18aeae519 100644 --- a/tests/python/tir-base/test_tir_ops.py +++ b/tests/python/tir-base/test_tir_ops.py @@ -79,7 +79,7 @@ def test_const_fold3(): ]: for v1 in [0, 1]: for v2 in [0, 1]: - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( tvm_func(tvm.tir.const(v1, "uint1"), tvm.tir.const(v2, "uint1")), tvm.tir.const(py_func(v1, v2), "uint1"), ) @@ -198,13 +198,13 @@ def test_if_then_else(): out = tvm.tir.if_then_else(cond, lhs, rhs) out2 = tvm.tir.if_then_else(not cond, rhs, lhs) out3 = tvm.tir.if_then_else(not cond, lhs, rhs) - assert tvm.ir.structural_equal(out, out2) == 1 + tvm.ir.assert_structural_equal(out, out2) == 1 if cond: - assert tvm.ir.structural_equal(out, lhs.astype(out_dtype)) == 1 - assert tvm.ir.structural_equal(out3, rhs.astype(out_dtype)) == 1 + tvm.ir.assert_structural_equal(out, lhs.astype(out_dtype)) == 1 + tvm.ir.assert_structural_equal(out3, rhs.astype(out_dtype)) == 1 else: - assert tvm.ir.structural_equal(out, rhs.astype(out_dtype)) == 1 - assert tvm.ir.structural_equal(out3, lhs.astype(out_dtype)) == 1 + tvm.ir.assert_structural_equal(out, rhs.astype(out_dtype)) == 1 + tvm.ir.assert_structural_equal(out3, lhs.astype(out_dtype)) == 1 elif cond.dtype == "bool": out = tvm.tir.if_then_else(cond, lhs, rhs) assert out.dtype == out_dtype diff --git a/tests/python/tir-schedule/test_tir_schedule_utilities.py b/tests/python/tir-schedule/test_tir_schedule_utilities.py index f7b0e672b23c..0ad05ea83288 100644 --- a/tests/python/tir-schedule/test_tir_schedule_utilities.py +++ b/tests/python/tir-schedule/test_tir_schedule_utilities.py @@ -290,7 +290,7 @@ def test_get_producers(use_block_name): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") block = "relu" if use_block_name else sch.get_block("relu") (producer,) = sch.get_producers(block) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( sch.get_sref(producer).stmt, sch.get_sref(sch.get_block("matmul")).stmt, ) @@ -301,7 +301,7 @@ def test_get_producers_multiple_buffer_depdencies(use_block_name): sch = tir.Schedule(mod=tuple_reduction, debug_mask="all") block = "T_add" if use_block_name else sch.get_block("T_add") (producer,) = sch.get_producers(block) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( sch.get_sref(producer).stmt, sch.get_sref(sch.get_block("data_red_temp")).stmt, ) @@ -311,7 +311,7 @@ def test_get_consumers(use_block_name): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") block = "matmul" if use_block_name else sch.get_block("matmul") (consumer,) = sch.get_consumers(block) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( sch.get_sref(consumer).stmt, sch.get_sref(sch.get_block("relu")).stmt, ) @@ -322,7 +322,7 @@ def test_get_consumers_multiple_buffer_depdencies(use_block_name): sch = tir.Schedule(mod=tuple_reduction, debug_mask="all") block = "data_red_temp" if use_block_name else sch.get_block("data_red_temp") (consumer,) = sch.get_consumers(block) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( sch.get_sref(consumer).stmt, sch.get_sref(sch.get_block("T_add")).stmt, ) diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index e64d3c74932b..f773e56e5ccb 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -101,7 +101,7 @@ def test_cse(): # And this is the name and value of this variable cse_var_1 = body.var # Keep the variable accessible for later checking the replacements assert body.var.name == "cse_var_1" - assert tvm.ir.structural_equal(body.value, z1 + z2) + tvm.ir.assert_structural_equal(body.value, z1 + z2) assert isinstance(body.body, tvm.tir.SeqStmt) body = body.body @@ -126,19 +126,19 @@ def test_cse(): # And this is the name and value of this variable cse_var_2 = body.var # Keep the variable accessible for later checking the replacements assert body.var.name == "cse_var_2" - assert tvm.ir.structural_equal(body.value, x + y) + tvm.ir.assert_structural_equal(body.value, x + y) body = body.body body.var.name == "a" # Check that the replacement has been done correctly! - assert tvm.ir.structural_equal(body.value, cse_var_2 + cse_var_1) + tvm.ir.assert_structural_equal(body.value, cse_var_2 + cse_var_1) body = body.body body.var.name == "b" # Check that the replacement has been done correctly! - assert tvm.ir.structural_equal(body.value, cse_var_2 + z3) + tvm.ir.assert_structural_equal(body.value, cse_var_2 + z3) assert isinstance(body.body, tvm.tir.BufferStore) @@ -201,7 +201,7 @@ def test_cse_ifNode_1(): # The let-in introduced by the CSE should appear now, inside the Then branch of the If node assert body.var.name == "cse_var_1" # and it should contain the expression (y+z) that was redundant - assert tvm.ir.structural_equal(body.value, y + z) + tvm.ir.assert_structural_equal(body.value, y + z) # Second test for if nodes : Some duplicated computations appear in both the Then and Else branch. @@ -252,7 +252,7 @@ def test_cse_ifNode_2(): # The let-in introduced by the CSE should appear now, at the toplevel (i.e. before the If) assert body.var.name == "cse_var_1" # and it should contain the expression (y+z) that was redundant - assert tvm.ir.structural_equal(body.value, y + z) + tvm.ir.assert_structural_equal(body.value, y + z) # ------------------------------------------------------------------------------------------------- @@ -294,7 +294,7 @@ def test_cse_cascade(): cse_var_2 = body.var # Keep the variable accessible for later checking the replacements assert body.var.name == "cse_var_2" # and it should contain the expression (x+y) - assert tvm.ir.structural_equal(body.value, (x + y)) + tvm.ir.assert_structural_equal(body.value, (x + y)) body = body.body @@ -304,7 +304,7 @@ def test_cse_cascade(): cse_var_1 = body.var # Keep the variable accessible for later checking the replacements assert body.var.name == "cse_var_1" # and it should contain the expression cse_var_2+z - assert tvm.ir.structural_equal(body.value, cse_var_2 + z) + tvm.ir.assert_structural_equal(body.value, cse_var_2 + z) body = body.body @@ -317,9 +317,9 @@ def test_cse_cascade(): store2 = body[1] store3 = body[2] - assert tvm.ir.structural_equal(store1.value, cse_var_1) - assert tvm.ir.structural_equal(store2.value, cse_var_1) - assert tvm.ir.structural_equal(store3.value, cse_var_2) + tvm.ir.assert_structural_equal(store1.value, cse_var_1) + tvm.ir.assert_structural_equal(store2.value, cse_var_1) + tvm.ir.assert_structural_equal(store3.value, cse_var_2) # ----------------------------------------------------------------------------------------- @@ -342,7 +342,7 @@ def test_no_normalization_without_commoning(): body = body["main"].body # Gets the body of the main, i.e. the full statement assert body.var.name == "a" - assert tvm.ir.structural_equal(body.value, x + (y + z)) + tvm.ir.assert_structural_equal(body.value, x + (y + z)) # ------------------------------------------------- diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py b/tests/python/tir-transform/test_tir_transform_loop_partition.py index 2b3f73e24f88..6468ac5396ef 100644 --- a/tests/python/tir-transform/test_tir_transform_loop_partition.py +++ b/tests/python/tir-transform/test_tir_transform_loop_partition.py @@ -567,7 +567,7 @@ def test_explicit_partition_hint(): mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.LoopPartition()(mod) mod = tvm.tir.transform.Simplify()(mod) - assert tvm.ir.structural_equal(mod["main"], partitioned_concat) + tvm.ir.assert_structural_equal(mod["main"], partitioned_concat) def partition_from_scheduled_tir(prim_func, pass_cfg): @@ -629,7 +629,7 @@ def test_condition_mutually_exclusive(): mod = partition_from_scheduled_tir( concat_func_3, {"tir.LoopPartition": {"partition_const_loop": True}} ) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( mod["main"], partitioned_concat_3.with_attr("global_symbol", "main") ) @@ -681,7 +681,7 @@ def partitioned_main( mod = tvm.tir.transform.UnrollLoop()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) mod = tvm.tir.transform.Simplify()(mod) - assert tvm.ir.structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) def test_loop_partition_recursive_unroll_hint(): @@ -750,7 +750,7 @@ def partitioned_main(): } }, ) - assert tvm.ir.structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) def test_loop_partition_keep_loop_annotations(): @@ -784,7 +784,7 @@ def after(A: T.Buffer(160, "int32"), B: T.Buffer(160, "int32")) -> None: } }, ) - assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], after.with_attr("global_symbol", "main")) def test_loop_partition_with_unit_loop_in_condition(): @@ -832,7 +832,7 @@ def after( } }, ) - assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main")) + tvm.ir.assert_structural_equal(mod["main"], after.with_attr("global_symbol", "main")) @T.prim_func @@ -1059,7 +1059,7 @@ def test_single_point_partition(origin, expected): } }, ) - assert tvm.ir.structural_equal(mod["main"], expected) + tvm.ir.assert_structural_equal(mod["main"], expected) if __name__ == "__main__": diff --git a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py b/tests/python/tir-transform/test_tir_transform_prim_func_pass.py index 647c44631312..553c7457708c 100644 --- a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py +++ b/tests/python/tir-transform/test_tir_transform_prim_func_pass.py @@ -42,7 +42,7 @@ def transform_function(self, func, mod, ctx): mod = tvm.IRModule({"main": func}) mod = TestReplaceFunc(new_func)(mod) - assert tvm.ir.structural_equal(mod["main"].body, new_func.body) + tvm.ir.assert_structural_equal(mod["main"].body, new_func.body) def test_cow_pass(): diff --git a/tests/python/tir-transform/test_transform_default_gpu_schedule.py b/tests/python/tir-transform/test_transform_default_gpu_schedule.py index 63809beade8a..0a648338490c 100644 --- a/tests/python/tir-transform/test_transform_default_gpu_schedule.py +++ b/tests/python/tir-transform/test_transform_default_gpu_schedule.py @@ -451,7 +451,7 @@ def full( target = tvm.target.Target("nvidia/geforce-rtx-3070") with target, tvm.transform.PassContext(opt_level=3): After = DefaultGPUSchedule()(Before) - assert tvm.ir.structural_equal(After, Expected) + tvm.ir.assert_structural_equal(After, Expected) def test_add_on_metal(): diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 1e595c8441b2..9bc9800c1cb8 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -1045,20 +1045,20 @@ def _flatten_loop(src_coeff, dst_coeff, extents): assert len(src_coeff) > 1 assert len(dst_coeff) > 1 assert len(extents) != 0 - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0 ) - assert tvm.ir.structural_equal( + tvm.ir.assert_structural_equal( analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0 ) - assert tvm.ir.structural_equal(src_coeff[-2], 1) - assert tvm.ir.structural_equal(dst_coeff[-2], 1) + tvm.ir.assert_structural_equal(src_coeff[-2], 1) + tvm.ir.assert_structural_equal(dst_coeff[-2], 1) if env.BATCH > 1: assert len(src_coeff) > 2 assert len(dst_coeff) > 2 assert len(extents) > 1 - assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT) - assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT) + tvm.ir.assert_structural_equal(src_coeff[-3], env.BLOCK_OUT) + tvm.ir.assert_structural_equal(dst_coeff[-3], env.BLOCK_OUT) # Apply tensorization of the loop coefficients src_offset = src_coeff[-1]