diff --git a/tests/common/quantization/metatypes.py b/tests/common/quantization/metatypes.py index 8b8255c59a7..d4f43c7cf41 100644 --- a/tests/common/quantization/metatypes.py +++ b/tests/common/quantization/metatypes.py @@ -36,11 +36,13 @@ class BatchNormTestMetatype(TestMetatype): @METATYPES_FOR_TEST.register() class Conv2dTestMetatype(TestMetatype): name = "conv2d" + input_edges_num_expected = 2 @METATYPES_FOR_TEST.register() class MatMulTestMetatype(TestMetatype): name = "matmul" + input_edges_num_expected = 2 @METATYPES_FOR_TEST.register() @@ -76,6 +78,7 @@ class CatTestMetatype(TestMetatype): @METATYPES_FOR_TEST.register() class LinearTestMetatype(TestMetatype): name = "linear" + input_edges_num_expected = 2 @METATYPES_FOR_TEST.register() @@ -96,11 +99,13 @@ class IdentityTestMetatype(TestMetatype): @METATYPES_FOR_TEST.register() class ReshapeTestMetatype(TestMetatype): name = "reshape" + input_edges_num_expected = 2 @METATYPES_FOR_TEST.register() class QuantizerTestMetatype(TestMetatype): name = "quantizer" + input_edges_num_expected = 2 @METATYPES_FOR_TEST.register() @@ -116,6 +121,7 @@ class ReluTestMetatype(TestMetatype): @METATYPES_FOR_TEST.register() class AddTestMetatype(TestMetatype): name = "add" + input_edges_num_expected = 2 @METATYPES_FOR_TEST.register() @@ -126,16 +132,19 @@ class ShapeOfTestMetatype(TestMetatype): @METATYPES_FOR_TEST.register() class PowerTestMetatype(TestMetatype): name = "power" + input_edges_num_expected = 2 @METATYPES_FOR_TEST.register() class MultiplyTestMetatype(TestMetatype): name = "multiply" + input_edges_num_expected = 2 @METATYPES_FOR_TEST.register() class InterpolateTestMetatype(TestMetatype): name = "interpolate" + input_edges_num_expected = 3 @METATYPES_FOR_TEST.register() diff --git a/tests/common/quantization/test_quantizer_removal.py b/tests/common/quantization/test_quantizer_removal.py index 159095fd0ae..fe42596bf0d 100644 --- a/tests/common/quantization/test_quantizer_removal.py +++ b/tests/common/quantization/test_quantizer_removal.py @@ -18,6 +18,7 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph.layer_attributes import Dtype from nncf.common.quantization.quantizer_removal import find_quantizer_nodes_to_cut +from nncf.quantization.passes import filter_constant_nodes from nncf.quantization.passes import remove_shapeof_subgraphs from tests.common.quantization.metatypes import CONSTANT_METATYPES from tests.common.quantization.metatypes import METATYPES_FOR_TEST @@ -226,7 +227,8 @@ def create_test_params(): @pytest.mark.parametrize("nncf_graph,test_case", create_test_params()) def test_find_quantizer_nodes_to_cut(nncf_graph: NNCFGraph, test_case: TestCase): quantizer_node = nncf_graph.get_node_by_name(test_case.node_name) - nncf_graph_without_shapeof = remove_shapeof_subgraphs(deepcopy(nncf_graph), SHAPEOF_METATYPES) + nncf_graph_without_shapeof = filter_constant_nodes(deepcopy(nncf_graph), CONSTANT_METATYPES) + nncf_graph_without_shapeof = remove_shapeof_subgraphs(nncf_graph_without_shapeof, SHAPEOF_METATYPES) nodes, ops = find_quantizer_nodes_to_cut( nncf_graph_without_shapeof, quantizer_node,