Skip to content

Commit

Permalink
[onnx] Lowerings from onnx.relu and onnx.selu
Browse files Browse the repository at this point in the history
Started work on the `relu` and `selu` lowerings for ONNX to Torch.
  • Loading branch information
rsuderman committed Dec 12, 2023
1 parent 099e1f4 commit 56b093a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
19 changes: 19 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,25 @@ struct OpBinder {
return failure();
}

ParseResult f32FloatAttr(float &value, StringRef nameSuffix,
float defaultValue = 0.0f) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr) {
value = defaultValue;
return success();
}
if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
FloatType t = cast<FloatType>(floatAttr.getType());
if (t.getWidth() != 32)
return failure();
value = floatAttr.getValueAsDouble();
return success();
}
return failure();
}

ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
std::string defaultValue = "") {
SmallString<64> name("torch.onnx.");
Expand Down
42 changes: 41 additions & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,44 @@ using namespace mlir::torch::onnx_c;
// results in a lot of ONNX test cases that all reduce to the exact same
// thing here, so we simplify.
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
OnnxCustomOpConversionPattern &patterns) {}
OnnxCustomOpConversionPattern &patterns) {
patterns.onOp("Relu", 6,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenReluOp>(
binder.op, resultType, operand);
return success();
});

patterns.onOp(
"Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
float alpha, gamma;
Value operand;
if (binder.tensorOperand(operand) ||
binder.f32FloatAttr(alpha, "alpha") ||
binder.f32FloatAttr(gamma, "gamma") ||
binder.tensorResultType(resultType))
return failure();

Value vAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));

Value vScale = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), gamma));

Value vInputScale = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), 1.0));

rewriter.replaceOpWithNewOp<Torch::AtenEluOp>(
binder.op, resultType, operand, vAlpha, vScale, vInputScale);
return success();
});
}
25 changes: 25 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: torch-mlir-opt <%s --split-file -convert-torch-onnx-to-torch | FileCheck %s
// Generally, the test cases accumulated here come from running the importer
// over all included backend tests that involve simple ops with no model
// level constants. This is a pragmatic choice which lets us have a lot
// of tests in this file, whereas the others tend to be more bespoke.


// CHECK-LABEL: func.func @test_relu
func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 14 : si64} {
// CHECK: torch.aten.relu %arg0
%0 = torch.operator "onnx.Relu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}

// -----

// CHECK-LABEL: func.func @test_selu
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1
// CHECK-DAG: %[[F2:.+]] = torch.constant.float 2
// CHECK-DAG: %[[F3:.+]] = torch.constant.float 3
// CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]]
%0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}

0 comments on commit 56b093a

Please sign in to comment.