Skip to content

Commit

Permalink
Merge pull request #303 from Xilinx/bump_to_77d7f644
Browse files Browse the repository at this point in the history
[AutoBump] Merge with fixes of 77d7f64 (Jun 13, needs LLVM bump) (66)
  • Loading branch information
mgehre-amd authored Sep 11, 2024
2 parents e42cc87 + 2b86be6 commit 4670b65
Show file tree
Hide file tree
Showing 78 changed files with 3,991 additions and 461 deletions.
8 changes: 4 additions & 4 deletions docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ cmake -GNinja -Bbuild \
`# use ccache to cache build results` \
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
`# use LLD to link in seconds, rather than minutes` \
`# if using clang <= 13, replace --ld-path=lld with -fuse-ld=lld` \
-DCMAKE_EXE_LINKER_FLAGS_INIT="--ld-path=lld" \
-DCMAKE_MODULE_LINKER_FLAGS_INIT="--ld-path=lld" \
-DCMAKE_SHARED_LINKER_FLAGS_INIT="--ld-path=lld" \
`# if using clang <= 13, replace --ld-path=ld.lld with -fuse-ld=lld` \
-DCMAKE_EXE_LINKER_FLAGS_INIT="--ld-path=ld.lld" \
-DCMAKE_MODULE_LINKER_FLAGS_INIT="--ld-path=ld.lld" \
-DCMAKE_SHARED_LINKER_FLAGS_INIT="--ld-path=ld.lld" \
`# Enabling libtorch binary cache instead of downloading the latest libtorch everytime.` \
`# Testing against a mismatched version of libtorch may cause failures` \
-DLIBTORCH_CACHE=ON \
Expand Down
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 4764 files
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 76 files
+2 −2 .github/workflows/publishWheelRelease.yml
+48 −2 BUILD.bazel
+3 −4 CODE_OF_CONDUCT.md
+2 −2 WORKSPACE.bazel
+7 −2 build_tools/github_actions/ci_build_docs.sh
+1 −1 build_tools/llvm_version.txt
+2 −0 docs/_toc.yaml
+31 −0 docs/generated/interpreter_passes.md
+14 −29 docs/generated/stablehlo_passes.md
+2 −1 docs/spec.md
+22 −0 stablehlo/conversions/linalg/tests/dot-product.mlir
+44 −0 stablehlo/conversions/linalg/tests/gather.mlir
+3 −2 stablehlo/conversions/linalg/transforms/CMakeLists.txt
+32 −11 stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
+3 −3 stablehlo/conversions/linalg/transforms/StablehloToLinalgReduce.cpp
+8 −15 stablehlo/dialect/StablehloBytecode.cpp
+17 −2 stablehlo/dialect/StablehloOps.cpp
+2 −0 stablehlo/dialect/StablehloOps.td
+21 −0 stablehlo/dialect/TypeInference.cpp
+3 −0 stablehlo/dialect/TypeInference.h
+1 −1 stablehlo/dialect/Version.h
+3 −3 stablehlo/integrations/python/stablehlo/savedmodel/stablehlo_to_tf_saved_model.py
+20 −14 stablehlo/integrations/python/tests/stablehlo.py
+2 −0 stablehlo/integrations/python/tests/stablehlo_to_tf_saved_model_test.py
+20 −0 stablehlo/reference/CMakeLists.txt
+21 −23 stablehlo/reference/InterpreterInstrumentWithProbe.cpp
+1 −1 stablehlo/reference/InterpreterOps.td
+33 −0 stablehlo/reference/InterpreterPasses.h
+49 −0 stablehlo/reference/InterpreterPasses.td
+28 −0 stablehlo/tests/interpret/all_to_all.mlir
+0 −1 stablehlo/tests/interpret/dynamic_gather.mlir
+0 −5 stablehlo/tests/ops_speculatability.mlir
+14 −1 stablehlo/tests/ops_stablehlo.mlir
+16 −0 stablehlo/tests/ops_stablehlo_quantized.mlir
+0 −46 stablehlo/tests/ops_stablehlo_roundtrip.mlir
+112 −0 stablehlo/tests/stablehlo_aggressive_folder.mlir
+26 −0 stablehlo/tests/stablehlo_convert_to_signless.mlir
+2,229 −0 stablehlo/tests/stablehlo_legalize_quant_to_int.mlir
+1 −1 stablehlo/tests/stablehlo_probe_instrumentation.mlir
+12 −0 stablehlo/tests/stablehlo_refine_shapes.mlir
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_20_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_20_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_0_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_0_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_1_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_1_0.mlir.bc
+27 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
+1 −0 stablehlo/tools/CMakeLists.txt
+2 −0 stablehlo/tools/StablehloOptMain.cpp
+4 −2 stablehlo/transforms/CMakeLists.txt
+4 −2 stablehlo/transforms/Passes.h
+21 −34 stablehlo/transforms/Passes.td
+169 −12 stablehlo/transforms/StablehloAggressiveFolder.cpp
+137 −0 stablehlo/transforms/StablehloConvertToSignless.cpp
+1,332 −0 stablehlo/transforms/StablehloLegalizeQuantToInt.cpp
+3 −2 stablehlo/transforms/VhloLegalizeToStablehlo.cpp
13 changes: 13 additions & 0 deletions include/torch-mlir-c/TorchTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,19 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
/// Gets the !torch.quint8 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void);

