From 1a0f9faaa4702b2461ae40e376c423a65b29cba8 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Fri, 8 Sep 2023 00:03:57 -0700 Subject: [PATCH] support Gather indices of type int32 Signed-off-by: Soren Lassen --- .../ONNX/ElementsAttr/ElementsAttrBuilder.cpp | 14 ++++++++++---- test/mlir/onnx/onnx_constprop.mlir | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 5c327182b7..d2e24a2dbe 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -629,7 +629,10 @@ ElementsAttr ElementsAttrBuilder::gather( ArrayRef inputShape = inputType.getShape(); assert(axis < inputShape.size() && "gather axis out of range"); auto postAxisShape = inputShape.drop_front(axis + 1); - ArrayRef indicesShape = indices.getShapedType().getShape(); + ShapedType indicesType = indices.getShapedType(); + assert(indicesType.getElementType().isSignlessInteger() && + "gather indices must be i32 or i64"); + ArrayRef indicesShape = indicesType.getShape(); SmallVector outShape(inputShape.take_front(axis)); outShape.append(indicesShape.begin(), indicesShape.end()); outShape.append(postAxisShape.begin(), postAxisShape.end()); @@ -637,13 +640,16 @@ ElementsAttr ElementsAttrBuilder::gather( return fromWideNums(outType, [&](MutableArrayRef dst) { size_t postAxisNumElements = ShapedType::getNumElements(postAxisShape); ArrayBuffer src = getElementsWideNums(input); - ArrayBuffer indicesArray = getElementsArray(indices); + // + ArrayBuffer indicesWideNums = getElementsWideNums(indices); + ArrayRef indicesArray = + castArrayRef(indicesWideNums.get()); size_t axisInputSize = inputShape[axis]; size_t inputBlockLen = axisInputSize * postAxisNumElements; - size_t outBlockLen = indicesArray.get().size() * postAxisNumElements; + size_t outBlockLen = indicesArray.size() * postAxisNumElements; size_t start = 0; WideNum *out = dst.begin(); - for (int64_t idx : indicesArray.get()) { + for (int64_t idx : indicesArray) { int64_t adjustedIdx = idx < 0 ? idx + axisInputSize : idx; const WideNum *in = src.get().begin() + adjustedIdx * postAxisNumElements; for (size_t offset = start; offset < dst.size(); offset += outBlockLen) { diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index 1342fd9cb7..e96f8eb238 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -1663,6 +1663,21 @@ func.func @test_gather_negative_index() -> tensor<*xf32>{ // ----- +func.func @test_gather_rank0_int32_indices() -> tensor<*xf32>{ + %0 = onnx.Constant dense<[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]> : tensor<3x2xf32> + %1 = onnx.Constant dense<1> : tensor + %2 = "onnx.Gather"(%0, %1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor) -> tensor<*xf32> + "onnx.Return"(%2) : (tensor<*xf32>) -> () + + // CHECK-LABEL: func @test_gather_rank0_int32_indices + // CHECK-SAME: () -> tensor<2xf32> { + // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2.300000e+00, 3.400000e+00]> : tensor<2xf32> + // CHECK: onnx.Return [[VAR_0_]] : tensor<2xf32> + // CHECK: } +} + +// ----- + func.func @test_reshape() -> tensor<*xf32> { %0 = onnx.Constant dense<[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]]> : tensor<3x3xf32> %1 = onnx.Constant dense<[1, -1]> : tensor<2xi64>