Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #19237: [GPU] Fix passing of key-value store handle from client to compiler. #19254

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ xla_cc_test(
"//xla/pjrt:pjrt_executable",
"//xla/pjrt:pjrt_future",
"//xla/pjrt:pjrt_stream_executor_client",
"//xla/pjrt/distributed",
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:in_memory_key_value_store",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/service:gpu_plugin",
Expand All @@ -187,13 +189,15 @@ xla_cc_test(
"//xla/stream_executor:stream",
"//xla/tests:literal_test_util",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/util:command_line_flags",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@llvm-project//mlir:IR",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:env",
Expand All @@ -202,7 +206,7 @@ xla_cc_test(
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_main",
"@tsl//tsl/platform:subprocess",
],
)

Expand Down
1 change: 0 additions & 1 deletion xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,6 @@ PjRtFuture<> StreamExecutorGpuClient::CopyRawSubBufferToHost(
absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>>
StreamExecutorGpuClient::Compile(const XlaComputation& computation,
CompileOptions options) {
options.executable_build_options.set_key_value_store(kv_store_);
auto executable = PjRtStreamExecutorClient::Compile(computation, options);

#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
Expand Down
127 changes: 127 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ limitations under the License.
#include "xla/layout.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/gpu/gpu_topology.h"
#include "xla/pjrt/host_memory_spaces.h"
Expand Down Expand Up @@ -73,6 +75,7 @@ limitations under the License.
#include "tsl/platform/status.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/subprocess.h"
#include "tsl/platform/threadpool.h"

namespace xla {
Expand Down Expand Up @@ -1786,5 +1789,129 @@ TEST(StreamExecutorGpuClientTest, AutoLayoutIsSupported) {
EXPECT_NE(layouts[1]->ToString(), "{2,1,0}");
}

class ShardedAutotuningTest : public ::testing::TestWithParam<bool> {
public:
static constexpr int kNumNodes = 2;
};

static const char* test_binary_name;

TEST_P(ShardedAutotuningTest, ShardedAutotuningWorks) {
tsl::SubProcess child[ShardedAutotuningTest::kNumNodes];
for (int node_id = 0; node_id < ShardedAutotuningTest::kNumNodes; ++node_id) {
std::vector<std::string> argv;
argv.push_back(test_binary_name);
argv.push_back(absl::StrFormat("--node_id=%d", node_id));
argv.push_back(absl::StrFormat("--use_xla_computation=%d", GetParam()));
child[node_id].SetProgram(test_binary_name, argv);
child[node_id].SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE);
child[node_id].SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE);
ASSERT_TRUE(child[node_id].Start()) << "node " << node_id;
}
for (int node_id = 0; node_id < ShardedAutotuningTest::kNumNodes; ++node_id) {
std::string stdout_str;
std::string stderr_str;
int child_status =
child[node_id].Communicate(nullptr, &stdout_str, &stderr_str);
EXPECT_EQ(child_status, 0) << " node " << node_id << "\nstdout:\n"
<< stdout_str << "\nstderr:\n"
<< stderr_str;
}
}

absl::Status ShardedAutotuningWorksTestBody(const int node_id,
bool use_xla_computation) {
tsl::setenv("CUDA_VISIBLE_DEVICES", std::to_string(node_id).data(),
/*overwrite=*/true);
std::unique_ptr<xla::DistributedRuntimeService> service;
if (node_id == 0) {
TF_ASSIGN_OR_RETURN(
service,
xla::GetDistributedRuntimeService(
"[::]:12345", xla::CoordinationServiceImpl::Options{
.num_nodes = ShardedAutotuningTest::kNumNodes}));
}

xla::DistributedRuntimeClient::Options distributed_options;
distributed_options.node_id = node_id;
distributed_options.init_timeout = absl::Seconds(120);
auto distributed_client =
GetDistributedRuntimeClient("127.0.0.1:12345", distributed_options);
TF_QCHECK_OK(distributed_client->Connect());
GpuClientOptions options;
options.node_id = node_id;
options.num_nodes = ShardedAutotuningTest::kNumNodes;
options.kv_store = GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
GetStreamExecutorGpuClient(options));
TF_RET_CHECK(client->platform_name() == "cuda");
TF_RET_CHECK(client->addressable_device_count() == 1);
TF_RET_CHECK(client->device_count() == 2);

CompileOptions compile_options;
compile_options.executable_build_options.mutable_debug_options()
->set_xla_gpu_shard_autotuning(true);
compile_options.executable_build_options.mutable_debug_options()
->set_xla_gpu_triton_gemm_any(true);
compile_options.executable_build_options.mutable_debug_options()
->set_xla_gpu_cublas_fallback(false);

mlir::MLIRContext context;
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(R"mlir(
func.func public @main(%arg0: tensor<2x2048x2048xf32>) ->
(tensor<2x2048x2048xf32> {jax.result_info = ""}) {
%0 = stablehlo.dot_general %arg0, %arg0, batching_dims = [0] x [0],
contracting_dims = [2] x [1]
: (tensor<2x2048x2048xf32>, tensor<2x2048x2048xf32>) ->
tensor<2x2048x2048xf32>
return %0 : tensor<2x2048x2048xf32>
})mlir",
context));
std::unique_ptr<PjRtLoadedExecutable> executable;
if (use_xla_computation) {
XlaComputation computation;
TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation,
/*use_tuple_args=*/false,
/*return_tuple=*/false,
/*use_shardy=*/false));
TF_ASSIGN_OR_RETURN(executable,
client->Compile(computation, compile_options));
} else {
TF_ASSIGN_OR_RETURN(executable, client->Compile(*module, compile_options));
}

TF_RET_CHECK(absl::StrContains(
executable->GetHloModules()->front()->ToString(), "triton_gemm"));

return absl::OkStatus();
}

INSTANTIATE_TEST_SUITE_P(ShardedAutotuningTest, ShardedAutotuningTest,
::testing::Values(false, true));

} // namespace
} // namespace xla

int main(int argc, char* argv[]) {
// Save name of binary so that it may invoke itself.
xla::test_binary_name = argv[0];
int node_id = -1;
bool use_xla_computation = false;
std::vector<tsl::Flag> flag_list = {
tsl::Flag("node_id", &node_id,
"Node ID for ShardedAutotuningWorks test."),
tsl::Flag("use_xla_computation", &use_xla_computation,
"Test parameter for ShardedAutotuningWorks."),
};
xla::AppendDebugOptionsFlags(&flag_list);
std::string usage = tsl::Flags::Usage(argv[0], flag_list);
tsl::Flags::Parse(&argc, argv, flag_list);
testing::InitGoogleTest(&argc, argv);
if (node_id >= 0) {
return !xla::ShardedAutotuningWorksTestBody(node_id, use_xla_computation)
.ok();
}
return RUN_ALL_TESTS();
}
3 changes: 3 additions & 0 deletions xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3514,6 +3514,9 @@ PjRtStreamExecutorClient::CompileInternal(
CompileOptions options) {
tsl::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
VLOG(1) << "PjRtStreamExecutorClient::Compile";
if (!options.executable_build_options.key_value_store()) {
options.executable_build_options.set_key_value_store(*key_value_store());
}
options.executable_build_options.set_process_index(process_index());
TF_RET_CHECK(device_count() % addressable_device_count() == 0)
<< "Each process is expected to have the same number of devices";
Expand Down
Loading