From fb483c06f697990c60cc3c0bda7fb1d730fca3de Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 16 Dec 2024 19:27:49 -0600 Subject: [PATCH] Custom kernel lowering (#191) * Custom kernel lowering * fix * actually add file * gpu kernel generation * almost complete * use the same context * fixup * Doing things properly * cleaning up * cleanup * cleanup * fmt * final clean * cleanup * fmt * Now with dynamic shmem * fmt --------- Co-authored-by: Alex Zinenko --- BUILD | 41 +- src/enzyme_ad/jax/BUILD | 66 ++- src/enzyme_ad/jax/Dialect/Dialect.cpp | 52 ++ src/enzyme_ad/jax/Dialect/Dialect.h | 17 + src/enzyme_ad/jax/Dialect/Dialect.td | 36 ++ src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td | 74 +++ src/enzyme_ad/jax/Dialect/Ops.cpp | 53 ++ src/enzyme_ad/jax/Dialect/Ops.h | 36 ++ src/enzyme_ad/jax/Passes/LowerKernel.cpp | 578 ++++++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.h | 30 ++ src/enzyme_ad/jax/Passes/Passes.td | 31 ++ src/enzyme_ad/jax/RegistryUtils.cpp | 113 +++++ src/enzyme_ad/jax/RegistryUtils.h | 7 + src/enzyme_ad/jax/compile_with_xla.cc | 30 +- src/enzyme_ad/jax/compile_with_xla.h | 1 + src/enzyme_ad/jax/enzymexlamlir-opt.cpp | 53 +- test/lit_tests/lowering/gpu.mlir | 39 ++ workspace.bzl | 2 +- 18 files changed, 1192 insertions(+), 67 deletions(-) create mode 100644 src/enzyme_ad/jax/Dialect/Dialect.cpp create mode 100644 src/enzyme_ad/jax/Dialect/Dialect.h create mode 100644 src/enzyme_ad/jax/Dialect/Dialect.td create mode 100644 src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td create mode 100644 src/enzyme_ad/jax/Dialect/Ops.cpp create mode 100644 src/enzyme_ad/jax/Dialect/Ops.h create mode 100644 src/enzyme_ad/jax/Passes/LowerKernel.cpp create mode 100644 src/enzyme_ad/jax/RegistryUtils.cpp create mode 100644 src/enzyme_ad/jax/RegistryUtils.h create mode 100644 test/lit_tests/lowering/gpu.mlir diff --git a/BUILD b/BUILD index 0cdac784b..44afdb728 100644 --- a/BUILD +++ b/BUILD @@ -2,6 +2,15 @@ load("@rules_python//python:packaging.bzl", "py_wheel") load(":package.bzl", "py_package") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") +load( + "@xla//xla/tsl/platform:build_config_root.bzl", + "if_llvm_aarch32_available", + "if_llvm_aarch64_available", + "if_llvm_powerpc_available", + "if_llvm_system_z_available", + "if_llvm_x86_available", +) + licenses(["notice"]) package( @@ -24,7 +33,10 @@ py_package( cc_binary( name = "enzymexlamlir-opt", - srcs = ["//src/enzyme_ad/jax:enzymexlamlir-opt.cpp"], + srcs = [ + "//src/enzyme_ad/jax:enzymexlamlir-opt.cpp", + "//src/enzyme_ad/jax:RegistryUtils.cpp", + ], visibility = ["//visibility:public"], deps = [ "@enzyme//:EnzymeMLIR", @@ -44,6 +56,7 @@ cc_binary( "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:NVGPUDialect", "@llvm-project//mlir:OpenMPDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", @@ -52,8 +65,30 @@ cc_binary( "//src/enzyme_ad/jax:TransformOps", "//src/enzyme_ad/jax:XLADerivatives", "@stablehlo//:chlo_ops", - "@stablehlo//stablehlo/tests:check_ops" - ], + "@stablehlo//stablehlo/tests:check_ops", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUToLLVMIRTranslation", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + ] + if_llvm_aarch32_available([ + "@llvm-project//llvm:ARMAsmParser", + "@llvm-project//llvm:ARMCodeGen", + ]) + if_llvm_aarch64_available([ + "@llvm-project//llvm:AArch64AsmParser", + "@llvm-project//llvm:AArch64CodeGen", + ]) + if_llvm_powerpc_available([ + "@llvm-project//llvm:PowerPCAsmParser", + "@llvm-project//llvm:PowerPCCodeGen", + ]) + if_llvm_system_z_available([ + "@llvm-project//llvm:SystemZAsmParser", + "@llvm-project//llvm:SystemZCodeGen", + ]) + if_llvm_x86_available([ + "@llvm-project//llvm:X86AsmParser", + "@llvm-project//llvm:X86CodeGen", + ]), copts = [ "-Wno-unused-variable", "-Wno-return-type", diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 06230368b..7b3748296 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -2,7 +2,7 @@ load("@jax//jaxlib:symlink_files.bzl", "symlink_inputs") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -exports_files(["enzymexlamlir-opt.cpp"]) +exports_files(["enzymexlamlir-opt.cpp", "RegistryUtils.cpp"]) licenses(["notice"]) @@ -205,6 +205,55 @@ gentbl_cc_library( deps = [":EnzymeXLAPassesTdFiles"], ) + +td_library( + name = "EnzymeXLADialectTdFiles", + srcs = [ + "Dialect/Dialect.td", + ], + includes = ["."], + deps = [ + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + "@llvm-project//mlir:MemorySlotInterfacesTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "EnzymeXLAOpsIncGen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "Dialect/EnzymeXLAOps.h.inc", + ), + ( + ["-gen-op-defs"], + "Dialect/EnzymeXLAOps.cpp.inc", + ), + ( + [ + "-gen-dialect-decls", + "-dialect=enzymexla", + ], + "Dialect/EnzymeXLADialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect=enzymexla", + ], + "Dialect/EnzymeXLADialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Dialect/EnzymeXLAOps.td", + deps = [":EnzymeXLADialectTdFiles", "@enzyme//:EnzymeDialectTdFiles", "@stablehlo//:stablehlo_ops_td_files"], +) + gentbl_cc_library( name = "EnzyeHLOPatternsIncGen", tbl_outs = [ @@ -228,11 +277,13 @@ cc_library( [ "Implementations/*.cpp", "Passes/*.cpp", + "Dialect/*.cpp", ], ), hdrs = glob([ "Implementations/*.h", "Passes/*.h", + "Dialect/*.h", ]), copts = [ "-Werror=unused-variable", @@ -241,8 +292,14 @@ cc_library( "-Werror=unused-result", ], deps = [ + ":EnzymeXLAOpsIncGen", ":EnzymeXLAPassesIncGen", ":EnzyeHLOPatternsIncGen", + "@llvm-project//mlir:GPUPipelines", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:OrcJIT", + "@llvm-project//llvm:OrcTargetProcess", ":mhlo-derivatives", ":stablehlo-derivatives", ":chlo-derivatives", @@ -271,11 +328,12 @@ cc_library( pybind_library( name = "compile_with_xla", - srcs = ["compile_with_xla.cc"], + srcs = ["compile_with_xla.cc", "RegistryUtils.cpp"], hdrs = glob([ "compile_with_xla.h", "Implementations/*.h", "Passes/*.h", + "RegistryUtils.h" ]), deps = [ ":XLADerivatives", @@ -368,7 +426,9 @@ pybind_library( pybind_extension( name = "enzyme_call", - srcs = ["enzyme_call.cc"], + srcs = ["enzyme_call.cc", + "RegistryUtils.cpp" + ], visibility = ["//visibility:public"], deps = [ ":clang_compile", diff --git a/src/enzyme_ad/jax/Dialect/Dialect.cpp b/src/enzyme_ad/jax/Dialect/Dialect.cpp new file mode 100644 index 000000000..342440ef3 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Dialect.cpp @@ -0,0 +1,52 @@ +//===- EnzymeXLADialect.cpp - EnzymeXLA dialect -----------------------*- C++ +//-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Dialect.h" +#include "Ops.h" +#include "mlir/IR/DialectImplementation.h" + +#include "mlir/IR/Builders.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "mlir/IR/Dialect.h" + +// #include "Dialect/EnzymeEnums.cpp.inc" +#include "src/enzyme_ad/jax/Dialect/EnzymeXLADialect.cpp.inc" + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/Dialect/EnzymeXLAOps.cpp.inc" + +// #define GET_TYPEDEF_CLASSES +// #include "Dialect/EnzymeXLAOpsTypes.cpp.inc" +// #include "Dialect/EnzymeTypes.cpp.inc" + +using namespace mlir; +using namespace mlir::enzymexla; + +//===----------------------------------------------------------------------===// +// Enzyme dialect. +//===----------------------------------------------------------------------===// + +void EnzymeXLADialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "src/enzyme_ad/jax/Dialect/EnzymeXLAOps.cpp.inc" + >(); + // addAttributes< + // #define GET_ATTRDEF_LIST + // #include "src/enzyme_ad/jax/Dialect/EnzymeXLAAttributes.cpp.inc" + // >(); + // addTypes< + // #define GET_TYPEDEF_LIST + // #include "src/enzyme_ad/jax/Dialect/EnzymeXLAOpsTypes.cpp.inc" + // >(); +} + +// #define GET_ATTRDEF_CLASSES +// #include "src/enzyme_ad/jax/Dialect/EnzymeXLAAttributes.cpp.inc" diff --git a/src/enzyme_ad/jax/Dialect/Dialect.h b/src/enzyme_ad/jax/Dialect/Dialect.h new file mode 100644 index 000000000..f376c4d78 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Dialect.h @@ -0,0 +1,17 @@ +//===- Dialect.h - EnzymeXLA dialect -------------------------------*- C++ +//-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYMEXLA_DIALECT_H +#define ENZYMEXLA_DIALECT_H + +#include "mlir/IR/Dialect.h" + +#include "src/enzyme_ad/jax/Dialect/EnzymeXLADialect.h.inc" + +#endif // ENZYME_DIALECT_H diff --git a/src/enzyme_ad/jax/Dialect/Dialect.td b/src/enzyme_ad/jax/Dialect/Dialect.td new file mode 100644 index 000000000..cd899b0b3 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Dialect.td @@ -0,0 +1,36 @@ +//===- EnzymeXLA.td - EnzymeXLA dialect --------------------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYMEXLA_DIALECT +#define ENZYMEXLA_DIALECT + +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" + +//===----------------------------------------------------------------------===// +// Enzyme dialect definition. +//===----------------------------------------------------------------------===// + +def EnzymeXLA_Dialect : Dialect { + let name = "enzymexla"; + let description = [{}]; + let cppNamespace = "::mlir::enzymexla"; + // let useDefaultAttributePrinterParser = 1; + // let useDefaultTypePrinterParser = 1; +} + +//===----------------------------------------------------------------------===// +// Base Enzyme operation definition. +//===----------------------------------------------------------------------===// + +class EnzymeXLA_Op traits = []> + : Op; + +class EnzymeXLA_Type : TypeDef; + +#endif // ENZYMEXLA_DIALECT diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td new file mode 100644 index 000000000..fadf59032 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td @@ -0,0 +1,74 @@ +//===- EnzymeXLAOps.td - EnzymeXLA dialect ops ------------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYMEXLA_OPS +#define ENZYMEXLA_OPS + +include "Enzyme/MLIR/Dialect/Dialect.td" +include "Dialect.td" +include "mlir/Interfaces/ViewLikeInterface.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/EnumAttr.td" + +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" + +include "mlir/IR/AttrTypeBase.td" + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//include "stablehlo/dialect/Base.td" +//include "stablehlo/dialect/StablehloAttrs.td" + +def TensorI64 : Type($_self) && ::llvm::cast<::mlir::TensorType>($_self).getShape().size() == 0 && ::llvm::cast<::mlir::TensorType>($_self).getElementType().isSignlessInteger(64)">, "tensor", + "::mlir::TensorType">, + BuildableType<"RankedTensorType::get({}, $_builder.getIntegerType(64))">; + +def KernelCallOp: EnzymeXLA_Op<"kernel_call", [DeclareOpInterfaceMethods, Pure]> { + let summary = "Kernel Call operation"; + let description = [{ + }]; + + let arguments = (ins + FlatSymbolRefAttr:$fn, + TensorI64:$gridx, + TensorI64:$gridy, + TensorI64:$gridz, + TensorI64:$blockx, + TensorI64:$blocky, + TensorI64:$blockz, + TensorI64:$shmem, + Variadic:$inputs, + DefaultValuedStrAttr:$backend_config, + OptionalAttr:$operand_layouts, + OptionalAttr:$result_layouts, + DefaultValuedOptionalAttr< + ArrayAttr, "{}">:$output_operand_aliases + //OptionalAttr:$operand_layouts, + //OptionalAttr:$result_layouts, + //DefaultValuedOptionalAttr< + // TypedArrayAttrBase< + // StableHLO_OutputOperandAlias, + // "Aliasing attribute for outputs and operands of CustomCall">, + // "{}">:$output_operand_aliases + ); + + let results = (outs Variadic); + + + let assemblyFormat = [{ + $fn ` ` `blocks` `in` `(` $gridx `,` $gridy `,` $gridz `)` ` ` `threads` `in` `(` $blockx `,` $blocky `,` $blockz `)` ` ` `shmem` `=` $shmem ` ` `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; + +} + +#endif // ENZYMEXLA_OPS diff --git a/src/enzyme_ad/jax/Dialect/Ops.cpp b/src/enzyme_ad/jax/Dialect/Ops.cpp new file mode 100644 index 000000000..37ec52a76 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Ops.cpp @@ -0,0 +1,53 @@ +//===- EnzymeXLAOps.cpp - EnzymeXLA dialect ops -----------------------*- C++ +//-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Ops.h" +#include "Dialect.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/IntegerSet.h" + +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "enzymexla" + +using namespace mlir; +using namespace enzymexla; +using namespace mlir::arith; + +LogicalResult +KernelCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // TODO: Verify that the result type is same as the type of the referenced + // func.func op. + auto global = symbolTable.lookupNearestSymbolFrom( + *this, getFnAttr()); + if (!global) + return emitOpError("'") + << getFn() << "' does not reference a valid global funcOp"; + + return success(); +} diff --git a/src/enzyme_ad/jax/Dialect/Ops.h b/src/enzyme_ad/jax/Dialect/Ops.h new file mode 100644 index 000000000..91648d1e9 --- /dev/null +++ b/src/enzyme_ad/jax/Dialect/Ops.h @@ -0,0 +1,36 @@ +//===- EnzymeXLAOps.h - EnzymeXLA dialect ops -------------------------*- C++ +//-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYMEXLAOPS_H +#define ENZYMEXLAOPS_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Bytecode/BytecodeOpInterface.h" + +// #include "Dialect/EnzymeXLAEnums.h.inc" + +// #define GET_ATTRDEF_CLASSES +// #include "Dialect/EnzymeXLAAttributes.h.inc" + +// #define GET_TYPEDEF_CLASSES +// #include "Dialect/EnzymeXLAOpsTypes.h.inc" + +#define GET_OP_CLASSES +#include "src/enzyme_ad/jax/Dialect/EnzymeXLAOps.h.inc" + +// #include "Dialect/EnzymeXLATypes.h.inc" + +#endif // ENZYMEXLAOPS_H diff --git a/src/enzyme_ad/jax/Passes/LowerKernel.cpp b/src/enzyme_ad/jax/Passes/LowerKernel.cpp new file mode 100644 index 000000000..5ca38ab48 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/LowerKernel.cpp @@ -0,0 +1,578 @@ +//===- PrintPass.cpp - Print the MLIR module ------------ // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to print the MLIR module +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "src/enzyme_ad/jax/Passes/PassDetails.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +#include "src/enzyme_ad/jax/Dialect/Ops.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" +#include "src/enzyme_ad/jax/Passes/PassDetails.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" +#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/Pipelines/Passes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" + +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/LLJIT.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.h" +#include "llvm/ExecutionEngine/Orc/TargetProcess/RegisterEHFrames.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" + +#include "mlir/Target/LLVMIR/Export.h" + +#define DEBUG_TYPE "lower-kernel" + +using namespace mlir; +using namespace mlir::enzyme; +using namespace enzyme; +using namespace mlir::enzymexla; +using namespace enzymexla; + +using namespace stablehlo; + +typedef void XlaCustomCallStatus; + +llvm::StringMap kernels; +llvm::sys::SmartRWMutex kernel_mutex; +std::unique_ptr JIT = nullptr; + +void *CompileHostModule(std::string &key, mlir::ModuleOp modOp) { + if (!JIT) { + auto tJIT = + llvm::orc::LLJITBuilder() + .setLinkProcessSymbolsByDefault(true) + .setObjectLinkingLayerCreator( + [](llvm::orc::ExecutionSession &ES, const llvm::Triple &OLL) + -> llvm::Expected> { + auto obj = std::make_unique< + llvm::orc::RTDyldObjectLinkingLayer>(ES, []() { + return std::make_unique(); + }); + if (getenv("ENABLE_GDBLISTENER")) { + auto list = + llvm::JITEventListener::createGDBRegistrationListener(); + obj->registerJITEventListener(*list); + } + return obj; + }) + .create(); + if (!tJIT) { + llvm::errs() << " jit creating error: " << tJIT.takeError() << "\n"; + return nullptr; + } + JIT = std::move(tJIT.get()); + assert(JIT); + } + + std::unique_ptr ctx(new llvm::LLVMContext); + auto llvmModule = translateModuleToLLVMIR(modOp, *ctx); + if (!llvmModule) { + llvm::errs() << "could not convert to LLVM IR" + << "\n"; + return nullptr; + } + llvmModule->setDataLayout(JIT->getDataLayout()); + llvmModule->setTargetTriple(JIT->getTargetTriple().getTriple()); + + auto LibA = + JIT->createJITDylib("enzymecudadl_" + std::to_string(kernels.size())); + if (auto Err = JIT->addIRModule( + LibA.get(), + llvm::orc::ThreadSafeModule(std::move(llvmModule), std::move(ctx)))) { + llvm::errs() << " addIRModuleError " << Err << "\n"; + return nullptr; + } + + // Look up the JIT'd code entry point. + auto EntrySym = JIT->lookup(LibA.get(), "entry"); + if (!EntrySym) { + llvm::errs() << " lookupError " << EntrySym.takeError() << "\n"; + return nullptr; + } + + auto ptr = (void *)EntrySym->getValue(); + + kernels[key] = ptr; + return ptr; +} + +// See API details at +// https://github.com/openxla/xla/blob/37fb0612d36ac3d08ff984b1d61e4bc4dedf4809/xla/service/hlo.proto#L73 +extern "C" void EnzymeGPUCustomCall(void *__restrict__ stream, + void **__restrict__ buffers, + void **__restrict__ opaqueptr, + size_t opaque_len, + XlaCustomCallStatus *__restrict__ status) { + auto ptr = (void (*)(void *, void **))(opaqueptr[0]); + // auto ptr = (void(*)(void*, void**, size_t, size_t, size_t, size_t, size_t, + // size_t)) (opaqueptr[0][0]); + + // size_t gridx = opaqueptr[0][1]; + // size_t gridy = opaqueptr[0][2]; + // size_t gridz = opaqueptr[0][3]; + + // size_t blockx = opaqueptr[0][4]; + // size_t blocky = opaqueptr[0][5]; + // size_t blockz = opaqueptr[0][6]; + + ptr(stream, buffers); //, gridx, gridy, gridz, blockx, blocky, blockz); +} + +gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) { + ArrayRef objects = op.getObjectsAttr().getValue(); + + // Obtain the index of the object to select. + int64_t index = -1; + if (Attribute target = + cast(op.getOffloadingHandlerAttr()) + .getTarget()) { + // If the target attribute is a number it is the index. Otherwise compare + // the attribute to every target inside the object array to find the index. + if (auto indexAttr = mlir::dyn_cast(target)) { + index = indexAttr.getInt(); + } else { + for (auto [i, attr] : llvm::enumerate(objects)) { + auto obj = mlir::dyn_cast(attr); + if (obj.getTarget() == target) { + index = i; + } + } + } + } else { + // If the target attribute is null then it's selecting the first object in + // the object array. + index = 0; + } + + if (index < 0 || index >= static_cast(objects.size())) { + op->emitError("the requested target object couldn't be found"); + return nullptr; + } + return mlir::dyn_cast(objects[index]); +} + +void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc, + FunctionOpInterface op, bool jit, size_t gridx, + size_t gridy, size_t gridz, size_t blockx, size_t blocky, + size_t blockz, size_t shmem) { + + OpBuilder builder(op); + + auto ptrty = LLVM::LLVMPointerType::get(builder.getContext()); + mlir::Type intys[] = {ptrty, ptrty}; + FunctionType calleeType = builder.getFunctionType(intys, {}); + + FunctionType gpuTy = dyn_cast(op.getFunctionType()); + if (!gpuTy) { + if (auto lty = dyn_cast(op.getFunctionType())) { + gpuTy = builder.getFunctionType(lty.getParams(), {}); + } else { + op.emitError( + "Require target operand to have functiontype or llvmfunctiontype"); + return nullptr; + } + } + + auto submod = builder.create(loc, "offload"); + submod->setAttr("gpu.container_module", builder.getUnitAttr()); + builder.setInsertionPointToStart(&submod.getBodyRegion().front()); + + auto gpumod = builder.create(loc, "gpumodname"); + builder.setInsertionPointToStart(&gpumod.getBodyRegion().front()); + + auto gpufunc = builder.create(loc, "kernel", gpuTy); + { + IRMapping map; + map.map(op.getArguments(), gpufunc.getArguments()); + op.getFunctionBody().cloneInto(&gpufunc.getBody(), map); + gpufunc->setAttr("gpu.kernel", builder.getUnitAttr()); + + auto entry = &gpufunc.getBody().front(); + auto second = entry->getNextNode(); + entry->getOperations().splice(entry->getOperations().end(), + second->getOperations()); + + second->erase(); + + gpufunc->walk([](LLVM::ReturnOp op) { + OpBuilder rewriter(op); + rewriter.create(op.getLoc()); + op.erase(); + }); + + gpufunc->walk([](LLVM::UnreachableOp op) { + OpBuilder rewriter(op); + rewriter.create(op.getLoc()); + op.erase(); + }); + + gpufunc->walk([](func::ReturnOp op) { + OpBuilder rewriter(op); + rewriter.create(op.getLoc()); + op.erase(); + }); + } + SmallVector tocopy; + op->walk([&](CallOpInterface cop) { + if (auto op2 = cop.resolveCallable()) + tocopy.push_back(op2); + }); + SmallPtrSet done; + + builder.setInsertionPointToStart(&gpumod.getBodyRegion().front()); + while (tocopy.size()) { + auto cur = tocopy.pop_back_val(); + if (done.count(cur)) + continue; + done.insert(cur); + builder.clone(*cur); + cur->walk([&](CallOpInterface cop) { + if (auto op2 = cop.resolveCallable()) + tocopy.push_back(op2); + }); + } + + builder.setInsertionPointToEnd(&submod.getBodyRegion().front()); + + auto func = builder.create(loc, "entry", calleeType); + + auto &entryBlock = *func.addEntryBlock(); + builder.setInsertionPointToStart(&entryBlock); + + mlir::Value stream = entryBlock.getArgument(0); + auto buffers = entryBlock.getArgument(1); + + auto idx = builder.getIntegerType(64); + auto i32 = builder.getIntegerType(32); + gpu::KernelDim3 gridSize{ + builder.create(loc, gridx, idx), + builder.create(loc, gridy, idx), + builder.create(loc, gridz, idx), + }; + + gpu::KernelDim3 blockSize{ + builder.create(loc, blockx, idx), + builder.create(loc, blocky, idx), + builder.create(loc, blockz, idx), + }; + + SmallVector arguments; + for (auto arg : op.getArguments()) { + LLVM::GEPArg args[1] = {arg.getArgNumber()}; + auto gep = + builder.create(loc, ptrty, ptrty, buffers, args, true); + auto ld = builder.create(loc, arg.getType(), gep); + arguments.push_back(ld); + } + auto dynshmem = builder.create(loc, shmem, i32); + stream = builder + .create( + loc, gpu::AsyncTokenType::get(stream.getContext()), stream) + ->getResult(0); + builder.create(loc, gpufunc, gridSize, blockSize, dynshmem, + arguments, stream.getType(), + ValueRange(stream)); + + builder.create(loc); + + std::string modstr; + llvm::raw_string_ostream ss(modstr); + + ss << submod; + + if (!jit) + return nullptr; + + void *ptr = nullptr; + { + llvm::sys::SmartScopedWriter lock(kernel_mutex); + + auto found = kernels.find(ss.str()); + if (found != kernels.end()) { + ptr = found->second; + } else { + // mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); + + PassManager pm(submod.getContext()); + mlir::gpu::GPUToNVVMPipelineOptions options; + options.indexBitWidth = 64; + options.cubinTriple = "nvptx64-nvidia-cuda"; + options.cubinChip = "sm_50"; + options.cubinFeatures = "+ptx60"; + options.cubinFormat = "fatbin"; + options.optLevel = 2; + options.kernelUseBarePtrCallConv = false; + options.hostUseBarePtrCallConv = false; + mlir::gpu::buildLowerToNVVMPassPipeline(pm, options); + + pm.run(submod); + + OpBuilder builder(submod); + builder.setInsertionPointToStart(&submod.getBodyRegion().front()); + auto ptrty = LLVM::LLVMPointerType::get(builder.getContext()); + auto i64 = builder.getIntegerType(64); + auto i32 = builder.getIntegerType(32); + auto idx = i64; + auto voidty = LLVM::LLVMVoidType::get(submod.getContext()); + + auto glob = builder.create(loc, ptrty, /*constant*/ false, + LLVM::Linkage::Private, + "nv_func", mlir::Attribute()); + + mlir::Type cumodtys[] = {ptrty, ptrty}; + auto modload = builder.create( + loc, "cuModuleLoadData", LLVM::LLVMFunctionType::get(i32, cumodtys)); + + mlir::Type cutys[] = {ptrty, idx, idx, idx, idx, idx, + idx, i32, ptrty, ptrty, ptrty}; + auto launch = builder.create( + loc, "cuLaunchKernel", LLVM::LLVMFunctionType::get(voidty, cutys)); + + mlir::Type cufunctys[] = {ptrty, ptrty, ptrty}; + auto funcload = builder.create( + loc, "cuModuleGetFunction", + LLVM::LLVMFunctionType::get(i32, cufunctys)); + + LLVM::GlobalOp kernStr; + { + std::string value = "kernel"; + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1); + kernStr = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "str", + builder.getStringAttr(value + '\0')); + } + + builder.setInsertionPointToStart(&submod.getBodyRegion().front()); + + auto initfn = builder.create( + loc, "nv_func_init", LLVM::LLVMFunctionType::get(voidty, {}, false), + LLVM::Linkage::Private); + + mlir::Attribute funcs[] = {FlatSymbolRefAttr::get(initfn)}; + mlir::Attribute idxs[] = {builder.getI32IntegerAttr(0)}; + builder.create(loc, builder.getArrayAttr(funcs), + builder.getArrayAttr(idxs)); + + LLVM::GlobalOp binary; + submod.walk([&](gpu::BinaryOp op) { + gpu::ObjectAttr object = getSelectedObject(op); + auto value = object.getObject().getValue(); + auto type = LLVM::LLVMArrayType::get( + mlir::IntegerType::get(builder.getContext(), 8), value.size()); + binary = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "binary", + builder.getStringAttr(value)); + + if (object.getProperties()) { + if (auto section = mlir::dyn_cast_or_null( + object.getProperties().get("section"))) { + binary.setSectionAttr(section); + } + } + + binary.setAlignmentAttr(builder.getI64IntegerAttr(8)); + binary.setUnnamedAddrAttr(LLVM::UnnamedAddrAttr::get( + builder.getContext(), mlir::LLVM::UnnamedAddr::None)); + op.erase(); + }); + + { + auto blk = new Block(); + initfn.getRegion().push_back(blk); + builder.setInsertionPointToEnd(blk); + + auto one = builder.create( + loc, i64, builder.getI64IntegerAttr(1)); + auto modptr = builder.create(loc, ptrty, ptrty, one); + auto funcptr = builder.create(loc, ptrty, ptrty, one); + + auto addr_modbin = builder.create(loc, binary); + mlir::Value modargs[] = {modptr->getResult(0), + addr_modbin->getResult(0)}; + builder.create(loc, modload, modargs); + auto mod = builder.create(loc, ptrty, modptr); + + auto addr_kernstr = + builder.create(loc, ptrty, "str"); + + mlir::Value funcargs[] = {funcptr->getResult(0), mod->getResult(0), + addr_kernstr->getResult(0)}; + builder.create(loc, funcload, funcargs); + auto func = builder.create(loc, ptrty, funcptr); + + auto addr_glob = builder.create(loc, glob); + builder.create(loc, func, addr_glob); + builder.create(loc, ValueRange()); + } + + submod.walk([&](gpu::LaunchFuncOp op) { + builder.setInsertionPoint(op); + auto ldop = + op.getKernelOperands().front().getDefiningOp(); + assert(ldop); + auto params = ldop.getOperand(); + auto addr_glob = builder.create(loc, glob); + auto cufunc = builder.create(loc, ptrty, addr_glob); + mlir::Value args[] = {cufunc, + op.getGridSizeX(), + op.getGridSizeY(), + op.getGridSizeZ(), + op.getBlockSizeX(), + op.getBlockSizeY(), + op.getBlockSizeZ(), + op.getDynamicSharedMemorySize(), + op.getAsyncObject(), + params, + builder.create(loc, ptrty)}; + builder.create(loc, launch, args); + op.erase(); + ldop.erase(); + }); + + ptr = CompileHostModule(ss.str(), submod); + + submod.erase(); + } + } + + return ptr; +}; + +namespace { + +struct LowerKernelPass : public LowerKernelPassBase { + + void getDependentDialects(DialectRegistry ®istry) const override { + OpPassManager pm; + mlir::gpu::GPUToNVVMPipelineOptions options; + options.indexBitWidth = 64; + options.cubinTriple = "nvptx64-nvidia-cuda"; + options.cubinChip = "sm_50"; + options.cubinFeatures = "+ptx60"; + options.cubinFormat = "fatbin"; + options.optLevel = 2; + options.kernelUseBarePtrCallConv = false; + options.hostUseBarePtrCallConv = false; + mlir::gpu::buildLowerToNVVMPassPipeline(pm, options); + pm.getDependentDialects(registry); + + registry.insert(); + } + + void runOnOperation() override { + auto context = getOperation()->getContext(); + + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(getOperation()); + + getOperation()->walk([&](KernelCallOp op) { + mlir::ArrayAttr operand_layouts = + op.getOperandLayouts() + ? cast(*op.getOperandLayouts()) + : nullptr; + mlir::ArrayAttr result_layouts = + op.getResultLayouts() ? cast(*op.getResultLayouts()) + : nullptr; + mlir::ArrayAttr output_operand_aliases = op.getOutputOperandAliases(); + + size_t data[8]; + + auto *symbolOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr()); + auto fn = cast(symbolOp); + + Value vals[] = {op.getGridx(), op.getGridy(), op.getGridz(), + op.getBlockx(), op.getBlocky(), op.getBlockz(), + op.getShmem()}; + for (auto en : llvm::enumerate(vals)) { + DenseIntElementsAttr stepAttr; + if (!matchPattern(en.value(), m_Constant(&stepAttr))) { + op->emitError() << "Cannot lower kernel with a grid/block size which " + "is not a constant integer tensor"; + return; + } + if (stepAttr.size() != 1) { + op->emitError() << "Cannot lower kernel with a grid/block size which " + "is not a constant integer tensor of size 1"; + return; + } + auto val = (*stepAttr.begin()).getZExtValue(); + data[1 + en.index()] = val; + } + + // Compiled kernel goes here once ready + data[0] = (size_t)CompileKernel(symbolTable, op.getLoc(), fn, jit, + data[1], data[2], data[3], data[4], + data[5], data[6], data[7]); + + std::string backendinfo((char *)&data, sizeof(void *)); + + OpBuilder rewriter(op); + auto replacement = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getInputs(), + rewriter.getStringAttr("enzymexla_gpu"), + /* has_side_effect*/ rewriter.getBoolAttr(false), + /*backend_config*/ rewriter.getStringAttr(backendinfo), + /* api_version*/ + CustomCallApiVersionAttr::get(rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion:: + API_VERSION_STATUS_RETURNING), + /*calledcomputations*/ nullptr, operand_layouts, result_layouts, + output_operand_aliases); + + op.replaceAllUsesWith(replacement); + op.erase(); + }); + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace enzyme { +std::unique_ptr createLowerKernelPass() { + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index 584ec22ce..8693981c9 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -21,7 +21,12 @@ std::unique_ptr createArithRaisingPass(); std::unique_ptr createEnzymeHLOOptPass(); std::unique_ptr createEnzymeHLOUnrollPass(); std::unique_ptr createPrintPass(); +std::unique_ptr createLowerKernelPass(); } // namespace enzyme + +// namespace enzymexla { +// class EnzymeXLADialect; +//} } // namespace mlir namespace mlir { @@ -41,6 +46,22 @@ namespace tensor { class TensorDialect; } // namespace tensor +namespace math { +class MathDialect; +} // namespace math + +namespace vector { +class VectorDialect; +} // namespace vector + +namespace nvgpu { +class NVGPUDialect; +} // namespace nvgpu + +namespace NVVM { +class NVVMDialect; +} // namespace NVVM + namespace stablehlo { class StablehloDialect; } // namespace stablehlo @@ -53,6 +74,10 @@ namespace cf { class ControlFlowDialect; } // end namespace cf +namespace gpu { +class GPUDialect; +} // end namespace gpu + namespace scf { class SCFDialect; } // end namespace scf @@ -61,6 +86,10 @@ namespace memref { class MemRefDialect; } // end namespace memref +namespace async { +class AsyncDialect; +} // namespace async + namespace func { class FuncDialect; } @@ -81,5 +110,6 @@ static void regsiterenzymeXLAPasses() { registerPrintPass(); registerEnzymeHLOOptPass(); registerEnzymeHLOUnrollPass(); + registerLowerKernelPass(); } #endif // ENZYMEXLA_PASSES_H diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 2aafc894e..bc20f6a68 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -116,4 +116,35 @@ def PrintPass : Pass<"print"> { ]; } +def LowerKernelPass : Pass<"lower-kernel"> { + let summary = "Lower kernel to custom call"; + let dependentDialects = [ + ]; + let constructor = "mlir::enzyme::createLowerKernelPass()"; + let dependentDialects = [ + "stablehlo::StablehloDialect", + "gpu::GPUDialect", + "func::FuncDialect", + "math::MathDialect", + "memref::MemRefDialect", + "scf::SCFDialect", + "vector::VectorDialect", + "nvgpu::NVGPUDialect", + "NVVM::NVVMDialect", + "LLVM::LLVMDialect", + "arith::ArithDialect", + "tensor::TensorDialect", + ]; + + let options = [ + Option< + /*C++ variable name=*/"jit", + /*CLI argument=*/"jit", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"Whether to jit the kernel" + > + ]; +} + #endif diff --git a/src/enzyme_ad/jax/RegistryUtils.cpp b/src/enzyme_ad/jax/RegistryUtils.cpp new file mode 100644 index 000000000..1eabc8b22 --- /dev/null +++ b/src/enzyme_ad/jax/RegistryUtils.cpp @@ -0,0 +1,113 @@ +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" + +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Implementations/XLADerivatives.h" + +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "Dialect/Dialect.h" +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/include/mlir/Target/LLVM/NVVM/Target.h" +#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace enzyme { +void registerEnzymeJaxTransformExtension(mlir::DialectRegistry ®istry); +} // namespace enzyme +} // namespace mlir + +void prepareRegistry(mlir::DialectRegistry ®istry) { + + // Register MLIR stuff + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + + registry.insert(); + registry.insert(); + + mlir::enzyme::registerXLAAutoDiffInterfaces(registry); + + mlir::func::registerInlinerExtension(registry); + + mlir::registerConvertNVVMToLLVMInterface(registry); + + registry.insert(); + mlir::registerConvertNVVMToLLVMInterface(registry); + mlir::registerConvertComplexToLLVMInterface(registry); + mlir::registerConvertMemRefToLLVMInterface(registry); + mlir::registerConvertMathToLLVMInterface(registry); + mlir::registerConvertFuncToLLVMInterface(registry); + mlir::index::registerConvertIndexToLLVMInterface(registry); + mlir::cf::registerConvertControlFlowToLLVMInterface(registry); + mlir::ub::registerConvertUBToLLVMInterface(registry); + mlir::arith::registerConvertArithToLLVMInterface(registry); + mlir::registerConvertMemRefToLLVMInterface(registry); + mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); + mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry); + mlir::registerBuiltinDialectTranslation(registry); + mlir::registerGPUDialectTranslation(registry); + mlir::registerLLVMDialectTranslation(registry); + mlir::registerNVVMDialectTranslation(registry); + + // Register the autodiff interface implementations for upstream dialects. + mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry); + + mlir::linalg::registerTransformDialectExtension(registry); + + mlir::enzyme::registerEnzymeJaxTransformExtension(registry); +} diff --git a/src/enzyme_ad/jax/RegistryUtils.h b/src/enzyme_ad/jax/RegistryUtils.h new file mode 100644 index 000000000..b99b943a4 --- /dev/null +++ b/src/enzyme_ad/jax/RegistryUtils.h @@ -0,0 +1,7 @@ +#pragma once + +namespace mlir { +class DialectRegistry; +} + +void prepareRegistry(mlir::DialectRegistry ®istry); diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index 258d68891..953211e84 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -18,6 +18,7 @@ #include "absl/status/statusor.h" #include "llvm/ADT/StringRef.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -25,6 +26,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" +#include "pybind11/pybind11.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/client/client_library.h" @@ -40,25 +42,13 @@ #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" -#include "pybind11/pybind11.h" - #include "compile_with_xla.h" -#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" -#include "Implementations/XLADerivatives.h" #include "TransformOps/TransformOps.h" -#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" - #include "pybind11/stl.h" -void prepareRegistry(mlir::DialectRegistry ®istry) { - mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry); - mlir::enzyme::registerXLAAutoDiffInterfaces(registry); - mlir::linalg::registerTransformDialectExtension(registry); - mlir::enzyme::registerEnzymeJaxTransformExtension(registry); - mlir::func::registerInlinerExtension(registry); -} +#include "RegistryUtils.h" /// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric /// suffix in `lastUsedID`. @@ -147,13 +137,6 @@ run_pass_pipeline(const std::vector &oldsym_vec, mlir::DialectRegistry registry; prepareRegistry(registry); MLIRContext context(registry); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); mlir::ParserConfig parser_config(&context); mlir::OwningOpRef parsed_module = mlir::parseSourceString(mlir, parser_config); @@ -301,13 +284,6 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, mlir::DialectRegistry registry; prepareRegistry(registry); mlir::MLIRContext context(registry); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); mlir::ParserConfig parser_config(&context); mlir::OwningOpRef parsed_module = mlir::parseSourceString(mhlo_text, parser_config); diff --git a/src/enzyme_ad/jax/compile_with_xla.h b/src/enzyme_ad/jax/compile_with_xla.h index 4ff1fec79..5626da7b6 100644 --- a/src/enzyme_ad/jax/compile_with_xla.h +++ b/src/enzyme_ad/jax/compile_with_xla.h @@ -19,3 +19,4 @@ namespace mlir { class Operation; } void run_pass_pipeline(mlir::Operation *mod, const std::string &pass_pipeline); +void prepareRegistry(mlir::DialectRegistry ®istry); diff --git a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp index 33c0c73c5..44d2463c6 100644 --- a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp +++ b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp @@ -11,12 +11,15 @@ // //===----------------------------------------------------------------------===// +#include "Dialect/Dialect.h" #include "Enzyme/MLIR/Dialect/Dialect.h" #include "Enzyme/MLIR/Dialect/Ops.h" #include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" #include "Enzyme/MLIR/Passes/Passes.h" #include "Implementations/XLADerivatives.h" #include "Passes/Passes.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -33,25 +36,34 @@ #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" +#include "llvm/Support/TargetSelect.h" + +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" + +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/tests/CheckOps.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" using namespace mlir; namespace mlir { namespace enzyme { -void registerEnzymeJaxTransformExtension(mlir::DialectRegistry ®istry); void registerGenerateApplyPatternsPass(); void registerRemoveTransformPass(); } // namespace enzyme @@ -65,38 +77,18 @@ struct PtrElementModel : public mlir::LLVM::PointerElementTypeInterface::ExternalModel< PtrElementModel, T> {}; +void prepareRegistry(mlir::DialectRegistry ®istry); + int main(int argc, char **argv) { - mlir::DialectRegistry registry; + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); - // Register MLIR stuff - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); + mlir::DialectRegistry registry; registry.insert(); - - registry.insert(); + prepareRegistry(registry); mlir::registerenzymePasses(); regsiterenzymeXLAPasses(); - mlir::enzyme::registerXLAAutoDiffInterfaces(registry); - - mlir::func::registerInlinerExtension(registry); // Register the standard passes we want. mlir::registerCSEPass(); @@ -124,15 +116,10 @@ int main(int argc, char **argv) { *ctx); }); - // Register the autodiff interface implementations for upstream dialects. - enzyme::registerCoreDialectAutodiffInterfaces(registry); - // Transform dialect and extensions. mlir::transform::registerInterpreterPass(); - mlir::linalg::registerTransformDialectExtension(registry); mlir::enzyme::registerGenerateApplyPatternsPass(); mlir::enzyme::registerRemoveTransformPass(); - mlir::enzyme::registerEnzymeJaxTransformExtension(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Enzyme modular optimizer driver", registry)); diff --git a/test/lit_tests/lowering/gpu.mlir b/test/lit_tests/lowering/gpu.mlir new file mode 100644 index 000000000..c262627b6 --- /dev/null +++ b/test/lit_tests/lowering/gpu.mlir @@ -0,0 +1,39 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(lower-kernel{jit=false})" | FileCheck %s + +module { + llvm.func internal unnamed_addr fastcc @throw_boundserror_2676() attributes {dso_local, no_inline, sym_visibility = "private"} { + llvm.unreachable + } + llvm.func internal ptx_kernelcc @kern(%arg0: !llvm.ptr<1>) { + %0 = llvm.mlir.constant(63 : i32) : i32 + %1 = nvvm.read.ptx.sreg.tid.x : i32 + %2 = llvm.icmp "ugt" %1, %0 : i32 + llvm.cond_br %2, ^bb2, ^bb1 + ^bb1: // pred: ^bb0 + %4 = llvm.zext %1 : i32 to i64 + %5 = llvm.getelementptr inbounds %arg0[%4] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64 + %6 = llvm.load %5 {alignment = 1 : i64} : !llvm.ptr<1> -> i64 + %7 = llvm.mul %6, %6 : i64 + llvm.store %7, %5 {alignment = 1 : i64} : i64, !llvm.ptr<1> + llvm.return + ^bb2: // pred: ^bb0 + llvm.call fastcc @throw_boundserror_2676() : () -> () + llvm.unreachable + } + func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> { + %c0 = stablehlo.constant dense<0> : tensor + %c1 = stablehlo.constant dense<1> : tensor + %c40 = stablehlo.constant dense<40> : tensor + %0 = enzymexla.kernel_call @kern blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c40) shmem=%c0 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<64xi64>) -> tensor<64xi64> + return %0 : tensor<64xi64> + } +} + +// CHECK: func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> { +// CHECK-NEXT: stablehlo.constant +// CHECK-NEXT: stablehlo.constant +// CHECK-NEXT: stablehlo.constant +// CHECK-NEXT: %0 = stablehlo.custom_call @enzymexla_gpu(%arg0) {api_version = 2 : i32, backend_config = "\00\00\00\00\00\00\00\00", output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor<64xi64>) -> tensor<64xi64> +// CHECK-NEXT: return %0 : tensor<64xi64> +// CHECK-NEXT: } +// CHECK-NEXT:} diff --git a/workspace.bzl b/workspace.bzl index 38e091233..1b0428bd7 100644 --- a/workspace.bzl +++ b/workspace.bzl @@ -1,7 +1,7 @@ JAX_COMMIT = "99b390ce962599da44c221169e0df709920f141c" JAX_SHA256 = "" -ENZYME_COMMIT = "068ad9c6f8bc7d8c7ad3806fd148492e323bc4b1" +ENZYME_COMMIT = "b387a389b11040b2d8de7849aa063e0087f0ae05" ENZYME_SHA256 = "" XLA_PATCHES = [