Skip to content

Commit

Permalink
separate build test
Browse files Browse the repository at this point in the history
  • Loading branch information
smjleo committed Jan 25, 2025
1 parent a33e314 commit fb4b24f
Show file tree
Hide file tree
Showing 10 changed files with 335 additions and 67 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ http_archive(
load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
http_archive(
name = "xla",
sha256 = XLA_SHA256,
# sha256 = XLA_SHA256,
strip_prefix = "xla-" + XLA_COMMIT,
urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
urls = ["https://github.com/smjleo/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
patch_cmds = XLA_PATCHES,
)

Expand Down
110 changes: 110 additions & 0 deletions src/enzyme_ad/jax/AnalyticalCostModel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/backend.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"

#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>

#include "AnalyticalCostModel.h"
#include "RunXlaGpuPasses.h"

using namespace mlir;

uint64_t AnalyticalCostModel::getAnalyticalCost(ModuleOp &wrapperModule) {
std::unique_ptr<xla::HloModule> preOpt =
wrapperModuleToHloModule(wrapperModule);

// Run XLA passes (layout, fusion, simplification) to ensure what is being
// measured is what will be run
auto hloModule = runXlaGpuPasses(std::move(preOpt));

auto deviceDescription = getDeviceDescription();

xla::HloCostAnalysis::ShapeSizeFunction shapeSizeFunction =
[](const xla::Shape &shape) {
return xla::gpu::GetSizeOfShape(shape, 4);
};
xla::gpu::GpuHloCostAnalysis costAnalysis(
xla::gpu::GpuHloCostAnalysis::Options{shapeSizeFunction, {}, {}, true},
*deviceDescription);

assert(hloModule->computation_count() == 1);

uint64_t cost = -1;

for (auto c : hloModule->computations()) {
c->Accept(&costAnalysis);
// The op we are measuring should always be the return value, which is
// at the root.
auto op = c->root_instruction();

auto runtime = xla::gpu::GpuPerformanceModel::EstimateRunTimeForInstruction(
op, *deviceDescription, &costAnalysis,
xla::gpu::GpuPerformanceModelOptions::ForModule(op->GetModule()));
if (cost != -1) {
throw std::invalid_argument("found two computations");
}
cost = absl::ToInt64Nanoseconds(runtime.exec_time);
}

return cost;
}

/**
* Create XLA internal HloModule for the analytical cost model
*/
std::unique_ptr<xla::HloModule>
AnalyticalCostModel::wrapperModuleToHloModule(ModuleOp &wrapperModule) {
auto context = wrapperModule.getContext();
PassManager pm(context);
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
pm.run(wrapperModule);

MlirToHloConversionOptions options;
options.propagate_layouts = true;
options.return_tuple = false;

auto hloModule = ConvertMlirHloToHloModule(wrapperModule, options);
if (!hloModule.ok()) {
llvm::errs() << "Couldn't create hloModule: "
<< hloModule.status().message();
return nullptr;
} else {
return std::move(hloModule.value());
}
}

stream_executor::Platform *AnalyticalCostModel::getXlaPlatform() {
return xla::PlatformUtil::GetPlatform("cuda").value();
}

/**
* Get DeviceDescription for current device.
*/
std::unique_ptr<stream_executor::DeviceDescription>
AnalyticalCostModel::getDeviceDescription() {
// assume ordinal 0
return std::move(getXlaPlatform()->DescriptionForDevice(0).value());
}
46 changes: 46 additions & 0 deletions src/enzyme_ad/jax/AnalyticalCostModel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"

#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/backend.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"

#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>

class AnalyticalCostModel {
public:
static uint64_t getAnalyticalCost(mlir::ModuleOp &wrapperModule);

private:
/**
* Create XLA internal HloModule for the analytical cost model
*/
static std::unique_ptr<xla::HloModule>
wrapperModuleToHloModule(mlir::ModuleOp &wrapperModule);

static stream_executor::Platform *getXlaPlatform();

/**
* Get DeviceDescription for current device.
*/
static std::unique_ptr<stream_executor::DeviceDescription>
getDeviceDescription();
};
46 changes: 46 additions & 0 deletions src/enzyme_ad/jax/AnalyticalCostModelStub.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "AnalyticalCostModel.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"

#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/backend.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/model/gpu_performance_model.h"
#include "xla/service/platform_util.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"

#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>

uint64_t AnalyticalCostModel::getAnalyticalCost(mlir::ModuleOp &wrapperModule) {
throw std::runtime_error("stub");
}

std::unique_ptr<xla::HloModule>
AnalyticalCostModel::wrapperModuleToHloModule(mlir::ModuleOp &wrapperModule) {
throw std::runtime_error("stub");
}

stream_executor::Platform *AnalyticalCostModel::getXlaPlatform() {
throw std::runtime_error("stub");
}

std::unique_ptr<stream_executor::DeviceDescription>
AnalyticalCostModel::getDeviceDescription() {
throw std::runtime_error("stub");
}
71 changes: 71 additions & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,76 @@ gentbl_cc_library(
tblgen = "//:enzymexlamlir-tblgen",
)

cc_library(
name = "RunXlaGpuPassesHdr",
hdrs = ["RunXlaGpuPasses.h"],
deps = [
"@xla//xla/hlo/ir:hlo",
],
alwayslink = True,
linkstatic = True,
)

cc_library(
name = "RunXlaGpuPassesStub",
srcs = ["RunXlaGpuPassesStub.cpp"],
deps = [":RunXlaGpuPassesHdr"],
alwayslink = True,
linkstatic = True,
)

cc_library(
name = "RunXlaGpuPasses",
srcs = ["RunXlaGpuPasses.cpp"],
deps = [
":RunXlaGpuPassesHdr",

"@xla//xla/stream_executor:stream_executor_impl",
"@xla//xla/translate/stablehlo_to_hlo:translate",
"@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo",

"@xla//xla/service:backend",
"@xla//xla/service/gpu/model:gpu_performance_model",
"@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl",
"@xla//xla/service/gpu/model:gpu_hlo_cost_analysis",
"@xla//xla/service/gpu:gpu_latency_hiding_scheduler",
"@xla//xla/service/gpu:nvptx_compiler",
"@xla//xla/stream_executor:device_description",
]
)


cc_library(
name = "AnalyticalCostModel",
srcs = ["AnalyticalCostModel.cpp"],
hdrs = ["AnalyticalCostModel.h"],
deps = [
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",

"@stablehlo//:reference_ops",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:stablehlo_passes",
"@xla//xla/mlir_hlo",

"@xla//xla/stream_executor:stream_executor_impl",
"@xla//xla/translate/stablehlo_to_hlo:translate",
"@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo",

"@xla//xla/service:backend",
"@xla//xla/service/gpu/model:gpu_performance_model",
"@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl",
"@xla//xla/service/gpu/model:gpu_hlo_cost_analysis",
"@xla//xla/service/gpu:gpu_latency_hiding_scheduler",
"@xla//xla/stream_executor:device_description",
] + select({
"@bazel_tools//src/conditions:darwin": [":RunXlaGpuPassesStub"],
"//conditions:default": [":RunXlaGpuPasses"]}),
alwayslink = True,
linkstatic = True,
)

cc_library(
name = "XLADerivatives",
srcs = glob(
Expand All @@ -246,6 +316,7 @@ cc_library(
"-Werror=unused-result",
],
deps = [
":AnalyticalCostModel",
":EnzymeXLAPassesIncGen",
":EnzyeHLOPatternsIncGen",
":mhlo-derivatives",
Expand Down
Loading

0 comments on commit fb4b24f

Please sign in to comment.