Skip to content

Commit

Permalink
[UnitTests] Use tvm.ir.assert_structural_equal whenever possible (#17092
Browse files Browse the repository at this point in the history
)

* [UnitTests] Use tvm.ir.assert_structural_equal whenever possible

Prior to commit, many unit tests were implemented as `assert
tvm.ir.structural_equal(output, expected)`.  While this is correct, it
doesn't provide much information when the test fails.  The
`tvm.ir.assert_structural_equal` method performs the equivalent check,
but displays the exact location where a mismatch occurs.

This commit replaces all use of `assert tvm.ir.structural_equal` with
`tvm.ir.assert_structural_equal`.

* fix unit tests
  • Loading branch information
Lunderberg authored Jun 14, 2024
1 parent 4ecae58 commit 292ecfd
Show file tree
Hide file tree
Showing 84 changed files with 525 additions and 573 deletions.
2 changes: 1 addition & 1 deletion tests/python/arith/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_reduce_combiner_simplify():

# Check that the remaining components are the expected ones.
for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
assert tvm.ir.structural_equal(lhs, rhs)
tvm.ir.assert_structural_equal(lhs, rhs)

# Test that components with side effects are not removed
dummy = tvm.ir.GlobalVar("dummy")
Expand Down
2 changes: 1 addition & 1 deletion tests/python/arith/test_arith_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_simplify_reshape_flattened_index():
ana.bind(i1, tvm.ir.Range(0, 3))

i_flattened = i0 * 3 + i1
assert tvm.ir.structural_equal(
tvm.ir.assert_structural_equal(
ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4),
i_flattened,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_primary_operands_all_scalars():
mod = relay.transform.InferType()(mod)
mod = ScalarToTensorConstants()(mod)
new_mod = relay.transform.InferType()(mod)
assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body)
tvm.ir.assert_structural_equal(mod[global_var].body, new_mod[global_var].body)


@tvm.testing.requires_cmsisnn
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_all_primary_operands_tensor_constants():
mod = relay.transform.InferType()(mod)
mod = ScalarToTensorConstants()(mod)
new_mod = relay.transform.InferType()(mod)
assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body)
tvm.ir.assert_structural_equal(mod[global_var].body, new_mod[global_var].body)


@tvm.testing.requires_cmsisnn
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_duplicate_constant_arguments():
mod = relay.transform.InferType()(mod)
mod = ScalarToTensorConstants()(mod)
new_mod = relay.transform.InferType()(mod)
assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body)
tvm.ir.assert_structural_equal(mod[global_var].body, new_mod[global_var].body)


@tvm.testing.requires_cmsisnn
Expand Down Expand Up @@ -329,7 +329,7 @@ def get_mod():

expected = get_mod()["external_function"].body
actual = ScalarToTensorConstants()(get_mod())["external_function"].body
assert tvm.ir.structural_equal(expected, actual)
tvm.ir.assert_structural_equal(expected, actual)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_coreml_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_annotate():
mod = transform.PartitionGraph()(mod)

expected = _create_graph_annotated()
assert tvm.ir.structural_equal(mod, expected, map_free_vars=True)
tvm.ir.assert_structural_equal(mod, expected, map_free_vars=True)


@pytest.mark.skipif(not _has_xcode(), reason="Xcode is not available")
Expand Down
16 changes: 3 additions & 13 deletions tests/python/contrib/test_ethosn/test_convert_equivalents.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,6 @@
from .test_addition import _get_addition_qnn_params


def _assert_structural_equal(a, b):
"""Check structural equality of two Relay expressions."""
reason = (
"Actual and expected relay functions are not equal. "
"ConvertEquivalents is not correctly transforming the input "
"graph."
)
assert tvm.ir.structural_equal(a, b), reason


@requires_ethosn
@pytest.mark.parametrize("dtype", ["uint8", "int8"])
@pytest.mark.parametrize("shape,channels", [((1, 4, 4, 8), 8), ((1, 16, 12, 4), 4)])
Expand Down Expand Up @@ -114,7 +104,7 @@ def expected():
mod = before()
mod = ConvertEquivalents()(mod)
expected_mod = expected()
_assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"])
tvm.ir.assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"])


@requires_ethosn
Expand Down Expand Up @@ -221,7 +211,7 @@ def expected():
mod = before()
mod = ConvertEquivalents()(mod)
expected_mod = expected()
_assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"])
tvm.ir.assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"])


@requires_ethosn
Expand Down Expand Up @@ -438,7 +428,7 @@ def expected():
mod = before()
mod = ConvertEquivalents()(mod)
expected_mod = expected()
_assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"])
tvm.ir.assert_structural_equal(mod["ethos-n_0"], expected_mod["ethos-n_0"])


@requires_ethosn
Expand Down
22 changes: 8 additions & 14 deletions tests/python/contrib/test_ethosn/test_inline_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,6 @@
from . import infrastructure as tei


def _assert_structural_equal(a, b):
"""Check structural equality of two Relay expressions."""
reason = (
"Actual and expected relay functions are not equal. "
"InlineNonComputeIntensiveSubgraphs is not correctly "
"transforming the input graph."
)
assert tvm.ir.structural_equal(a, b, map_free_vars=True), reason


