Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OnnxToTorch lowering resize op (cherry-picks from upstream) (CR-1196776) #168

Merged
merged 12 commits into from
Jun 7, 2024
Merged
54 changes: 54 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6833,6 +6833,35 @@ def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [
}];
}

def Torch_Aten__InterpolateSizeListScaleListOp : Torch_Op<"aten.__interpolate.size_list_scale_list", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalListOfTorchIntType:$size,
AnyTorchOptionalListOfTorchFloatType:$scale_factor,
Torch_StringType:$mode,
AnyTorchOptionalBoolType:$align_corners,
AnyTorchOptionalBoolType:$recompute_scale_factor,
Torch_BoolType:$antialias
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__InterpolateSizeListScaleListOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void Aten__InterpolateSizeListScaleListOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -13249,6 +13278,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [
}];
}

def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`";
let arguments = (ins
AnyTorchListOfTorchStringType:$l,
Torch_StringType:$item
);
let results = (outs
Torch_BoolType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
31 changes: 31 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl<bool> &bind_values) {
return detail::torch_list_of_constant_bools_op_binder(bind_values);
}

namespace detail {
/// Matches the constant strs stored in a `torch.ListConstruct`.
struct torch_list_of_constant_strs_op_binder {
SmallVectorImpl<std::string> &bind_values;

/// Creates a matcher instance that binds the value to bvs if match succeeds.
torch_list_of_constant_strs_op_binder(SmallVectorImpl<std::string> &bvs)
: bind_values(bvs) {}

bool match(Operation *op) {
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
if (!listConstruct)
return false;
for (Value value : listConstruct.getElements()) {
std::string str;
if (matchPattern(value, m_TorchConstantStr(str)))
bind_values.push_back(str);
else
return false;
}
return true;
}
};
} // namespace detail

/// Matches the constant strs stored in a `torch.prim.ListConstruct`.
inline detail::torch_list_of_constant_strs_op_binder
m_TorchListOfConstantStrs(SmallVectorImpl<std::string> &bind_values) {
return detail::torch_list_of_constant_strs_op_binder(bind_values);
}

namespace detail {
/// Matches the expected tensor and dim from `torch.aten.size.int`.
struct torch_tensor_size_int_op_binder {
Expand Down
180 changes: 180 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2099,4 +2099,184 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
std::string mode, nearest_mode, coordTfMode;
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

if (auto attr = binder.op->getAttr("torch.onnx.antialias")) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for antialias attribute");
}
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for axes attribute");
}
if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for "
"exclude_outside attribute");
}
if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for "
"extrapolation_value attribute");
}
if (auto attr =
binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for "
"keep_aspect_ratio_policy attribute");
}

if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(mode, "mode", "nearest") ||
binder.customOpNameStringAttr(
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "round_prefer_floor"))
return failure();
if (coordTfMode == "tf_crop_and_resize")
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: coordinate transformation mode: "
"tf_crop_and_resize");

if (mode == "nearest" && coordTfMode != "asymmetric" && coordTfMode != "half_pixel") {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for coord tf mode "
"except asymmetric and half_pixel");
}

unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
.getSizes()
.size();

Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value modeStrValue;

auto extract = [&rewriter, &binder](Value x, Value v) {
auto xTy = x.getType().cast<Torch::ValueTensorType>();
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(xTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();

return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
v);
};

auto getValueList = [&](Value operand) {
SmallVector<Value> itemList;
auto sizes =
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
Torch::BaseTensorType operandType =
operand.getType().cast<Torch::BaseTensorType>();

SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = operandType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());

MLIRContext *context = binder.op->getContext();
for (int i = 2; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, operand, zero, selectIndex);
Value item = extract(operand, ext);
itemList.push_back(item);
}
auto xTy = operand.getType().cast<Torch::ValueTensorType>();
Value ValueList;
if (isa<IntegerType>(xTy.getDtype())) {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)), itemList);
} else {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::FloatType::get(context)), itemList);
}
return ValueList;
};

Value scalesValueList = noneVal;
Value sizesValueList = noneVal;
Value alignCorners =
coordTfMode == "align_corners" ? cstTrue : cstFalse;
if (mode == "cubic") {
return rewriter.notifyMatchFailure(binder.op,
"unimplemented: bicubic mode");
}
// supported modes:
// bilinear (half_pixel), bilinear with align_corners,
// bilinear_pytorch_half_pixel, bilinear_asymmetric nearest
// (asymmetric), nearest with align_corners, nearest_half_pixel,
// nearest_pytorch_half_pixel
if (mode == "linear") {
std::string modeStr;
switch (rank) {
case 3:
modeStr = "linear";
break;
case 4:
modeStr = "bilinear";
break;
case 5:
modeStr = "trilinear";
break;
default:
return failure();
}
// Confusingly enough, the default coordTfMode for pytorch bilinear
// mode is apparently half_pixel, NOT pytorch_half_pixel
if (coordTfMode != "half_pixel" && coordTfMode != "align_corners")
modeStr = (modeStr + "_") + coordTfMode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
}
if (mode == "nearest") {
std::string modeStr = "nearest";
// The default coordTfMode for pytorch with mode = nearest is
// apparently asymmetric
if (coordTfMode != "asymmetric" && coordTfMode != "align_corners")
modeStr = (modeStr + "_") + coordTfMode;
if (nearest_mode != "floor")
modeStr = modeStr + "," + nearest_mode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
}
if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(scaleOperand);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(sizeOperand);
}
if (scalesValueList.getType().isa<Torch::NoneType>() &&
sizesValueList.getType().isa<Torch::NoneType>()) {
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
}
rewriter
.replaceOpWithNewOp<Torch::Aten__InterpolateSizeListScaleListOp>(
binder.op, resultType, operands[0], sizesValueList,
scalesValueList, modeStrValue,
/* AnyTorchOptionalBoolType:$align_corners */ alignCorners,
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal,
/*Torch_BoolType:$antialias*/ cstFalse);
return success();
});
}
Loading
Loading