Skip to content

Commit

Permalink
Move backend-specific logic to input nodes search
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 16, 2023
1 parent 4a244f3 commit 5de90a2
Show file tree
Hide file tree
Showing 18 changed files with 116 additions and 156 deletions.
1 change: 0 additions & 1 deletion nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class OperatorMetatype:
hw_config_names: List[str] = []
output_channel_axis: Optional[int] = None
ignored_input_ports: List[int] = []
input_edges_num_expected = None

@classmethod
def get_all_aliases(cls) -> List[str]:
Expand Down
41 changes: 0 additions & 41 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
class ONNXOpMetatype(OperatorMetatype):
op_names: List[str] = []
subtypes: List[Type[OperatorMetatype]] = []
input_edges_num_expected = None

@classmethod
def get_all_aliases(cls) -> List[str]:
Expand Down Expand Up @@ -79,7 +78,6 @@ class ONNXDepthwiseConvolutionMetatype(ONNXOpWithWeightsMetatype):
weight_port_ids = [1]
bias_port_id = 2
output_channel_axis = 1
input_edges_num_expected = 2

@classmethod
def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
Expand All @@ -96,7 +94,6 @@ class ONNXConvolutionMetatype(ONNXOpWithWeightsMetatype):
bias_port_id = 2
output_channel_axis = 1
subtypes = [ONNXDepthwiseConvolutionMetatype]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -108,7 +105,6 @@ class ONNXConvolutionTransposeMetatype(ONNXOpWithWeightsMetatype):
weight_port_ids = [1]
bias_port_id = 2
output_channel_axis = 1
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -121,7 +117,6 @@ class ONNXGemmMetatype(ONNXOpWithWeightsMetatype):
bias_port_id = 2
possible_weight_ports = [0, 1]
output_channel_axis = -1
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -134,7 +129,6 @@ class ONNXMatMulMetatype(ONNXOpMetatype):
bias_port_id = 2
possible_weight_ports = [0, 1]
output_channel_axis = -1
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand Down Expand Up @@ -190,31 +184,27 @@ class ONNXGlobalAveragePoolMetatype(ONNXOpMetatype):
name = "GlobalAveragePoolOp"
op_names = ["GlobalAveragePool"]
hw_config_names = [HWConfigOpName.AVGPOOL]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXAveragePoolMetatype(ONNXOpMetatype):
name = "AveragePoolOp"
op_names = ["AveragePool"]
hw_config_names = [HWConfigOpName.AVGPOOL]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXGlobalMaxPoolMetatype(ONNXOpMetatype):
name = "GlobalMaxPoolOp"
op_names = ["GlobalMaxPool"]
hw_config_names = [HWConfigOpName.MAXPOOL]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXMaxPoolMetatype(ONNXOpMetatype):
name = "MaxPoolOp"
op_names = ["MaxPool"]
hw_config_names = [HWConfigOpName.MAXPOOL]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -228,31 +218,27 @@ class ONNXAddLayerMetatype(ONNXOpMetatype):
name = "AddOp"
op_names = ["Add", "Sum"]
hw_config_names = [HWConfigOpName.ADD]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXSubMetatype(ONNXOpMetatype):
name = "SubOp"
op_names = ["Sub"]
hw_config_names = [HWConfigOpName.SUBTRACT]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXMulLayerMetatype(ONNXOpMetatype):
name = "MulOp"
op_names = ["Mul"]
hw_config_names = [HWConfigOpName.MULTIPLY]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXDivLayerMetatype(ONNXOpMetatype):
name = "DivOp"
op_names = ["Div"]
hw_config_names = [HWConfigOpName.DIVIDE]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -266,7 +252,6 @@ class ONNXConcatMetatype(ONNXOpMetatype):
class ONNXBatchNormMetatype(ONNXOpMetatype):
name = "BatchNormalizationOp"
op_names = ["BatchNormalization"]
input_edges_num_expected = 5