@requires_ethosn
def test_single_reshape():
"""Check that a single reshape is inlined correctly."""
Expand All @@ -57,7 +47,7 @@ def expected():
mod = before()
mod = InlineNonComputeIntensivePartitions()(mod)
expected_mod = expected()
_assert_structural_equal(mod, expected_mod)
tvm.ir.assert_structural_equal(mod, expected_mod)


@requires_ethosn
Expand Down Expand Up @@ -86,7 +76,7 @@ def expected():
mod = before()
mod = InlineNonComputeIntensivePartitions()(mod)
expected_mod = expected()
_assert_structural_equal(mod, expected_mod)
tvm.ir.assert_structural_equal(mod, expected_mod)


@requires_ethosn
Expand All @@ -105,7 +95,7 @@ def before():
mod = before()
transformed_mod = InlineNonComputeIntensivePartitions()(mod)
for global_var in mod.get_global_vars():
_assert_structural_equal(mod[global_var], transformed_mod[global_var])
tvm.ir.assert_structural_equal(mod[global_var], transformed_mod[global_var])


@requires_ethosn
Expand Down Expand Up @@ -164,4 +154,8 @@ def expected():
mod = InlineNonComputeIntensivePartitions()(mod)
expected_mod = expected()
for global_var in mod.get_global_vars():
_assert_structural_equal(mod[global_var.name_hint], expected_mod[global_var.name_hint])
tvm.ir.assert_structural_equal(
mod[global_var.name_hint],
expected_mod[global_var.name_hint],
map_free_vars=True,
)
4 changes: 2 additions & 2 deletions tests/python/contrib/test_ethosu/test_extract_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _expected():

func, const = _get_func()
new_func, const_dict = extract_constants(func)
assert tvm.ir.structural_equal(new_func, _expected())
tvm.ir.assert_structural_equal(new_func, _expected())
assert 1 in const_dict
assert (const_dict[1] == const.data.asnumpy()).all()

Expand Down Expand Up @@ -89,7 +89,7 @@ def _expected():

func, consts = _get_func()
new_func, const_dict = extract_constants(func)
assert tvm.ir.structural_equal(new_func, _expected())
tvm.ir.assert_structural_equal(new_func, _expected())
for i, const in enumerate(consts):
assert i + 2 in const_dict
assert (const_dict[i + 2] == consts[i].data.asnumpy()).all()
Expand Down
36 changes: 13 additions & 23 deletions tests/python/contrib/test_ethosu/test_identity_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,6 @@ def _optimize(func, optimize=True):
return entry if isinstance(func, relay.Function) else entry.body


def _assert_structural_equal(a, b):
"""Check structural equality of two Relay expressions."""
reason = (
"Actual and expected relay functions are not equal. "
"IdentityOptimizer is not correctly removing redundant "
"identity operations."
)
assert tvm.ir.structural_equal(a, b), reason


def test_simple_reshape_identity_removal():
"""Check identity is removed when there is a reshape in
the graph and a compute operation follows."""
Expand All @@ -70,7 +60,7 @@ def get_graph(get_expected=False):

actual = _optimize(get_graph())
expected = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_simple_strided_slice_identity_removal():
Expand All @@ -90,7 +80,7 @@ def get_graph(get_expected=False):

actual = _optimize(get_graph())
expected = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_no_identity():
Expand All @@ -108,7 +98,7 @@ def get_graph():

actual = _optimize(get_graph())
expected = _optimize(get_graph(), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_reshape_last():
Expand All @@ -123,7 +113,7 @@ def get_graph():

actual = _optimize(get_graph())
expected = _optimize(get_graph(), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_requantize_identity_no_removal():
Expand All @@ -140,7 +130,7 @@ def get_graph():

actual = _optimize(get_graph())
expected = _optimize(get_graph(), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_activation_identity_no_removal():
Expand All @@ -155,7 +145,7 @@ def get_graph():

actual = _optimize(get_graph())
expected = _optimize(get_graph(), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_multiple_output_identity():
Expand All @@ -172,7 +162,7 @@ def get_graph(get_expected=False):

actual = _optimize(get_graph())
expected = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_many_output_identity():
Expand All @@ -195,7 +185,7 @@ def get_graph(get_expected=False):

actual = _optimize(get_graph())
expected = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_identity_before_concatenate_no_removal():
Expand All @@ -215,7 +205,7 @@ def get_graph():

actual = _optimize(get_graph())
expected = _optimize(get_graph(), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_identity_removal_with_multiple_transform_ops():
Expand All @@ -235,7 +225,7 @@ def get_graph(get_expected=False):

actual = _optimize(get_graph())
expected = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_identity_removal_on_binary_elementwise():
Expand All @@ -252,7 +242,7 @@ def get_graph(get_expected=False):

actual = _optimize(get_graph())
expected = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_identity_single_removal_on_binary_elementwise():
Expand All @@ -270,7 +260,7 @@ def get_graph(get_expected=False):

actual = _optimize(get_graph())
expected = _optimize(get_graph(get_expected=True), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_multiple_transform_ops_with_reduction_in_dimensionality():
Expand All @@ -289,7 +279,7 @@ def get_graph():

actual = _optimize(get_graph())
expected = _optimize(get_graph(), optimize=False)
_assert_structural_equal(actual, expected)
tvm.ir.assert_structural_equal(actual, expected)


def test_identity_optimizer_runs_in_compilation_pipeline():
Expand Down
Loading

0 comments on commit 292ecfd

Please sign in to comment.