Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Gather indices of type int32 #2488

Merged
merged 2 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,21 +629,29 @@ ElementsAttr ElementsAttrBuilder::gather(
ArrayRef<int64_t> inputShape = inputType.getShape();
assert(axis < inputShape.size() && "gather axis out of range");
auto postAxisShape = inputShape.drop_front(axis + 1);
ArrayRef<int64_t> indicesShape = indices.getShapedType().getShape();
ShapedType indicesType = indices.getShapedType();
assert(indicesType.getElementType().isSignlessInteger() &&
"gather indices must be i32 or i64");
ArrayRef<int64_t> indicesShape = indicesType.getShape();
SmallVector<int64_t> outShape(inputShape.take_front(axis));
outShape.append(indicesShape.begin(), indicesShape.end());
outShape.append(postAxisShape.begin(), postAxisShape.end());
auto outType = inputType.clone(outShape);
return fromWideNums(outType, [&](MutableArrayRef<WideNum> dst) {
size_t postAxisNumElements = ShapedType::getNumElements(postAxisShape);
ArrayBuffer<WideNum> src = getElementsWideNums(input);
ArrayBuffer<int64_t> indicesArray = getElementsArray<int64_t>(indices);
// Convert indices of any signed int element type to int64 by
// first promoting to WideNum and then casting to int64.
// In practice we support both int32 and int64 in this way.
ArrayBuffer<WideNum> indicesWideNums = getElementsWideNums(indices);
ArrayRef<int64_t> indicesArray =
castArrayRef<int64_t, WideNum>(indicesWideNums.get());
size_t axisInputSize = inputShape[axis];
size_t inputBlockLen = axisInputSize * postAxisNumElements;
size_t outBlockLen = indicesArray.get().size() * postAxisNumElements;
size_t outBlockLen = indicesArray.size() * postAxisNumElements;
size_t start = 0;
WideNum *out = dst.begin();
for (int64_t idx : indicesArray.get()) {
for (int64_t idx : indicesArray) {
int64_t adjustedIdx = idx < 0 ? idx + axisInputSize : idx;
const WideNum *in = src.get().begin() + adjustedIdx * postAxisNumElements;
for (size_t offset = start; offset < dst.size(); offset += outBlockLen) {
Expand Down
15 changes: 15 additions & 0 deletions test/mlir/onnx/onnx_constprop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,21 @@ func.func @test_gather_negative_index() -> tensor<*xf32>{

// -----

func.func @test_gather_rank0_int32_indices() -> tensor<*xf32>{
%0 = onnx.Constant dense<[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]> : tensor<3x2xf32>
%1 = onnx.Constant dense<1> : tensor<i32>
%2 = "onnx.Gather"(%0, %1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<i32>) -> tensor<*xf32>
"onnx.Return"(%2) : (tensor<*xf32>) -> ()

// CHECK-LABEL: func @test_gather_rank0_int32_indices
// CHECK-SAME: () -> tensor<2xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2.300000e+00, 3.400000e+00]> : tensor<2xf32>
// CHECK: onnx.Return [[VAR_0_]] : tensor<2xf32>
// CHECK: }
}

// -----

func.func @test_reshape() -> tensor<*xf32> {
%0 = onnx.Constant dense<[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]]> : tensor<3x3xf32>
%1 = onnx.Constant dense<[1, -1]> : tensor<2xi64>
Expand Down
Loading