diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 9bc2328cc71b6..ac959d5c061f7 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -206,10 +206,9 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, "GroupQueryAttention": self._infer_GroupQueryAttention, - "SparseAttention": self._infer_SparseAttention, - "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, + "MatMulNBits": self._infer_MatMulNBits, "MultiHeadAttention": self._infer_MultiHeadAttention, "NhwcConv": self._infer_NhwcConv, "PackedAttention": self._infer_PackedAttention, @@ -223,8 +222,10 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "RestorePadding": self._infer_RestorePadding, "RotaryEmbedding": self._infer_RotaryEmbedding, "SimplifiedLayerNormalization": self._infer_LayerNormalization, + "SkipGroupNorm": self._infer_SkipGroupNorm, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, + "SparseAttention": self._infer_SparseAttention, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather, @@ -1256,6 +1257,25 @@ def _infer_MatMul(self, node): # noqa: N802 def _infer_MatMulInteger(self, node): # noqa: N802 self._compute_matmul_shape(node, onnx.TensorProto.INT32) + def _infer_MatMulNBits(self, node): # noqa: N802 + lhs_shape = self._get_shape(node, 0) + rhs_shape = [get_attribute(node, "K"), get_attribute(node, "N")] + lhs_rank = len(lhs_shape) + assert lhs_rank > 0 + if lhs_rank == 1: + new_shape = rhs_shape[1:] + else: + new_shape = lhs_shape[:-1] + rhs_shape[1:] + # merge reduce dim + self._check_merged_dims( + [lhs_shape[-1], rhs_shape[0]], + allow_broadcast=False, + ) + # infer output_dtype from input type when not specified + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + def _infer_NonMaxSuppression(self, node): # noqa: N802 selected = str(self._new_symbolic_dim_from_output(node)) vi = self.known_vi_[node.output[0]] diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index eca1430448e8e..29680c98fb4de 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -594,6 +594,55 @@ def test_dequantize_linear_ms_domain(self): ] self._check_shapes(graph, inferred.graph, expected_shapes) + def test_matmulnbits(self): + """ + Test ORT MatMulNBits op. + Check that the output shape is propagated from the inputs and that the output data + type comes from the first input. + """ + b_np = numpy.random.randint(0, 255, (4, 1, 8), numpy.uint8) + b = numpy_helper.from_array(b_np, name="b") + scale_np = numpy.random.rand(4).astype(numpy.float32) + scale = numpy_helper.from_array(scale_np, name="scale") + zero_point_np = numpy.random.randint(0, 255, (4), numpy.uint8) + zero_point = numpy_helper.from_array(zero_point_np, name="zero_point") + + initializers = [b, scale, zero_point] + + kwargs = {"K": 10, "N": 4, "block_size": 16} + + nodes = [ + helper.make_node( + "MatMulNBits", + inputs=[ + "input_f32", + "b", + "scale", + "zero_point", + ], + outputs=["output_f32"], + **kwargs, + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["x", 2, 3, 10]), + ] + + outputs = [ + helper.make_tensor_value_info("output_f32", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "MatMulNBits_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("output_f32", TensorProto.FLOAT, ["x", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + class TestSymbolicShapeInferenceForSlice(unittest.TestCase): def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):