Skip to content

Commit

Permalink
Fix common pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 16, 2023
1 parent 81ad932 commit c3efc6a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
9 changes: 9 additions & 0 deletions tests/common/quantization/metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ class BatchNormTestMetatype(TestMetatype):
@METATYPES_FOR_TEST.register()
class Conv2dTestMetatype(TestMetatype):
name = "conv2d"
input_edges_num_expected = 2


@METATYPES_FOR_TEST.register()
class MatMulTestMetatype(TestMetatype):
name = "matmul"
input_edges_num_expected = 2


@METATYPES_FOR_TEST.register()
Expand Down Expand Up @@ -76,6 +78,7 @@ class CatTestMetatype(TestMetatype):
@METATYPES_FOR_TEST.register()
class LinearTestMetatype(TestMetatype):
name = "linear"
input_edges_num_expected = 2


@METATYPES_FOR_TEST.register()
Expand All @@ -96,11 +99,13 @@ class IdentityTestMetatype(TestMetatype):
@METATYPES_FOR_TEST.register()
class ReshapeTestMetatype(TestMetatype):
name = "reshape"
input_edges_num_expected = 2


@METATYPES_FOR_TEST.register()
class QuantizerTestMetatype(TestMetatype):
name = "quantizer"
input_edges_num_expected = 2


@METATYPES_FOR_TEST.register()
Expand All @@ -116,6 +121,7 @@ class ReluTestMetatype(TestMetatype):
@METATYPES_FOR_TEST.register()
class AddTestMetatype(TestMetatype):
name = "add"
input_edges_num_expected = 2


@METATYPES_FOR_TEST.register()
Expand All @@ -126,16 +132,19 @@ class ShapeOfTestMetatype(TestMetatype):
@METATYPES_FOR_TEST.register()
class PowerTestMetatype(TestMetatype):
name = "power"
input_edges_num_expected = 2


@METATYPES_FOR_TEST.register()
class MultiplyTestMetatype(TestMetatype):
name = "multiply"
input_edges_num_expected = 2


@METATYPES_FOR_TEST.register()
class InterpolateTestMetatype(TestMetatype):
name = "interpolate"
input_edges_num_expected = 3


@METATYPES_FOR_TEST.register()
Expand Down
4 changes: 3 additions & 1 deletion tests/common/quantization/test_quantizer_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.common.graph import NNCFGraph
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.quantization.quantizer_removal import find_quantizer_nodes_to_cut
from nncf.quantization.passes import filter_constant_nodes
from nncf.quantization.passes import remove_shapeof_subgraphs
from tests.common.quantization.metatypes import CONSTANT_METATYPES
from tests.common.quantization.metatypes import METATYPES_FOR_TEST
Expand Down Expand Up @@ -226,7 +227,8 @@ def create_test_params():
@pytest.mark.parametrize("nncf_graph,test_case", create_test_params())
def test_find_quantizer_nodes_to_cut(nncf_graph: NNCFGraph, test_case: TestCase):
quantizer_node = nncf_graph.get_node_by_name(test_case.node_name)
nncf_graph_without_shapeof = remove_shapeof_subgraphs(deepcopy(nncf_graph), SHAPEOF_METATYPES)
nncf_graph_without_shapeof = filter_constant_nodes(deepcopy(nncf_graph), CONSTANT_METATYPES)
nncf_graph_without_shapeof = remove_shapeof_subgraphs(nncf_graph_without_shapeof, SHAPEOF_METATYPES)
nodes, ops = find_quantizer_nodes_to_cut(
nncf_graph_without_shapeof,
quantizer_node,
Expand Down

0 comments on commit c3efc6a

Please sign in to comment.