@ONNX_OPERATION_METATYPES.register()
Expand Down Expand Up @@ -300,7 +285,6 @@ class ONNXTileMetatype(ONNXOpMetatype):
class ONNXUpsampleMetatype(ONNXOpMetatype):
name = "UpsampleOp"
op_names = ["Upsample"]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -319,7 +303,6 @@ class ONNXShapeMetatype(ONNXOpMetatype):
class ONNXExpandMetatype(ONNXOpMetatype):
name = "ExpandOp"
op_names = ["Expand"]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -340,23 +323,20 @@ class ONNXLessMetatype(ONNXOpMetatype):
name = "LessOp"
op_names = ["Less"]
hw_config_names = [HWConfigOpName.LESS]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXGreaterMetatype(ONNXOpMetatype):
name = "GreaterOp"
op_names = ["Greater"]
hw_config_names = [HWConfigOpName.GREATER]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXEqualMetatype(ONNXOpMetatype):
name = "EqualOp"
op_names = ["Equal"]
hw_config_names = [HWConfigOpName.EQUAL]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -371,31 +351,27 @@ class ONNXAndMetatype(ONNXOpMetatype):
name = "AndOp"
op_names = ["And"]
hw_config_names = [HWConfigOpName.LOGICALAND]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXOrMetatype(ONNXOpMetatype):
name = "OrOp"
op_names = ["Or"]
hw_config_names = [HWConfigOpName.LOGICALOR]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXMaximumMetatype(ONNXOpMetatype):
name = "MaxOp"
op_names = ["Max"]
hw_config_names = [HWConfigOpName.MAXIMUM]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXMinimumMetatype(ONNXOpMetatype):
name = "MinOp"
op_names = ["Min"]
hw_config_names = [HWConfigOpName.MINIMUM]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -410,7 +386,6 @@ class ONNXPowMetatype(ONNXOpMetatype):
name = "PowOp"
op_names = ["Pow"]
hw_config_names = [HWConfigOpName.POWER]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -433,7 +408,6 @@ class ONNXEmbeddingMetatype(ONNXOpMetatype):
hw_config_names = [HWConfigOpName.EMBEDDING]
weight_port_ids = [0]
weight_channel_axis = 0
input_edges_num_expected = 2

@classmethod
def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
Expand All @@ -444,7 +418,6 @@ def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
class ONNXLogMetatype(ONNXOpMetatype):
name = "LogOp"
op_names = ["Log"]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand Down Expand Up @@ -475,29 +448,25 @@ class ONNXScatterNDMetatype(ONNXOpMetatype):
class ONNXRoiAlignMetatype(ONNXOpMetatype):
name = "RoiAlignOp"
op_names = ["RoiAlign"]
input_edges_num_expected = 3


@ONNX_OPERATION_METATYPES.register()
class ONNXGatherMetatype(ONNXOpMetatype):
name = "GatherOp"
op_names = ["Gather"]
subtypes = [ONNXEmbeddingMetatype]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXGatherNDMetatype(ONNXOpMetatype):
name = "GatherNDOp"
op_names = ["GatherND"]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXGatherElementsMetatype(ONNXOpMetatype):
name = "GatherElementsOp"
op_names = ["GatherElements"]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -518,7 +487,6 @@ class ONNXSqueezeMetatype(ONNXOpMetatype):
class ONNXNonMaxSuppressionMetatype(ONNXOpMetatype):
name = "NonMaxSuppressionOp"
op_names = ["NonMaxSuppression"]
# input_edges_num_expected = from 2 to 5


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -537,30 +505,26 @@ class ONNXCastLikeMetatype(ONNXOpMetatype):
class ONNXReduceMinMetatype(ONNXOpMetatype):
name = "ReduceMinOp"
op_names = ["ReduceMin"]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
class ONNXReduceMaxMetatype(ONNXOpMetatype):
name = "ReduceMaxOp"
op_names = ["ReduceMax"]
hw_config_names = [HWConfigOpName.REDUCEMAX]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
class ONNXReduceSumMetatype(ONNXOpMetatype):
name = "ReduceSumOp"
op_names = ["ReduceSum"]
hw_config_names = [HWConfigOpName.REDUCESUM]
input_edges_num_expected = 1


class ONNXReduceL2Metatype(ONNXOpMetatype):
name = "ReduceL2Op"
op_names = ["ReduceL2"]
hw_config_names = [HWConfigOpName.REDUCEL2]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -580,7 +544,6 @@ class ONNXReduceMeanMetatype(ONNXOpMetatype):
name = "ReduceMeanOp"
op_names = ["ReduceMean"]
hw_config_names = [HWConfigOpName.REDUCEMEAN]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
Expand All @@ -606,29 +569,25 @@ class ONNXTransposeMetatype(ONNXOpMetatype):
name = "TransposeOp"
op_names = ["Transpose"]
hw_config_names = [HWConfigOpName.TRANSPOSE]
input_edges_num_expected = 2


@ONNX_OPERATION_METATYPES.register()
class ONNXDropoutMetatype(ONNXOpMetatype):
name = "DropoutOp"
op_names = ["Dropout"]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
class ONNXFlattenMetatype(ONNXOpMetatype):
name = "FlattenOp"
op_names = ["Flatten"]
hw_config_names = [HWConfigOpName.FLATTEN]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
class ONNXSoftmaxMetatype(ONNXOpMetatype):
name = "SoftmaxOp"
op_names = ["Softmax"]
input_edges_num_expected = 1


@ONNX_OPERATION_METATYPES.register()
Expand Down
Loading

0 comments on commit 5de90a2

Please sign in to comment.