Skip to content

Commit

Permalink
Add MatMulNBits shape infer to SymbolicShapeInference (#21246)
Browse files Browse the repository at this point in the history
### Description
Support MatMulNBits shape infer in SymbolicShapeInference

MatMulNBits's B input is rank-2, so implicit merge does not apply.

### Motivation and Context
[Issue with performing shape inference using symbolic_shape_infer.py
with Phi-3 ONNX Models · Issue #21194 · microsoft/onnxruntime
(github.com)](#21194)
  • Loading branch information
fajin-corp authored Jul 5, 2024
1 parent 9ef28f0 commit 83e0c6b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
24 changes: 22 additions & 2 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,9 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"GemmFloat8": self._infer_GemmFloat8,
"GroupNorm": self._infer_GroupNorm,
"GroupQueryAttention": self._infer_GroupQueryAttention,
"SparseAttention": self._infer_SparseAttention,
"SkipGroupNorm": self._infer_SkipGroupNorm,
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
"MatMulNBits": self._infer_MatMulNBits,
"MultiHeadAttention": self._infer_MultiHeadAttention,
"NhwcConv": self._infer_NhwcConv,
"PackedAttention": self._infer_PackedAttention,
Expand All @@ -223,8 +222,10 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"RestorePadding": self._infer_RestorePadding,
"RotaryEmbedding": self._infer_RotaryEmbedding,
"SimplifiedLayerNormalization": self._infer_LayerNormalization,
"SkipGroupNorm": self._infer_SkipGroupNorm,
"SkipLayerNormalization": self._infer_SkipLayerNormalization,
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
"SparseAttention": self._infer_SparseAttention,
}
self.aten_op_dispatcher_ = {
"embedding": self._infer_Gather,
Expand Down Expand Up @@ -1256,6 +1257,25 @@ def _infer_MatMul(self, node): # noqa: N802
def _infer_MatMulInteger(self, node): # noqa: N802
self._compute_matmul_shape(node, onnx.TensorProto.INT32)

def _infer_MatMulNBits(self, node): # noqa: N802
lhs_shape = self._get_shape(node, 0)
rhs_shape = [get_attribute(node, "K"), get_attribute(node, "N")]
lhs_rank = len(lhs_shape)
assert lhs_rank > 0
if lhs_rank == 1:
new_shape = rhs_shape[1:]
else:
new_shape = lhs_shape[:-1] + rhs_shape[1:]
# merge reduce dim
self._check_merged_dims(
[lhs_shape[-1], rhs_shape[0]],
allow_broadcast=False,
)
# infer output_dtype from input type when not specified
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))

def _infer_NonMaxSuppression(self, node): # noqa: N802
selected = str(self._new_symbolic_dim_from_output(node))
vi = self.known_vi_[node.output[0]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,55 @@ def test_dequantize_linear_ms_domain(self):
]
self._check_shapes(graph, inferred.graph, expected_shapes)

def test_matmulnbits(self):
"""
Test ORT MatMulNBits op.
Check that the output shape is propagated from the inputs and that the output data
type comes from the first input.
"""
b_np = numpy.random.randint(0, 255, (4, 1, 8), numpy.uint8)
b = numpy_helper.from_array(b_np, name="b")
scale_np = numpy.random.rand(4).astype(numpy.float32)
scale = numpy_helper.from_array(scale_np, name="scale")
zero_point_np = numpy.random.randint(0, 255, (4), numpy.uint8)
zero_point = numpy_helper.from_array(zero_point_np, name="zero_point")

initializers = [b, scale, zero_point]

kwargs = {"K": 10, "N": 4, "block_size": 16}

nodes = [
helper.make_node(
"MatMulNBits",
inputs=[
"input_f32",
"b",
"scale",
"zero_point",
],
outputs=["output_f32"],
**kwargs,
),
]

inputs = [
helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["x", 2, 3, 10]),
]

outputs = [
helper.make_tensor_value_info("output_f32", TensorProto.UNDEFINED, None),
]

graph = helper.make_graph(nodes, "MatMulNBits_Test", inputs, outputs, initializers)
model = helper.make_model(graph)

inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True)

expected_shapes = [
helper.make_tensor_value_info("output_f32", TensorProto.FLOAT, ["x", 2, 3, 4]),
]
self._check_shapes(graph, inferred.graph, expected_shapes)


class TestSymbolicShapeInferenceForSlice(unittest.TestCase):
def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):
Expand Down

0 comments on commit 83e0c6b

Please sign in to comment.