Skip to content

Commit

Permalink
Rollback ONNX/PT changes
Browse files Browse the repository at this point in the history
  • Loading branch information
KodiaqQ committed Nov 5, 2024
1 parent 54c750d commit 6a6434d
Show file tree
Hide file tree
Showing 12 changed files with 2 additions and 267 deletions.
12 changes: 0 additions & 12 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,18 +677,6 @@ class ONNXErfMetatype(ONNXOpMetatype):
op_names = ["Erf"]


@ONNX_OPERATION_METATYPES.register()
class ONNXCosMetatype(ONNXOpMetatype):
name = "CosOp"
op_names = ["Cos"]


@ONNX_OPERATION_METATYPES.register()
class ONNXSinMetatype(ONNXOpMetatype):
name = "SinOp"
op_names = ["Sin"]


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
22 changes: 0 additions & 22 deletions nncf/onnx/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,25 +157,3 @@ def create_se_block() -> GraphPattern:
multiply_node = node_id
pattern.add_edge(non_pattern_node, multiply_node)
return pattern


@ONNX_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE)
def create_rope() -> GraphPattern:
pattern = GraphPattern()
matmul_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.ONNXMatMulMetatype}
)
transpose_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.ONNXTransposeMetatype}
)
concat_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.ONNXConcatMetatype}
)
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.ONNXCosMetatype})
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.ONNXSinMetatype})

pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(concat_node, cos_node)
pattern.add_edge(concat_node, sin_node)
return pattern
12 changes: 0 additions & 12 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,18 +1110,6 @@ class PTScaledDotProductAttentionMetatype(PTOperatorMetatype):
target_input_ports = [0, 1]


@PT_OPERATOR_METATYPES.register()
class PTCosMetatype(PTOperatorMetatype):
name = "CosOp"
module_to_function_names = {NamespaceTarget.TORCH: ["cos"]}


@PT_OPERATOR_METATYPES.register()
class PTSinMetatype(PTOperatorMetatype):
name = "SinOp"
module_to_function_names = {NamespaceTarget.TORCH: ["sin"]}


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
20 changes: 0 additions & 20 deletions nncf/torch/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,23 +227,3 @@ def get_se_block_with_bias_and_reshape() -> GraphPattern:
main_pattern.add_pattern_alternative(get_se_block_with_reshape())
main_pattern.add_pattern_alternative(get_se_block_with_bias_and_reshape())
return main_pattern


@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.ROPE)
def create_rope() -> GraphPattern:
pattern = GraphPattern()
matmul_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: om.PTMatMulMetatype}
)
transpose_node = pattern.add_node(
**{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: om.PTTransposeMetatype}
)
concat_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: om.PTCatMetatype})
cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: om.PTCosMetatype})
sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: om.PTSinMetatype})

pattern.add_edge(matmul_node, transpose_node)
pattern.add_edge(transpose_node, concat_node)
pattern.add_edge(concat_node, cos_node)
pattern.add_edge(concat_node, sin_node)
return pattern
19 changes: 0 additions & 19 deletions tests/cross_fw/test_templates/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,22 +466,3 @@ def __init__(self):

def forward(self, x):
return self.conv(x)


class RoPEModel(nn.Module):
INPUT_SIZE = [1, 10]

def __init__(self):
super().__init__()
with set_torch_seed():
self.data = torch.randn([5])

def forward(self, x):
x = torch.unsqueeze(x, dim=0)
reshape = torch.reshape(self.data, [1, 5, 1])
x = torch.matmul(reshape, x)
x = torch.transpose(x, 2, 1)
x = torch.cat([x], dim=2)
x1 = x.sin()
x2 = x.cos()
return x1, x2

This file was deleted.

123 changes: 0 additions & 123 deletions tests/onnx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,126 +1757,3 @@ def __init__(self):
model = onnx.helper.make_model(graph_def, opset_imports=[op])
onnx.checker.check_model(model)
super().__init__(model, [input_shape], "unified_embedding_model.dot")


class RoPEModel(ONNXReferenceModel):
def __init__(self):
rng = np.random.default_rng(seed=0)

input_shape = [1, 10]
input_name = "model_in"
model_in = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.INT64, input_shape)

output_shape = [1, 5, 10]
cos_out_name = "cos_out"
cos_out = onnx.helper.make_tensor_value_info(cos_out_name, onnx.TensorProto.FLOAT, output_shape)
sin_out_name = "sin_out"
sin_out = onnx.helper.make_tensor_value_info(sin_out_name, onnx.TensorProto.FLOAT, output_shape)

unsqueeze_out_name = "un_out"
unsqueeze_tensor_name = "un_tensor"
unsqueeze_tensor = create_initializer_tensor(
name=unsqueeze_tensor_name, tensor_array=np.int64([2]), data_type=onnx.TensorProto.INT64
)
unsqueeze_node = onnx.helper.make_node(
name="unsqueeze",
op_type="Unsqueeze",
inputs=[input_name, unsqueeze_tensor_name],
outputs=[unsqueeze_out_name],
)

cast_out_name = "cast_out"
cast_node = onnx.helper.make_node(
name="cast",
op_type="Cast",
to=onnx.TensorProto.FLOAT,
inputs=[unsqueeze_out_name],
outputs=[cast_out_name],
)

