Skip to content

Commit

Permalink
Merge branch 'llvm:main' into reducesumsquare
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam authored Apr 24, 2024
2 parents c03d9b7 + 7be22bb commit 36559db
Show file tree
Hide file tree
Showing 37 changed files with 2,962 additions and 708 deletions.
2 changes: 1 addition & 1 deletion docs/add_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

Collected links and contacts for how to add ops to torch-mlir.


<details>
<summary>Turbine Camp: Start Here</summary>
This document was previously known as `turbine-camp.md` to Nod.ai. "Turbine Camp" is part of Nod.ai's onboarding process. Welcome to turbine camp. This document originated at Nod.ai as a part of onboardding process, where new nod-ai folks learn about the architecture of our work by adding support for 2 ops to torch-mlir. I decided to put this into torch mlir because a lot of this is about torch-mlir.
Expand All @@ -27,6 +26,7 @@ The details of how we do it and helpful commands to help you set up each repo is
PS: IREE is pronounced Eerie, and hence the ghost icon.

## How to begin
0. Set up torch-mlir according to the instructions here: https://github.com/llvm/torch-mlir/blob/main/docs/development.md
1. You will start by adding support for 2 ops in torch-mlir, to get you familiar with the center of our pipeline. Begin by reading [torch-mlir's documentation on how to implement a new torch op](https://github.com/llvm/torch-mlir/blob/main/docs/Torch-ops-E2E-implementation.md), and set up `llvm/torch_mlir` using https://github.com/llvm/torch-mlir/blob/main/docs/development.md
2. Pick 1 of the yet-unimplemented from the following. You should choose something that looks easy to you. **Make sure you create an issue by clicking the little "target" icon to the right of the op, thereby marking the op as yours**
- [TorchToLinalg ops tracking issue](https://github.com/nod-ai/SHARK-Turbine/issues/347)
Expand Down
61 changes: 61 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,26 @@
#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

class Endian {
private:
static constexpr uint32_t uint32_ = 0x01020304;
static constexpr uint8_t magic_ = (const uint8_t &)uint32_;

public:
static constexpr bool little = magic_ == 0x04;
static constexpr bool big = magic_ == 0x01;
static_assert(little || big, "Cannot determine endianness!");

private:
Endian() = delete;
};

namespace mlir::torch::onnx_c {

Value createConstantIntList(OpBinder binder,
Expand All @@ -28,6 +43,52 @@ LogicalResult OnnxLstmExpander(OpBinder binder,

bool areAllElementsDistinct(SmallVector<int64_t> array);

namespace detail {
/// Matches the constant integers stored in a `onnx.Constant`.
struct onnx_list_of_constant_ints_op_binder {
SmallVectorImpl<int64_t> &bind_values;

/// Creates a matcher instance that binds the value to bvs if match succeeds.
onnx_list_of_constant_ints_op_binder(SmallVectorImpl<int64_t> &bvs)
: bind_values(bvs) {}

bool match(Operation *op) {
auto constOp = dyn_cast<Torch::OperatorOp>(op);
if (!constOp || !constOp.getName().equals("onnx.Constant"))
return false;

if (DenseResourceElementsAttr attr =
constOp->getAttr("torch.onnx.value")
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {
op->emitError("unimplemented: importing on big endian systems");
return false;
}

auto ty = cast<ShapedType>(attr.getType());
ElementsAttr denseAttr;
auto ptr = attr.getRawHandle().getBlob()->getData();
denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr);
for (auto axis : denseAttr.getValues<llvm::APInt>()) {
bind_values.push_back(axis.getSExtValue());
}
return true;
}
return false;
}
};
} // namespace detail

/// Matches the constant integers stored in a `onnx.Constant`.
inline detail::onnx_list_of_constant_ints_op_binder
m_OnnxListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {
return detail::onnx_list_of_constant_ints_op_binder(bind_values);
}

std::optional<int64_t> onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx);

} // namespace mlir::torch::onnx_c

#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
100 changes: 100 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4223,6 +4223,52 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [
}];
}

def Torch_AtenTruncOp : Torch_Op<"aten.trunc", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::trunc : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTruncOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenTruncOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::trunc_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTrunc_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenTrunc_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenSignOp : Torch_Op<"aten.sign", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -9092,6 +9138,7 @@ def Torch_AtenTensorIntOp : Torch_Op<"aten.tensor.int", [
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenScalarTensorOp : Torch_Op<"aten.scalar_tensor", [
Expand Down Expand Up @@ -10712,6 +10759,30 @@ def Torch_AtenProdDimIntOp : Torch_Op<"aten.prod.dim_int", [
}];
}

def Torch_AtenProdOp : Torch_Op<"aten.prod", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::prod : (Tensor, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalIntType:$dtype
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenProdOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenProdOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenMaxOp : Torch_Op<"aten.max", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -11577,6 +11648,7 @@ def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
Expand Down Expand Up @@ -15907,6 +15979,34 @@ def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [
let hasFolder = 1;
}

def Torch_PrimsIotaOp : Torch_Op<"prims.iota", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prims::iota : (int, int, int, int, Device, bool) -> (Tensor)`";
let arguments = (ins
Torch_IntType:$length,
Torch_IntType:$start,
Torch_IntType:$step,
Torch_IntType:$dtype,
Torch_DeviceType:$device,
Torch_BoolType:$requires_grad
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult PrimsIotaOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void PrimsIotaOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
HasValueSemantics,
AllowsTypeRefinement,
Expand Down
2 changes: 1 addition & 1 deletion include/torch-mlir/Dialect/Torch/IR/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ m_TorchConstantBool(bool *bind_value) {
}

namespace detail {
/// Matches the constant integers stored in a `torch.ListConstruct`.
/// Matches the constant integers stored in a `torch.prim.ListConstruct`.
struct torch_list_of_constant_ints_op_binder {
SmallVectorImpl<int64_t> &bind_values;

Expand Down
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
| torch.bool | i1 |
| torch.qint8 | !torch.qint8 |
| torch.quint8 | !torch.quint8 |
| torch.qint32 | !torch.qint32 |
| torch.complex64 | complex<f32> |
| torch.complex128 | complex<f64> |
|-------------------|--------------------|
Expand Down
Loading

0 comments on commit 36559db

Please sign in to comment.