Skip to content

Commit

Permalink
[onnx] Add onnx-to-torch lowering for random ops (#3193)
Browse files Browse the repository at this point in the history
This commit adds the OnnxToTorch lowering for Onnx's RandomNormal, RandomNormalLike, RandomUniform, and RandomUniformLike op.
  • Loading branch information
vivekkhandelwal1 authored Apr 22, 2024
1 parent 6abc737 commit 3c252cd
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 83 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ m_OnnxListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {
return detail::onnx_list_of_constant_ints_op_binder(bind_values);
}

std::optional<int64_t> onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx);

} // namespace mlir::torch::onnx_c

#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
92 changes: 22 additions & 70 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "mlir/IR/DialectResourceBlobManager.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/FormatVariadic.h"
Expand All @@ -17,56 +18,6 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;

class Endian {
private:
static constexpr uint32_t uint32_ = 0x01020304;
static constexpr uint8_t magic_ = (const uint8_t &)uint32_;

public:
static constexpr bool little = magic_ == 0x04;
static constexpr bool big = magic_ == 0x01;
static_assert(little || big, "Cannot determine endianness!");

private:
Endian() = delete;
};

static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
// TODO: Add complete mapping.
// Where are the ONNX and PyTorch dtype enums defined?
// ONNX:
// https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto
// PyTorch:
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88

int64_t dtypeIntTorch = [dtypeIntOnnx]() {
switch (dtypeIntOnnx) {
case 1:
return 6; // float
case 2:
return 0; // uint8
case 3:
return 1; // int8
case 6:
return 3; // int32
case 7:
return 4; // int64
case 9:
return 11; // bool
case 10:
return 5; // half
case 11:
return 7; // double
case 16:
return 15; // bfloat16
default:
return -1; // No dtype
}
}();

return dtypeIntTorch;
}

static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
Location loc, Value input,
int64_t dimA, int64_t dimB,
Expand Down Expand Up @@ -428,7 +379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
int64_t dtypeIntOnnx, dtypeIntTorch;
int64_t dtypeIntOnnx;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) ||
binder.tensorResultType(resultType))
Expand All @@ -452,16 +403,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
rewriter.replaceOp(binder.op, bernoulli);
return success();
}
dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (dtypeIntTorch == -1) {
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch));
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
Expand Down Expand Up @@ -539,25 +489,21 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
"Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
int64_t dtypeIntOnnx, dtypeIntTorch;
int64_t dtypeIntOnnx;
if (binder.tensorOperand(operand) ||
binder.s64IntegerAttr(dtypeIntOnnx, "to") ||
binder.tensorResultType(resultType))
return failure();

dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (dtypeIntTorch == -1) {
auto message = llvm::formatv("unimplemented support for the given "
"dtype conversion (onnx 'type' = {0})",
dtypeIntOnnx);
auto y = rewriter.notifyMatchFailure(binder.op, message);

return y;
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch));
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Expand Down Expand Up @@ -1768,9 +1714,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
Value mVal = rewriter.create<Torch::AtenSizeIntOp>(binder.getLoc(),
operand, cst1);
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
int64_t dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value dtypeVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch));
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));

// diagonalIndex = 0 populates the main diagonal
// diagonalIndex > 0 populates an upper diagonal
Expand Down
214 changes: 214 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2274,4 +2274,218 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, input, cstAlpha, value);
return success();
});
patterns.onOp(
"RandomNormal", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
SmallString<64> name("torch.onnx.seed");
auto seedAttr = binder.op->getAttr(name);
if (seedAttr)
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for seed attribute");

Torch::ValueTensorType resultType;
int64_t dtypeIntOnnx;
float mean, scale;
SmallVector<int64_t> shape;
if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
binder.f32FloatAttr(mean, "mean", 0.0) ||
binder.f32FloatAttr(scale, "scale", 1.0) ||
binder.s64IntegerArrayAttr(shape, "shape", {}) ||
binder.tensorResultType(resultType)) {
return failure();
}

std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));

Value shapeList = createConstantIntList(binder, rewriter, shape);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
binder.op->getLoc(), resultType, shapeList,
/*dtype=*/constDtype,
/*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone,
/*memoryFormat=*/cstNone);

Value cstMean = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), mean));
Value cstStd = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), scale));

rewriter.replaceOpWithNewOp<Torch::AtenNormalFunctionalOp>(
binder.op, resultType, self, cstMean, cstStd,
/*generator=*/cstNone);
return success();
});
patterns.onOp(
"RandomNormalLike", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
SmallString<64> name("torch.onnx.seed");
auto seedAttr = binder.op->getAttr(name);
if (seedAttr)
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for seed attribute");

Torch::ValueTensorType resultType;
int64_t dtypeIntOnnx;
float mean, scale;
SmallVector<int64_t> shape;
Value input;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
binder.f32FloatAttr(mean, "mean", 0.0) ||
binder.f32FloatAttr(scale, "scale", 1.0) ||
binder.tensorResultType(resultType)) {
return failure();
}

std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));

Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
input = rewriter.create<Torch::AtenToDtypeOp>(
binder.op->getLoc(), resultType, input, constDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/cstNone);

Value cstMean = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), mean));
Value cstStd = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), scale));

rewriter.replaceOpWithNewOp<Torch::AtenNormalFunctionalOp>(
binder.op, resultType, input, cstMean, cstStd,
/*generator=*/cstNone);
return success();
});
patterns.onOp(
"RandomUniform", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
SmallString<64> name("torch.onnx.seed");
auto seedAttr = binder.op->getAttr(name);
if (seedAttr)
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for seed attribute");

Torch::ValueTensorType resultType;
int64_t dtypeIntOnnx;
float high, low;
SmallVector<int64_t> shape;
if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
binder.f32FloatAttr(high, "high", 1.0) ||
binder.f32FloatAttr(low, "low", 0.0) ||
binder.s64IntegerArrayAttr(shape, "shape", {}) ||
binder.tensorResultType(resultType)) {
return failure();
}

std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));

Value shapeList = createConstantIntList(binder, rewriter, shape);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

Value self = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
binder.op->getLoc(), resultType, shapeList,
/*dtype=*/constDtype,
/*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone,
/*memoryFormat=*/cstNone);

Value cstHigh = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), high));
Value cstLow = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), low));

rewriter.replaceOpWithNewOp<Torch::AtenUniformOp>(
binder.op, resultType, self, cstLow, cstHigh,
/*generator=*/cstNone);
return success();
});
patterns.onOp(
"RandomUniformLike", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
SmallString<64> name("torch.onnx.seed");
auto seedAttr = binder.op->getAttr(name);
if (seedAttr)
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for seed attribute");

Torch::ValueTensorType resultType;
int64_t dtypeIntOnnx;
float high, low;
SmallVector<int64_t> shape;
Value input;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) ||
binder.f32FloatAttr(high, "high", 1.0) ||
binder.f32FloatAttr(low, "low", 0.0) ||
binder.tensorResultType(resultType)) {
return failure();
}

std::optional<int64_t> dtypeIntTorch =
onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (!dtypeIntTorch.has_value()) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value()));

Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
input = rewriter.create<Torch::AtenToDtypeOp>(
binder.op->getLoc(), resultType, input, constDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/cstNone);

Value cstHigh = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), high));
Value cstLow = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), low));

rewriter.replaceOpWithNewOp<Torch::AtenUniformOp>(
binder.op, resultType, input, cstLow, cstHigh,
/*generator=*/cstNone);
return success();
});
}
Loading

0 comments on commit 3c252cd

Please sign in to comment.