Skip to content

Commit

Permalink
support Gather indices of type int32
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen committed Sep 8, 2023
1 parent e3a8a67 commit 1a0f9fa
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,21 +629,27 @@ ElementsAttr ElementsAttrBuilder::gather(
ArrayRef<int64_t> inputShape = inputType.getShape();
assert(axis < inputShape.size() && "gather axis out of range");
auto postAxisShape = inputShape.drop_front(axis + 1);
ArrayRef<int64_t> indicesShape = indices.getShapedType().getShape();
ShapedType indicesType = indices.getShapedType();
assert(indicesType.getElementType().isSignlessInteger() &&
"gather indices must be i32 or i64");
ArrayRef<int64_t> indicesShape = indicesType.getShape();
SmallVector<int64_t> outShape(inputShape.take_front(axis));
outShape.append(indicesShape.begin(), indicesShape.end());
outShape.append(postAxisShape.begin(), postAxisShape.end());
auto outType = inputType.clone(outShape);
return fromWideNums(outType, [&](MutableArrayRef<WideNum> dst) {
size_t postAxisNumElements = ShapedType::getNumElements(postAxisShape);
ArrayBuffer<WideNum> src = getElementsWideNums(input);
ArrayBuffer<int64_t> indicesArray = getElementsArray<int64_t>(indices);
//
ArrayBuffer<WideNum> indicesWideNums = getElementsWideNums(indices);
ArrayRef<int64_t> indicesArray =
castArrayRef<int64_t, WideNum>(indicesWideNums.get());
size_t axisInputSize = inputShape[axis];
size_t inputBlockLen = axisInputSize * postAxisNumElements;
size_t outBlockLen = indicesArray.get().size() * postAxisNumElements;
size_t outBlockLen = indicesArray.size() * postAxisNumElements;
size_t start = 0;
WideNum *out = dst.begin();
for (int64_t idx : indicesArray.get()) {
for (int64_t idx : indicesArray) {
int64_t adjustedIdx = idx < 0 ? idx + axisInputSize : idx;
const WideNum *in = src.get().begin() + adjustedIdx * postAxisNumElements;
for (size_t offset = start; offset < dst.size(); offset += outBlockLen) {
Expand Down
15 changes: 15 additions & 0 deletions test/mlir/onnx/onnx_constprop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,21 @@ func.func @test_gather_negative_index() -> tensor<*xf32>{

// -----

func.func @test_gather_rank0_int32_indices() -> tensor<*xf32>{
%0 = onnx.Constant dense<[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]> : tensor<3x2xf32>
%1 = onnx.Constant dense<1> : tensor<i32>
%2 = "onnx.Gather"(%0, %1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<i32>) -> tensor<*xf32>
"onnx.Return"(%2) : (tensor<*xf32>) -> ()

// CHECK-LABEL: func @test_gather_rank0_int32_indices
// CHECK-SAME: () -> tensor<2xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2.300000e+00, 3.400000e+00]> : tensor<2xf32>
// CHECK: onnx.Return [[VAR_0_]] : tensor<2xf32>
// CHECK: }
}

// -----

func.func @test_reshape() -> tensor<*xf32> {
%0 = onnx.Constant dense<[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]]> : tensor<3x3xf32>
%1 = onnx.Constant dense<[1, -1]> : tensor<2xi64>
Expand Down

0 comments on commit 1a0f9fa

Please sign in to comment.