Skip to content

Commit

Permalink
[Torch] fold aten.log
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Apr 24, 2024
1 parent fab2696 commit aa96a08
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 52 deletions.
91 changes: 46 additions & 45 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,51 +256,6 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
}];
}

def Torch_AtenLogOp : Torch_Op<"aten.log", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -4085,6 +4040,52 @@ def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [
}];
}

def Torch_AtenLogOp : Torch_Op<"aten.log", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenFloorOp : Torch_Op<"aten.floor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
91 changes: 85 additions & 6 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1241,27 +1241,31 @@ llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
llvm::SmallVector<APInt> splattrs;

for (auto attr : attrs) {
bool isunsigned = false;
// Note that i1 is neither signed nor unsigned.
// But we should trait i1 as unsigned, otherwise that
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
// So here only distinguish signed integer.
bool isSigned = false;
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
isunsigned = dyn_cast<IntegerType>(dense.getElementType()).isUnsigned();
isSigned = dyn_cast<IntegerType>(dense.getElementType()).isSigned();
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APInt>());
} else {
splattrs.push_back(dense.getValues<APInt>()[idx]);
}
} else if (auto intattr = dyn_cast<IntegerAttr>(attr)) {
isunsigned = cast<IntegerType>(intattr.getType()).isUnsigned();
isSigned = cast<IntegerType>(intattr.getType()).isSigned();
splattrs.push_back(intattr.getValue());
} else {
return {};
}

auto &apint = splattrs.back();
if (apint.getBitWidth() < bitwidth) {
if (isunsigned) {
apint = apint.zextOrTrunc(bitwidth);
} else {
if (isSigned) {
apint = apint.sextOrTrunc(bitwidth);
} else {
apint = apint.zextOrTrunc(bitwidth);
}
}
}
Expand Down Expand Up @@ -1795,6 +1799,81 @@ OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) {
return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenLogOp
//===----------------------------------------------------------------------===//

using UnaryPromoteFpOperator = std::function<double(double)>;
using UnaryPromoteIntOperator = std::function<double(APInt, bool)>;

static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand,
ValueTensorType resultTy,
UnaryPromoteFpOperator fpFolder,
UnaryPromoteIntOperator intFolder) {
constexpr int64_t kMaxFold = 16;
if (!resultTy.hasDtype() || !resultTy.hasSizes())
return nullptr;
if (!isa<mlir::FloatType>(resultTy.getDtype()))
return nullptr;

auto fpTy = dyn_cast<mlir::FloatType>(operand.getType().getElementType());
auto intTy = dyn_cast<mlir::IntegerType>(operand.getType().getElementType());
if (!fpTy && !intTy)
return nullptr;

auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype());
bool splat = operand.isSplat();
bool withinMaxFold =
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold;
if (!splat && !withinMaxFold)
return nullptr;

const int64_t numValues = splat ? 1 : resultBTy.getNumElements();

llvm::SmallVector<Attribute> operands = {operand};
llvm::SmallVector<APFloat> folded;
for (int i = 0, s = numValues; i < s; ++i) {
double fold = 0.0;
if (fpTy) {
auto inputs = getFoldValueAtIndexFp(operands, i);
fold = fpFolder(inputs[0]);
}
if (intTy) {
auto inputs =
getFoldValueAtIndexInt(operands, intTy.getIntOrFloatBitWidth(), i);
fold = intFolder(inputs[0], intTy.isSigned());
}

APFloat val(fold);
bool unused;
val.convert(
cast<mlir::FloatType>(resultBTy.getElementType()).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &unused);
folded.push_back(val);
}
return DenseElementsAttr::get(resultBTy, folded);
}

OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
auto self = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
auto resultType = dyn_cast<ValueTensorType>(getType());
if (!self || !resultType)
return nullptr;

// Note that i1 is neither signed nor unsigned.
// But we should trait i1 as unsigned, otherwise that
// APInt(1,1).getSExtValue() return allOnes 64-bit integer.
auto intFold = [](APInt a, bool isSigned) -> double {
if (isSigned)
return std::log(a.getSExtValue());
else
return std::log(a.getZExtValue());
};
auto fpFold = [](double a) -> double { return std::log(a); };

return unaryPromoteFolder(self, resultType, fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenFloorOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::relu : (Tensor) -> (Tensor)",
"aten::relu6 : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::selu : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::sinh : (Tensor) -> (Tensor)",
Expand Down Expand Up @@ -356,6 +355,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit_with_mutating_variants("aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::log : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
Expand Down
33 changes: 33 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2870,3 +2870,36 @@ func.func @aten_tensor_tensor_ne() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4
%fpBool = torch.aten.ne.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1>
return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>
}

// -----

// CHECK-LABEL: @aten_log$fold_splat_i1
func.func @aten_log$fold_splat_i1() -> !torch.vtensor<[4], f32> {
// CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[4],f32>
%cst = torch.vtensor.literal(dense<true> : tensor<4xi1>) : !torch.vtensor<[4], i1>
%result = torch.aten.log %cst : !torch.vtensor<[4], i1> -> !torch.vtensor<[4], f32>
return %result : !torch.vtensor<[4], f32>
}

// -----

// CHECK-LABEL: @aten_log$fold_splat_si32
func.func @aten_log$fold_splat_si32() -> !torch.vtensor<[4], f32> {
// CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.09861231> : tensor<4xf32>) : !torch.vtensor<[4],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[4],f32>
%cst = torch.vtensor.literal(dense<3> : tensor<4xsi32>) : !torch.vtensor<[4], si32>
%result = torch.aten.log %cst : !torch.vtensor<[4], si32> -> !torch.vtensor<[4], f32>
return %result : !torch.vtensor<[4], f32>
}

// -----

// CHECK-LABEL: @aten_log$fold_splat_f32
func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> {
// CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.09861231> : tensor<4xf32>) : !torch.vtensor<[4],f32>
// CHECK: return %[[RET]] : !torch.vtensor<[4],f32>
%cst = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4], f32>
%result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32>
return %result : !torch.vtensor<[4], f32>
}

0 comments on commit aa96a08

Please sign in to comment.