From fb4b24f6aaa8793edd0eb11ec56bce2f4245b45e Mon Sep 17 00:00:00 2001 From: smjleo Date: Sat, 25 Jan 2025 14:50:42 +0000 Subject: [PATCH] separate build test --- WORKSPACE | 4 +- src/enzyme_ad/jax/AnalyticalCostModel.cpp | 110 ++++++++++++++++++ src/enzyme_ad/jax/AnalyticalCostModel.h | 46 ++++++++ src/enzyme_ad/jax/AnalyticalCostModelStub.cpp | 46 ++++++++ src/enzyme_ad/jax/BUILD | 71 +++++++++++ .../jax/Passes/EqualitySaturation.cpp | 70 +---------- src/enzyme_ad/jax/RunXlaGpuPasses.cpp | 38 ++++++ src/enzyme_ad/jax/RunXlaGpuPasses.h | 5 + src/enzyme_ad/jax/RunXlaGpuPassesStub.cpp | 10 ++ workspace.bzl | 2 + 10 files changed, 335 insertions(+), 67 deletions(-) create mode 100644 src/enzyme_ad/jax/AnalyticalCostModel.cpp create mode 100644 src/enzyme_ad/jax/AnalyticalCostModel.h create mode 100644 src/enzyme_ad/jax/AnalyticalCostModelStub.cpp create mode 100644 src/enzyme_ad/jax/RunXlaGpuPasses.cpp create mode 100644 src/enzyme_ad/jax/RunXlaGpuPasses.h create mode 100644 src/enzyme_ad/jax/RunXlaGpuPassesStub.cpp diff --git a/WORKSPACE b/WORKSPACE index 7c6745cdb..77ecde451 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -28,9 +28,9 @@ http_archive( load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256") http_archive( name = "xla", - sha256 = XLA_SHA256, + # sha256 = XLA_SHA256, strip_prefix = "xla-" + XLA_COMMIT, - urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], + urls = ["https://github.com/smjleo/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], patch_cmds = XLA_PATCHES, ) diff --git a/src/enzyme_ad/jax/AnalyticalCostModel.cpp b/src/enzyme_ad/jax/AnalyticalCostModel.cpp new file mode 100644 index 000000000..b7b1ed2d2 --- /dev/null +++ b/src/enzyme_ad/jax/AnalyticalCostModel.cpp @@ -0,0 +1,110 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" + +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/backend.h" +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" + +#include +#include +#include +#include + +#include "AnalyticalCostModel.h" +#include "RunXlaGpuPasses.h" + +using namespace mlir; + +uint64_t AnalyticalCostModel::getAnalyticalCost(ModuleOp &wrapperModule) { + std::unique_ptr preOpt = + wrapperModuleToHloModule(wrapperModule); + + // Run XLA passes (layout, fusion, simplification) to ensure what is being + // measured is what will be run + auto hloModule = runXlaGpuPasses(std::move(preOpt)); + + auto deviceDescription = getDeviceDescription(); + + xla::HloCostAnalysis::ShapeSizeFunction shapeSizeFunction = + [](const xla::Shape &shape) { + return xla::gpu::GetSizeOfShape(shape, 4); + }; + xla::gpu::GpuHloCostAnalysis costAnalysis( + xla::gpu::GpuHloCostAnalysis::Options{shapeSizeFunction, {}, {}, true}, + *deviceDescription); + + assert(hloModule->computation_count() == 1); + + uint64_t cost = -1; + + for (auto c : hloModule->computations()) { + c->Accept(&costAnalysis); + // The op we are measuring should always be the return value, which is + // at the root. + auto op = c->root_instruction(); + + auto runtime = xla::gpu::GpuPerformanceModel::EstimateRunTimeForInstruction( + op, *deviceDescription, &costAnalysis, + xla::gpu::GpuPerformanceModelOptions::ForModule(op->GetModule())); + if (cost != -1) { + throw std::invalid_argument("found two computations"); + } + cost = absl::ToInt64Nanoseconds(runtime.exec_time); + } + + return cost; +} + +/** + * Create XLA internal HloModule for the analytical cost model + */ +std::unique_ptr +AnalyticalCostModel::wrapperModuleToHloModule(ModuleOp &wrapperModule) { + auto context = wrapperModule.getContext(); + PassManager pm(context); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + pm.run(wrapperModule); + + MlirToHloConversionOptions options; + options.propagate_layouts = true; + options.return_tuple = false; + + auto hloModule = ConvertMlirHloToHloModule(wrapperModule, options); + if (!hloModule.ok()) { + llvm::errs() << "Couldn't create hloModule: " + << hloModule.status().message(); + return nullptr; + } else { + return std::move(hloModule.value()); + } +} + +stream_executor::Platform *AnalyticalCostModel::getXlaPlatform() { + return xla::PlatformUtil::GetPlatform("cuda").value(); +} + +/** + * Get DeviceDescription for current device. + */ +std::unique_ptr +AnalyticalCostModel::getDeviceDescription() { + // assume ordinal 0 + return std::move(getXlaPlatform()->DescriptionForDevice(0).value()); +} diff --git a/src/enzyme_ad/jax/AnalyticalCostModel.h b/src/enzyme_ad/jax/AnalyticalCostModel.h new file mode 100644 index 000000000..98e7b4ac2 --- /dev/null +++ b/src/enzyme_ad/jax/AnalyticalCostModel.h @@ -0,0 +1,46 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Verifier.h" + +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/backend.h" +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" + +#include +#include +#include +#include + +class AnalyticalCostModel { +public: + static uint64_t getAnalyticalCost(mlir::ModuleOp &wrapperModule); + +private: + /** + * Create XLA internal HloModule for the analytical cost model + */ + static std::unique_ptr + wrapperModuleToHloModule(mlir::ModuleOp &wrapperModule); + + static stream_executor::Platform *getXlaPlatform(); + + /** + * Get DeviceDescription for current device. + */ + static std::unique_ptr + getDeviceDescription(); +}; diff --git a/src/enzyme_ad/jax/AnalyticalCostModelStub.cpp b/src/enzyme_ad/jax/AnalyticalCostModelStub.cpp new file mode 100644 index 000000000..979ba62c8 --- /dev/null +++ b/src/enzyme_ad/jax/AnalyticalCostModelStub.cpp @@ -0,0 +1,46 @@ +#include "AnalyticalCostModel.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Verifier.h" + +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/backend.h" +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" + +#include +#include +#include +#include + +uint64_t AnalyticalCostModel::getAnalyticalCost(mlir::ModuleOp &wrapperModule) { + throw std::runtime_error("stub"); +} + +std::unique_ptr +AnalyticalCostModel::wrapperModuleToHloModule(mlir::ModuleOp &wrapperModule) { + throw std::runtime_error("stub"); +} + +stream_executor::Platform *AnalyticalCostModel::getXlaPlatform() { + throw std::runtime_error("stub"); +} + +std::unique_ptr +AnalyticalCostModel::getDeviceDescription() { + throw std::runtime_error("stub"); +} diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index f5978b72d..25e39b136 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -225,6 +225,76 @@ gentbl_cc_library( tblgen = "//:enzymexlamlir-tblgen", ) +cc_library( + name = "RunXlaGpuPassesHdr", + hdrs = ["RunXlaGpuPasses.h"], + deps = [ + "@xla//xla/hlo/ir:hlo", + ], + alwayslink = True, + linkstatic = True, +) + +cc_library( + name = "RunXlaGpuPassesStub", + srcs = ["RunXlaGpuPassesStub.cpp"], + deps = [":RunXlaGpuPassesHdr"], + alwayslink = True, + linkstatic = True, +) + +cc_library( + name = "RunXlaGpuPasses", + srcs = ["RunXlaGpuPasses.cpp"], + deps = [ + ":RunXlaGpuPassesHdr", + + "@xla//xla/stream_executor:stream_executor_impl", + "@xla//xla/translate/stablehlo_to_hlo:translate", + "@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + + "@xla//xla/service:backend", + "@xla//xla/service/gpu/model:gpu_performance_model", + "@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl", + "@xla//xla/service/gpu/model:gpu_hlo_cost_analysis", + "@xla//xla/service/gpu:gpu_latency_hiding_scheduler", + "@xla//xla/service/gpu:nvptx_compiler", + "@xla//xla/stream_executor:device_description", + ] +) + + +cc_library( + name = "AnalyticalCostModel", + srcs = ["AnalyticalCostModel.cpp"], + hdrs = ["AnalyticalCostModel.h"], + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + + "@stablehlo//:reference_ops", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", + "@xla//xla/mlir_hlo", + + "@xla//xla/stream_executor:stream_executor_impl", + "@xla//xla/translate/stablehlo_to_hlo:translate", + "@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + + "@xla//xla/service:backend", + "@xla//xla/service/gpu/model:gpu_performance_model", + "@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl", + "@xla//xla/service/gpu/model:gpu_hlo_cost_analysis", + "@xla//xla/service/gpu:gpu_latency_hiding_scheduler", + "@xla//xla/stream_executor:device_description", + ] + select({ + "@bazel_tools//src/conditions:darwin": [":RunXlaGpuPassesStub"], + "//conditions:default": [":RunXlaGpuPasses"]}), + alwayslink = True, + linkstatic = True, +) + cc_library( name = "XLADerivatives", srcs = glob( @@ -246,6 +316,7 @@ cc_library( "-Werror=unused-result", ], deps = [ + ":AnalyticalCostModel", ":EnzymeXLAPassesIncGen", ":EnzyeHLOPatternsIncGen", ":mhlo-derivatives", diff --git a/src/enzyme_ad/jax/Passes/EqualitySaturation.cpp b/src/enzyme_ad/jax/Passes/EqualitySaturation.cpp index e30b11a9a..ada9711eb 100644 --- a/src/enzyme_ad/jax/Passes/EqualitySaturation.cpp +++ b/src/enzyme_ad/jax/Passes/EqualitySaturation.cpp @@ -32,17 +32,21 @@ #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/executable.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/service/backend.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" #include "xla/service/platform_util.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "cxxbridge/deps/tensat/src/input.rs.h" #include "rust/cxx.h" +#include "../AnalyticalCostModel.h" #include "Passes.h" #include @@ -265,35 +269,7 @@ class OperationTimer { throw std::invalid_argument("gpu only"); } case GPU: { - std::unique_ptr hloModule = - wrapperModuleToHloModule(wrapperModule); - - auto deviceDescription = getDeviceDescription(); - - xla::HloCostAnalysis::ShapeSizeFunction shapeSizeFunction = - [](const xla::Shape &shape) { - return xla::gpu::GetSizeOfShape(shape, 4); - }; - xla::gpu::GpuHloCostAnalysis costAnalysis( - xla::gpu::GpuHloCostAnalysis::Options{ - shapeSizeFunction, {}, {}, true}, - *deviceDescription); - - assert(hloModule->computation_count() == 1); - for (auto c : hloModule->computations()) { - c->Accept(&costAnalysis); - // The op we are measuring should always be the return value, which is - // at the root. - auto op = c->root_instruction(); - - auto runtime = - xla::gpu::GpuPerformanceModel::EstimateRunTimeForInstruction( - op, *deviceDescription, &costAnalysis, - xla::gpu::GpuPerformanceModelOptions::ForModule( - op->GetModule())); - cost = absl::ToInt64Nanoseconds(runtime.exec_time); - fus_cost = absl::ToInt64Nanoseconds(runtime.compute_time); - } + cost = AnalyticalCostModel::getAnalyticalCost(wrapperModule); std::cout << "Empirical: " << empirical_cost << ", modelled: " << cost << std::endl; @@ -430,42 +406,6 @@ class OperationTimer { return buffer; } - /** - * Create XLA internal HloModule for the analytical cost model - */ - static std::unique_ptr - wrapperModuleToHloModule(ModuleOp &wrapperModule) { - auto context = wrapperModule.getContext(); - PassManager pm(context); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.run(wrapperModule); - - MlirToHloConversionOptions options; - options.propagate_layouts = true; - options.return_tuple = false; - - auto hloModule = ConvertMlirHloToHloModule(wrapperModule, options); - if (!hloModule.ok()) { - llvm::errs() << "Couldn't create hloModule: " - << hloModule.status().message(); - return nullptr; - } else { - return std::move(hloModule.value()); - } - } - - /** - * Get DeviceDescription for current device. - */ - static std::unique_ptr - getDeviceDescription() { - auto platform = xla::PlatformUtil::GetPlatform( - (getPlatform() == EqsatPlatform::CPU ? "cpu" : "cuda")) - .value(); - // assume ordinal 0 - return std::move(platform->DescriptionForDevice(0).value()); - } - /** * Measure cost of operation empirically. */ diff --git a/src/enzyme_ad/jax/RunXlaGpuPasses.cpp b/src/enzyme_ad/jax/RunXlaGpuPasses.cpp new file mode 100644 index 000000000..7700cbe75 --- /dev/null +++ b/src/enzyme_ad/jax/RunXlaGpuPasses.cpp @@ -0,0 +1,38 @@ +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/service/backend.h" +#include "xla/service/compiler.h" +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_performance_model.h" +#include "xla/service/gpu/nvptx_compiler.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" + +#include "RunXlaGpuPasses.h" + +/** + * Get StreamExecutor for current device. + */ +stream_executor::StreamExecutor *getStreamExecutor() { + auto platform = xla::PlatformUtil::GetPlatform("cuda").value(); + // assume ordinal 0 + auto executor = platform->ExecutorForDevice(0).value(); + if (executor == nullptr) { + throw std::runtime_error("Couldn't get executor"); + } + + return executor; +} + +std::unique_ptr +runXlaGpuPasses(std::unique_ptr hloModule) { + xla::gpu::NVPTXCompiler compiler; + auto executor = getStreamExecutor(); + xla::gpu::NVPTXCompiler::CompileOptions options; + auto res = compiler.RunHloPasses(std::move(hloModule), executor, options); + return std::move(res.value()); +} diff --git a/src/enzyme_ad/jax/RunXlaGpuPasses.h b/src/enzyme_ad/jax/RunXlaGpuPasses.h new file mode 100644 index 000000000..d0bad97c0 --- /dev/null +++ b/src/enzyme_ad/jax/RunXlaGpuPasses.h @@ -0,0 +1,5 @@ +#include "xla/hlo/ir/hlo_module.h" +#include + +std::unique_ptr +runXlaGpuPasses(std::unique_ptr hloModule); diff --git a/src/enzyme_ad/jax/RunXlaGpuPassesStub.cpp b/src/enzyme_ad/jax/RunXlaGpuPassesStub.cpp new file mode 100644 index 000000000..d7706e668 --- /dev/null +++ b/src/enzyme_ad/jax/RunXlaGpuPassesStub.cpp @@ -0,0 +1,10 @@ +#include "xla/hlo/ir/hlo_module.h" +#include + +#include "RunXlaGpuPasses.h" + +/** Dummy implementation to make it build on Mac. */ +std::unique_ptr +runXlaGpuPasses(std::unique_ptr hloModule) { + std::runtime_error("stub"); +} diff --git a/workspace.bzl b/workspace.bzl index f39334463..13f462bc1 100644 --- a/workspace.bzl +++ b/workspace.bzl @@ -4,6 +4,8 @@ JAX_SHA256 = "" ENZYME_COMMIT = "f1f4d8e62856286efaa0df8c622711b17aa191c3" ENZYME_SHA256 = "" +# XLA_COMMIT = "812837828b6c86036ba68d3cacc770aa13fb808d" + XLA_PATCHES = [ """ sed -i.bak0 "s/\\/\\/third_party:repo.bzl/@bazel_tools\\/\\/tools\\/build_defs\\/repo:http.bzl/g" third_party/llvm/workspace.bzl