Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] fold aten.log #3223

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 46 additions & 45 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,51 +256,6 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
}];
}

def Torch_AtenLogOp : Torch_Op<"aten.log", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -4085,6 +4040,52 @@ def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [
}];
}

def Torch_AtenLogOp : Torch_Op<"aten.log", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenFloorOp : Torch_Op<"aten.floor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
91 changes: 85 additions & 6 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1241,27 +1241,31 @@ llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
llvm::SmallVector<APInt> splattrs;

for (auto attr : attrs) {
bool isunsigned = false;
// Note that i1 is neither signed nor unsigned.
// But we should trait i1 as unsigned, otherwise that
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
// So here only distinguish signed integer.
bool isSigned = false;
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
isunsigned = dyn_cast<IntegerType>(dense.getElementType()).isUnsigned();
isSigned = dyn_cast<IntegerType>(dense.getElementType()).isSigned();
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APInt>());
} else {
splattrs.push_back(dense.getValues<APInt>()[idx]);
}
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
isunsigned = cast<IntegerType>(intattr.getType()).isUnsigned();
isSigned = cast<IntegerType>(intattr.getType()).isSigned();
splattrs.push_back(intattr.getValue());
} else {
return {};
}

auto &apint = splattrs.back();
if (apint.getBitWidth() < bitwidth) {
if (isunsigned) {
apint = apint.zextOrTrunc(bitwidth);
} else {
if (isSigned) {
apint = apint.sextOrTrunc(bitwidth);
} else {
apint = apint.zextOrTrunc(bitwidth);
}
}
}
Expand Down Expand Up @@ -1795,6 +1799,81 @@ OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) {
return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenLogOp
//===----------------------------------------------------------------------===//

using UnaryPromoteFpOperator = std::function<double(double)>;
using UnaryPromoteIntOperator = std::function<double(APInt, bool)>;

static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand,
ValueTensorType resultTy,
UnaryPromoteFpOperator fpFolder,
UnaryPromoteIntOperator intFolder) {
constexpr int64_t kMaxFold = 16;
if (!resultTy.hasDtype() || !resultTy.hasSizes())
return nullptr;
if (!isa<mlir::FloatType>(resultTy.getDtype()))
return nullptr;

auto fpTy = dyn_cast<mlir::FloatType>(operand.getType().getElementType());
auto intTy = dyn_cast<mlir::IntegerType>(operand.getType().getElementType());
if (!fpTy && !intTy)
return nullptr;

auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype());
bool splat = operand.isSplat();
bool withinMaxFold =
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold;
if (!splat && !withinMaxFold)
return nullptr;

const int64_t numValues = splat ? 1 : resultBTy.getNumElements();

llvm::SmallVector<Attribute> operands = {operand};
llvm::SmallVector<APFloat> folded;
for (int i = 0, s = numValues; i < s; ++i) {
double fold = 0.0;
if (fpTy) {
auto inputs = getFoldValueAtIndexFp(operands, i);
fold = fpFolder(inputs[0]);
}
if (intTy) {
auto inputs =
getFoldValueAtIndexInt(operands, intTy.getIntOrFloatBitWidth(), i);
fold = intFolder(inputs[0], intTy.isSigned());
}

APFloat val(fold);
bool unused;
val.convert(
cast<mlir::FloatType>(resultBTy.getElementType()).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &unused);
folded.push_back(val);
}
return DenseElementsAttr::get(resultBTy, folded);
}

OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
auto resultType = dyn_cast<ValueTensorType>(getType());
if (!self || !resultType)
return nullptr;

// Note that i1 is neither signed nor unsigned.
// But we should trait i1 as unsigned, otherwise that
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
auto intFold = [](APInt a, bool isSigned) -> double {
if (isSigned)
return std::log(a.getSExtValue());
else
return std::log(a.getZExtValue());
};
auto fpFold = [](double a) -> double { return std::log(a); };

return unaryPromoteFolder(self, resultType, fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenFloorOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::relu : (Tensor) -> (Tensor)",
"aten::relu6 : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::selu : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::sinh : (Tensor) -> (Tensor)",
Expand Down Expand Up @@ -356,6 +355,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit_with_mutating_variants("aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::log : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
Expand Down
33 changes: 33 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2870,3 +2870,36 @@ func.func @aten_tensor_tensor_ne() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4
%fpBool = torch.aten.ne.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1>
return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>
}

// -----

// CHECK-LABEL: @aten_log$fold_splat_i1
func.func @aten_log$fold_splat_i1() -> !torch.vtensor<[4], f32> {
// CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[4],f32>
%cst = torch.vtensor.literal(dense<true> : tensor<4xi1>) : !torch.vtensor<[4], i1>
%result = torch.aten.log %cst : !torch.vtensor<[4], i1> -> !torch.vtensor<[4], f32>
return %result : !torch.vtensor<[4], f32>
}

// -----

// CHECK-LABEL: @aten_log$fold_splat_si32
func.func @aten_log$fold_splat_si32() -> !torch.vtensor<[4], f32> {
// CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.09861231> : tensor<4xf32>) : !torch.vtensor<[4],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[4],f32>
%cst = torch.vtensor.literal(dense<3> : tensor<4xsi32>) : !torch.vtensor<[4], si32>
%result = torch.aten.log %cst : !torch.vtensor<[4], si32> -> !torch.vtensor<[4], f32>
return %result : !torch.vtensor<[4], f32>
}

// -----

// CHECK-LABEL: @aten_log$fold_splat_f32
func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> {
// CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.09861231> : tensor<4xf32>) : !torch.vtensor<[4],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[4],f32>
%cst = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4], f32>
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
return %result : !torch.vtensor<[4], f32>
}
Loading