reshape_shape_name = "re_shape"
reshape_shape = create_initializer_tensor(
name=reshape_shape_name,
tensor_array=np.array([1, 5]).astype(np.int64),
data_type=onnx.TensorProto.INT64,
)
reshape_tensor_name = "re_tensor"
reshape_tensor = create_initializer_tensor(
name=reshape_tensor_name,
tensor_array=rng.uniform(0, 1, (5)).astype(np.float32),
data_type=onnx.TensorProto.FLOAT,
)
reshape_out_name = "re_out"
reshape_node = onnx.helper.make_node(
name="reshape",
op_type="Reshape",
inputs=[reshape_tensor_name, reshape_shape_name],
outputs=[reshape_out_name],
)

matmul_out_name = "mm_out"
matmul_node = onnx.helper.make_node(
name="matmul",
op_type="MatMul",
inputs=[cast_out_name, reshape_out_name],
outputs=[matmul_out_name],
)

transpose_out_name = "trans_out"
transpose_node = onnx.helper.make_node(
name="transpose",
op_type="Transpose",
inputs=[matmul_out_name],
outputs=[transpose_out_name],
perm=[0, 2, 1],
)

concat_out_name = "concat_out"
concat_node = onnx.helper.make_node(
name="concat",
op_type="Concat",
inputs=[transpose_out_name],
outputs=[concat_out_name],
axis=-1,
)

sin_node = onnx.helper.make_node(
name="sin",
op_type="Sin",
inputs=[concat_out_name],
outputs=[sin_out_name],
)

cos_node = onnx.helper.make_node(
name="cos",
op_type="Cos",
inputs=[concat_out_name],
outputs=[cos_out_name],
)

graph_def = onnx.helper.make_graph(
nodes=[
unsqueeze_node,
cast_node,
reshape_node,
matmul_node,
transpose_node,
concat_node,
sin_node,
cos_node,
],
name="RoPEModel",
inputs=[model_in],
outputs=[sin_out, cos_out],
initializer=[
unsqueeze_tensor,
reshape_tensor,
reshape_shape,
],
)

op = onnx.OperatorSetIdProto()
op.version = OPSET_VERSION
model = onnx.helper.make_model(graph_def, opset_imports=[op])
onnx.checker.check_model(model)
super().__init__(model, [input_shape], "rope_model.dot")
12 changes: 0 additions & 12 deletions tests/onnx/quantization/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@

import pytest

from nncf.parameters import ModelType
from tests.onnx.models import ALL_SYNTHETIC_MODELS
from tests.onnx.models import RoPEModel
from tests.onnx.quantization.common import compare_nncf_graph
from tests.onnx.quantization.common import min_max_quantize_model
from tests.onnx.quantization.common import mock_collect_statistics
Expand All @@ -25,13 +23,3 @@ def test_synthetic_models_graph(model_cls_to_test, mocker):
model_to_test = model_cls_to_test()
quantized_model = min_max_quantize_model(model_to_test.onnx_model)
compare_nncf_graph(quantized_model, "synthetic/" + model_to_test.path_ref_graph)


@pytest.mark.parametrize("model_cls_to_test", [RoPEModel])
def test_synthetic_models_graph_transformer(model_cls_to_test, mocker):
mock_collect_statistics(mocker)
model_to_test = model_cls_to_test()
quantized_model = min_max_quantize_model(
model_to_test.onnx_model, quantization_params={"model_type": ModelType.TRANSFORMER}
)
compare_nncf_graph(quantized_model, "synthetic/" + model_to_test.path_ref_graph)
1 change: 1 addition & 0 deletions tests/onnx/test_pattern_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
IGNORING_IGNORED_PATTERN_REASONS = {
IgnoredPatternNames.FC_BN_HSWISH_ACTIVATION: "Not relevant for ONNX.",
IgnoredPatternNames.EQUAL_LOGICALNOT: "Not relevant for ONNX.",
IgnoredPatternNames.ROPE: "Not relevant for ONNX.",
}


Expand Down

This file was deleted.

3 changes: 0 additions & 3 deletions tests/torch/ptq/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
import torch

from nncf import Dataset
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.torch import wrap_model
from nncf.torch.layers import NNCF_RNN
from nncf.torch.layers import LSTMCellNNCF
from tests.cross_fw.test_templates.helpers import EmbeddingModel
from tests.cross_fw.test_templates.helpers import RoPEModel
from tests.cross_fw.test_templates.helpers import ScaledDotProductAttentionModel
from tests.torch import test_models
from tests.torch.quantization.test_algo_quantization import SharedLayersModel
Expand Down Expand Up @@ -52,7 +50,6 @@ def get_model_name(description):

TEST_MODELS_DESC = [
(ModelDesc("embedding_model", EmbeddingModel, [1, 10]), {}),
(ModelDesc("rope_model", RoPEModel, [1, 10]), {"model_type": ModelType.TRANSFORMER}),
(
ModelDesc(
"scaled_dot_product_attention_model",
Expand Down
1 change: 1 addition & 0 deletions tests/torch/test_pattern_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
IGNORING_IGNORED_PATTERN_REASONS = {
IgnoredPatternNames.FC_BN_HSWISH_ACTIVATION: "Not relevant for Torch.",
IgnoredPatternNames.EQUAL_LOGICALNOT: "Not relevant for Torch.",
IgnoredPatternNames.ROPE: "Not relevant for Torch.",
}


Expand Down

0 comments on commit 6a6434d

Please sign in to comment.