diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 5c327182b7..73b8fda8c8 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -629,7 +629,10 @@ ElementsAttr ElementsAttrBuilder::gather( ArrayRef inputShape = inputType.getShape(); assert(axis < inputShape.size() && "gather axis out of range"); auto postAxisShape = inputShape.drop_front(axis + 1); - ArrayRef indicesShape = indices.getShapedType().getShape(); + ShapedType indicesType = indices.getShapedType(); + assert(indicesType.getElementType().isSignlessInteger() && + "gather indices must be i32 or i64"); + ArrayRef indicesShape = indicesType.getShape(); SmallVector outShape(inputShape.take_front(axis)); outShape.append(indicesShape.begin(), indicesShape.end()); outShape.append(postAxisShape.begin(), postAxisShape.end()); @@ -637,13 +640,18 @@ ElementsAttr ElementsAttrBuilder::gather( return fromWideNums(outType, [&](MutableArrayRef dst) { size_t postAxisNumElements = ShapedType::getNumElements(postAxisShape); ArrayBuffer src = getElementsWideNums(input); - ArrayBuffer indicesArray = getElementsArray(indices); + // Convert indices of any signed int element type to int64 by + // first promoting to WideNum and then casting to int64. + // In practice we support both int32 and int64 in this way. + ArrayBuffer indicesWideNums = getElementsWideNums(indices); + ArrayRef indicesArray = + castArrayRef(indicesWideNums.get()); size_t axisInputSize = inputShape[axis]; size_t inputBlockLen = axisInputSize * postAxisNumElements; - size_t outBlockLen = indicesArray.get().size() * postAxisNumElements; + size_t outBlockLen = indicesArray.size() * postAxisNumElements; size_t start = 0; WideNum *out = dst.begin(); - for (int64_t idx : indicesArray.get()) { + for (int64_t idx : indicesArray) { int64_t adjustedIdx = idx < 0 ? idx + axisInputSize : idx; const WideNum *in = src.get().begin() + adjustedIdx * postAxisNumElements; for (size_t offset = start; offset < dst.size(); offset += outBlockLen) { diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index 1342fd9cb7..e96f8eb238 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -1663,6 +1663,21 @@ func.func @test_gather_negative_index() -> tensor<*xf32>{ // ----- +func.func @test_gather_rank0_int32_indices() -> tensor<*xf32>{ + %0 = onnx.Constant dense<[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]> : tensor<3x2xf32> + %1 = onnx.Constant dense<1> : tensor + %2 = "onnx.Gather"(%0, %1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor) -> tensor<*xf32> + "onnx.Return"(%2) : (tensor<*xf32>) -> () + + // CHECK-LABEL: func @test_gather_rank0_int32_indices + // CHECK-SAME: () -> tensor<2xf32> { + // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2.300000e+00, 3.400000e+00]> : tensor<2xf32> + // CHECK: onnx.Return [[VAR_0_]] : tensor<2xf32> + // CHECK: } +} + +// ----- + func.func @test_reshape() -> tensor<*xf32> { %0 = onnx.Constant dense<[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]]> : tensor<3x3xf32> %1 = onnx.Constant dense<[1, -1]> : tensor<2xi64>