//===----------------------------------------------------------------------===//
// torch.qint16 type.
//===----------------------------------------------------------------------===//

/// Checks whether the given type is a !torch.qint16 type
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt16(MlirType t);

/// Gets the !torch.qint16 type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt16TypeGet(MlirContext context);

/// Gets the !torch.qint16 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt16TypeGetTypeID(void);

//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ struct OpBinder {
return success();
}

ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) {
if (idx >= op->getNumOperands())
return failure();
valueIdx = op->getOperand(idx);
auto tt = dyn_cast<Torch::ListType>(valueIdx.getType());
if (!tt)
return failure();
if (!toValidTensorType(tt.getContainedType()))
return failure();
return success();
}

ParseResult tensorListResultType(Torch::ListType &type0) {
if (op->getNumResults() != 1)
return failure();
Expand Down
12 changes: 11 additions & 1 deletion include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Value createConstantIntList(OpBinder binder,
ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput);

Type getQTorchTypeFromTorchIntType(Type ty);
Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty);

template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -96,6 +96,16 @@ m_OnnxListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {

std::optional<int64_t> onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx);

LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
Location loc, Value input, int64_t dimA,
int64_t dimB, Value &transposed);

LogicalResult createTorchPermuteOp(OpBinder binder,
ConversionPatternRewriter &rewriter,
Location loc, Value input,
SmallVector<int64_t> permuteDims,
Value &permuted);

} // namespace mlir::torch::onnx_c

#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
203 changes: 154 additions & 49 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,106 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
}];
}

def Torch_AtenRreluOp : Torch_Op<"aten.rrelu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenRreluOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::rrelu_ : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRrelu_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenRrelu_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCeluOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCelu_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -4810,53 +4910,6 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [
}];
}

def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCeluOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCelu_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenRealOp : Torch_Op<"aten.real", [
AllowsTypeRefinement,
ReadOnly
Expand Down Expand Up @@ -6766,6 +6819,31 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
}];
}

def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$indices,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -6854,6 +6932,33 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
}];
}

def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$indices,
AnyTorchListOfTorchIntType:$output_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxUnpool3dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenMaxUnpool3dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -16197,11 +16302,11 @@ def Torch_PrimsVarOp : Torch_Op<"prims.var", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prims::var : (Tensor, int[]?, float, int?) -> (Tensor)`";
let summary = "Generated op for `prims::var : (Tensor, int[]?, float?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$inp,
AnyTorchOptionalListOfTorchIntType:$dims,
Torch_FloatType:$correction,
AnyTorchOptionalFloatType:$correction,
AnyTorchOptionalIntType:$output_dtype
);
let results = (outs
Expand Down
Loading

0 comments on commit 4670b65

Please sign in to comment.