Skip to content

Commit

Permalink
separate build test
Browse files Browse the repository at this point in the history
  • Loading branch information
smjleo committed Jan 25, 2025
1 parent a33e314 commit 5d513c2
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 67 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ http_archive(
load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
http_archive(
name = "xla",
sha256 = XLA_SHA256,
# sha256 = XLA_SHA256,
strip_prefix = "xla-" + XLA_COMMIT,
urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
urls = ["https://github.com/smjleo/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
patch_cmds = XLA_PATCHES,
)

Expand Down
123 changes: 123 additions & 0 deletions src/enzyme_ad/jax/AnalyticalCostModel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/backend.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"

#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>

#include "AnalyticalCostModel.h"
#include "RunXlaGpuPasses.h"

using namespace mlir;

uint64_t AnalyticalCostModel::getAnalyticalCost(ModuleOp &wrapperModule) {
std::unique_ptr<xla::HloModule> preOpt =
wrapperModuleToHloModule(wrapperModule);

// Run XLA passes (layout, fusion, simplification) to ensure what is being
// measured is what will be run
auto hloModule = runXlaGpuPasses(std::move(preOpt));

auto deviceDescription = getDeviceDescription();

xla::HloCostAnalysis::ShapeSizeFunction shapeSizeFunction =
[](const xla::Shape &shape) {
return xla::gpu::GetSizeOfShape(shape, 4);
};
xla::gpu::GpuHloCostAnalysis costAnalysis(
xla::gpu::GpuHloCostAnalysis::Options{shapeSizeFunction, {}, {}, true},
*deviceDescription);

assert(hloModule->computation_count() == 1);

uint64_t cost = -1;

for (auto c : hloModule->computations()) {
c->Accept(&costAnalysis);
// The op we are measuring should always be the return value, which is
// at the root.
auto op = c->root_instruction();

auto runtime = xla::gpu::GpuPerformanceModel::EstimateRunTimeForInstruction(
op, *deviceDescription, &costAnalysis,
xla::gpu::GpuPerformanceModelOptions::ForModule(op->GetModule()));
if (cost != -1) {
throw std::invalid_argument("found two computations");
}
cost = absl::ToInt64Nanoseconds(runtime.exec_time);
}

return cost;
}

/**
* Create XLA internal HloModule for the analytical cost model
*/
std::unique_ptr<xla::HloModule>
AnalyticalCostModel::wrapperModuleToHloModule(ModuleOp &wrapperModule) {
auto context = wrapperModule.getContext();
PassManager pm(context);
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
pm.run(wrapperModule);

MlirToHloConversionOptions options;
options.propagate_layouts = true;
options.return_tuple = false;

auto hloModule = ConvertMlirHloToHloModule(wrapperModule, options);
if (!hloModule.ok()) {
llvm::errs() << "Couldn't create hloModule: "
<< hloModule.status().message();
return nullptr;
} else {
return std::move(hloModule.value());
}
}

stream_executor::Platform *AnalyticalCostModel::getXlaPlatform() {
return xla::PlatformUtil::GetPlatform("cuda").value();
}

/**
* Get DeviceDescription for current device.
*/
std::unique_ptr<stream_executor::DeviceDescription>
AnalyticalCostModel::getDeviceDescription() {
// assume ordinal 0
return std::move(getXlaPlatform()->DescriptionForDevice(0).value());
}

/**
* Get StreamExecutor for current device.
*/
stream_executor::StreamExecutor *AnalyticalCostModel::getStreamExecutor() {
// assume ordinal 0
auto executor = getXlaPlatform()->ExecutorForDevice(0).value();
if (executor == nullptr) {
throw std::runtime_error("Couldn't get executor");
}

return executor;
}
51 changes: 51 additions & 0 deletions src/enzyme_ad/jax/AnalyticalCostModel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"

#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/backend.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"

#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>

class AnalyticalCostModel {
public:
static uint64_t getAnalyticalCost(mlir::ModuleOp &wrapperModule);

private:
/**
* Create XLA internal HloModule for the analytical cost model
*/
static std::unique_ptr<xla::HloModule>
wrapperModuleToHloModule(mlir::ModuleOp &wrapperModule);

static stream_executor::Platform *getXlaPlatform();

/**
* Get DeviceDescription for current device.
*/
static std::unique_ptr<stream_executor::DeviceDescription>
getDeviceDescription();

/**
* Get StreamExecutor for current device.
*/
static stream_executor::StreamExecutor *getStreamExecutor();
};
71 changes: 71 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,76 @@ gentbl_cc_library(
tblgen = "//:enzymexlamlir-tblgen",
)

cc_library(
name = "RunXlaGpuPassesHdr",
hdrs = ["RunXlaGpuPasses.h"],
deps = [
"@xla//xla/hlo/ir:hlo",
],
alwayslink = True,
linkstatic = True,
)

cc_library(
name = "RunXlaGpuPassesStub",
srcs = ["RunXlaGpuPassesStub.cpp"],
deps = [":RunXlaGpuPassesHdr"],
alwayslink = True,
linkstatic = True,
)

cc_library(
name = "RunXlaGpuPasses",
srcs = ["RunXlaGpuPasses.cpp"],
deps = [
":RunXlaGpuPassesHdr",

"@xla//xla/stream_executor:stream_executor_impl",
"@xla//xla/translate/stablehlo_to_hlo:translate",
"@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo",

"@xla//xla/service:backend",
"@xla//xla/service/gpu/model:gpu_performance_model",
"@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl",
"@xla//xla/service/gpu/model:gpu_hlo_cost_analysis",
"@xla//xla/service/gpu:gpu_latency_hiding_scheduler",
"@xla//xla/service/gpu:nvptx_compiler",
"@xla//xla/stream_executor:device_description",
]
)


cc_library(
name = "AnalyticalCostModel",
srcs = ["AnalyticalCostModel.cpp"],
hdrs = ["AnalyticalCostModel.h"],
deps = [
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",

"@stablehlo//:reference_ops",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:stablehlo_passes",
"@xla//xla/mlir_hlo",

"@xla//xla/stream_executor:stream_executor_impl",
"@xla//xla/translate/stablehlo_to_hlo:translate",
"@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo",

"@xla//xla/service:backend",
"@xla//xla/service/gpu/model:gpu_performance_model",
"@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl",
"@xla//xla/service/gpu/model:gpu_hlo_cost_analysis",
"@xla//xla/service/gpu:gpu_latency_hiding_scheduler",
"@xla//xla/stream_executor:device_description",
] + select({
"@bazel_tools//src/conditions:darwin": [":RunXlaGpuPassesStub"],
"//conditions:default": [":RunXlaGpuPasses"]}),
alwayslink = True,
linkstatic = True,
)

cc_library(
name = "XLADerivatives",
srcs = glob(
Expand All @@ -246,6 +316,7 @@ cc_library(
"-Werror=unused-result",
],
deps = [
":AnalyticalCostModel",
":EnzymeXLAPassesIncGen",
":EnzyeHLOPatternsIncGen",
":mhlo-derivatives",
Expand Down
70 changes: 5 additions & 65 deletions src/enzyme_ad/jax/Passes/EqualitySaturation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,21 @@
#include "xla/pjrt/status_casters.h"
#include "xla/python/ifrt/executable.h"
#include "xla/python/pjrt_ifrt/xla_compiler.h"
#include "xla/service/backend.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"

#include "cxxbridge/deps/tensat/src/input.rs.h"
#include "rust/cxx.h"

#include "../AnalyticalCostModel.h"
#include "Passes.h"

#include <chrono>
Expand Down Expand Up @@ -265,35 +269,7 @@ class OperationTimer {
throw std::invalid_argument("gpu only");
}
case GPU: {
std::unique_ptr<xla::HloModule> hloModule =
wrapperModuleToHloModule(wrapperModule);

auto deviceDescription = getDeviceDescription();

xla::HloCostAnalysis::ShapeSizeFunction shapeSizeFunction =
[](const xla::Shape &shape) {
return xla::gpu::GetSizeOfShape(shape, 4);
};
xla::gpu::GpuHloCostAnalysis costAnalysis(
xla::gpu::GpuHloCostAnalysis::Options{
shapeSizeFunction, {}, {}, true},
*deviceDescription);

assert(hloModule->computation_count() == 1);
for (auto c : hloModule->computations()) {
c->Accept(&costAnalysis);
// The op we are measuring should always be the return value, which is
// at the root.
auto op = c->root_instruction();

auto runtime =
xla::gpu::GpuPerformanceModel::EstimateRunTimeForInstruction(
op, *deviceDescription, &costAnalysis,
xla::gpu::GpuPerformanceModelOptions::ForModule(
op->GetModule()));
cost = absl::ToInt64Nanoseconds(runtime.exec_time);
fus_cost = absl::ToInt64Nanoseconds(runtime.compute_time);
}
cost = AnalyticalCostModel::getAnalyticalCost(wrapperModule);

std::cout << "Empirical: " << empirical_cost << ", modelled: " << cost
<< std::endl;
Expand Down Expand Up @@ -430,42 +406,6 @@ class OperationTimer {
return buffer;
}

/**
* Create XLA internal HloModule for the analytical cost model
*/
static std::unique_ptr<xla::HloModule>
wrapperModuleToHloModule(ModuleOp &wrapperModule) {
auto context = wrapperModule.getContext();
PassManager pm(context);
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
pm.run(wrapperModule);

MlirToHloConversionOptions options;
options.propagate_layouts = true;
options.return_tuple = false;

auto hloModule = ConvertMlirHloToHloModule(wrapperModule, options);
if (!hloModule.ok()) {
llvm::errs() << "Couldn't create hloModule: "
<< hloModule.status().message();
return nullptr;
} else {
return std::move(hloModule.value());
}
}

/**
* Get DeviceDescription for current device.
*/
static std::unique_ptr<stream_executor::DeviceDescription>
getDeviceDescription() {
auto platform = xla::PlatformUtil::GetPlatform(
(getPlatform() == EqsatPlatform::CPU ? "cpu" : "cuda"))
.value();
// assume ordinal 0
return std::move(platform->DescriptionForDevice(0).value());
}

/**
* Measure cost of operation empirically.
*/
Expand Down
24 changes: 24 additions & 0 deletions src/enzyme_ad/jax/RunXlaGpuPasses.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/backend.h"
#include "xla/service/compiler.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model.h"
#include "xla/service/gpu/nvptx_compiler.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"

#include "RunXlaGpuPasses.h"

std::unique_ptr<xla::HloModule>
runXlaGpuPasses(std::unique_ptr<xla::HloModule> hloModule) {
xla::gpu::NVPTXCompiler compiler;
auto executor = getStreamExecutor();
xla::gpu::NVPTXCompiler::CompileOptions options;
auto res = compiler.RunHloPasses(hloModule, executor, options);
return res;
}
Loading

0 comments on commit 5d513c2

Please sign in to comment.