Skip to content

Commit

Permalink
OnnxToTorch lowering resize op (llvm#3013)
Browse files Browse the repository at this point in the history
nod-ai/SHARK-ModelDev#358
adds a lowering from onnx to linalg for bilinear and nearest resize with
support for using scales or sizes to get resize shape. uses coordinate
transform half pixel for bilinear mode and asymmetrical for nearest
mode. See
https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize. Added
two passes -- one for bilinear and the other for nearest.
  • Loading branch information
aldesilv authored May 8, 2024
1 parent bce800a commit ec6d7aa
Show file tree
Hide file tree
Showing 8 changed files with 776 additions and 0 deletions.
29 changes: 29 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7260,6 +7260,35 @@ def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [
}];
}

def Torch_Aten__InterpolateSizeListScaleListOp : Torch_Op<"aten.__interpolate.size_list_scale_list", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalListOfTorchIntType:$size,
AnyTorchOptionalListOfTorchFloatType:$scale_factor,
Torch_StringType:$mode,
AnyTorchOptionalBoolType:$align_corners,
AnyTorchOptionalBoolType:$recompute_scale_factor,
Torch_BoolType:$antialias
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__InterpolateSizeListScaleListOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void Aten__InterpolateSizeListScaleListOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
152 changes: 152 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2637,4 +2637,156 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.replaceOp(binder.op, {loss, logProb});
return success();
});
patterns.onOp(
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
std::string mode, nearest_mode, coordTfMode;
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

if (auto attr = binder.op->getAttr("torch.onnx.antialias")) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for antialias attribute");
}
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for axes attribute");
}
if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for "
"exclude_outside attribute");
}
if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for "
"extrapolation_value attribute");
}
if (auto attr =
binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for "
"keep_aspect_ratio_policy attribute");
}

if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(mode, "mode", "nearest") ||
binder.customOpNameStringAttr(
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
binder.customOpNameStringAttr(nearest_mode, "nearest_mode", ""))
return failure();

if (mode == "nearest" && nearest_mode != "floor") {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for nearest_mode "
"except floor");
}

Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value modeStrValue;

auto extract = [&rewriter, &binder](Value x, Value v) {
auto xTy = x.getType().cast<Torch::ValueTensorType>();
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(xTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();

return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
v);
};

auto getValueList = [&](Value operand) {
SmallVector<Value> itemList;
auto sizes =
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
Torch::BaseTensorType operandType =
operand.getType().cast<Torch::BaseTensorType>();

SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = operandType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());

MLIRContext *context = binder.op->getContext();
for (int i = sizes[0] - 2; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, operand, zero, selectIndex);
Value item = extract(operand, ext);
itemList.push_back(item);
}
auto xTy = operand.getType().cast<Torch::ValueTensorType>();
Value ValueList;
if (isa<IntegerType>(xTy.getDtype())) {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)), itemList);
} else {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::FloatType::get(context)), itemList);
}
return ValueList;
};

Value scalesValueList = noneVal;
Value sizesValueList = noneVal;
Value alignCorners =
coordTfMode == "align_corners" ? cstTrue : cstFalse;

if (mode == "cubic") {
return rewriter.notifyMatchFailure(binder.op,
"unimplemented: bicubic mode");
}
if (mode == "linear") {
modeStrValue = rewriter.create<Torch::ConstantStrOp>(binder.getLoc(),
"bilinear");
if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(scaleOperand);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(sizeOperand);
}
}
if (mode == "nearest") {
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), "nearest");
if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(scaleOperand);
sizesValueList = noneVal;
} else {
Value sizesOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(sizesOperand);
}
}
if (scalesValueList.getType().isa<Torch::NoneType>() &&
sizesValueList.getType().isa<Torch::NoneType>()) {
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
}
rewriter
.replaceOpWithNewOp<Torch::Aten__InterpolateSizeListScaleListOp>(
binder.op, resultType, operands[0], sizesValueList,
scalesValueList, modeStrValue,
/* AnyTorchOptionalBoolType:$align_corners */ alignCorners,
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal,
/*Torch_BoolType:$antialias*/ cstFalse);
return success();
});
}
Loading

0 comments on commit ec6d7aa

Please sign in to comment.