Skip to content

Commit

Permalink
Add support for the onnx.SequenceLength op. (llvm#3362)
Browse files Browse the repository at this point in the history
  • Loading branch information
AWoloszyn authored and sjarus committed Jun 6, 2024
1 parent 9f9830e commit cc464fd
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
13 changes: 13 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ struct OpBinder {
return success();
}

// Operand matches of different arities.
ParseResult tensorListOperand(Value &value0) {
if (op->getNumOperands() != 1)
return failure();
value0 = op->getOperand(0);
auto tt = dyn_cast<Torch::ListType>(value0.getType());
if (!tt)
return failure();
if (!toValidTensorType(tt.getContainedType()))
return failure();
return success();
}

ParseResult tensorListResultType(Torch::ListType &type0) {
if (op->getNumResults() != 1)
return failure();
Expand Down
24 changes: 24 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,30 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, operands);
return success();
});
patterns.onOp(
"SequenceLength", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
// onnx.SequenceLength takes a sequence(list) of tensors, and returns
// a zero rank tensor with the length.
Torch::ValueTensorType resultType;
Value x;
if (binder.tensorListOperand(x) || binder.tensorResultType(resultType))
return failure();

Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

Value len = rewriter.create<Torch::AtenLenTOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), x);

// AtenLenTOp returns a torch.int, so we have to
// put that in a tensor.
rewriter.replaceOpWithNewOp<Torch::AtenTensorIntOp>(
binder.op, resultType, len, none, none, cstFalse);

return success();
});
patterns.onOp(
"Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
15 changes: 15 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,21 @@ module {

// -----

// CHECK-LABEL: func.func @test_sequence_length
module {
func.func @test_sequence_length(%arg0: !torch.list<vtensor<[?,?,?],f32>>) -> !torch.vtensor<[],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[LEN:.+]] = torch.aten.len.t %arg0 : !torch.list<vtensor<[?,?,?],f32>> -> !torch.int
// CHECK: %[[LEN_AS_TEN:.+]] = torch.aten.tensor.int %[[LEN]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si64>
// CHECK: return %[[LEN_AS_TEN]] : !torch.vtensor<[],si64>
%0 = torch.operator "onnx.SequenceLength"(%arg0) : (!torch.list<vtensor<[?,?,?],f32>>) -> !torch.vtensor<[],si64>
return %0 : !torch.vtensor<[],si64>
}
}

// -----

// CHECK-LABEL: func.func @test_sce_mean_3d
func.func @test_sce_mean_3d(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.+]] = torch.constant.none
Expand Down

0 comments on commit cc464fd

Please sign in to comment.