From 027cfb1602e11ce559ce8a888995b808d65197dd Mon Sep 17 00:00:00 2001 From: Georg Stefan Schmid Date: Fri, 27 Sep 2024 06:56:47 -0700 Subject: [PATCH] PR #17704: [jax.distributed] Allow enabling grpc channel compression Imported from GitHub PR https://github.com/openxla/xla/pull/17704 Allows passing an additional boolean argument `use_compression` via `xla_extension.get_distributed_runtime_client(...)` that controls whether compression is enabled on the gRPC channels created for each distributed runtime client. Motivation: XLA sends O(mesh) [device topologies](https://github.com/openxla/xla/blob/9fb4f21c3542c10b6a5bd98144801bbeec10b489/xla/pjrt/distributed/protocol.proto#L84) through its centralized coordination service and we have reason to believe that this becomes a bottleneck at large scale. Compression of the underlying gRPC communication is currently implicitly disabled, and might give us a cheap avenue to scale a bit further with the centralized KV store design. One small note: I refrained from adding `use_compression` to `DistributedRuntimeClient::Options` because the new flag is only relevant during channel creation in `distributed.cc`, but not within `DistributedRuntimeClient`. If we added `use_compression` to Options then the `GetDistributedRuntimeClient(channel, options)` defined in `client.cc` would seem to allow controlling compression, but it's really ignored. Let me know if you'd rather go that way. Corresponding JAX PR: https://github.com/jax-ml/jax/pull/23969 Copybara import of the project: -- 609e21a4c416afae7468a89cc44988db7b9828ac by Georg Stefan Schmid : [jax.distributed] Allow enabling grpc channel compression Merging this change closes #17704 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/17704 from gspschmid:gschmid/dist-compression 609e21a4c416afae7468a89cc44988db7b9828ac PiperOrigin-RevId: 679573391 --- xla/pjrt/distributed/client_server_test.cc | 18 ++++++++++++++++++ xla/pjrt/distributed/distributed.cc | 11 ++++++++--- xla/pjrt/distributed/distributed.h | 3 ++- xla/python/xla.cc | 9 ++++++--- xla/python/xla_extension/__init__.pyi | 1 + .../multihost_hlo_runner/create_client.cc | 3 ++- 6 files changed, 37 insertions(+), 8 deletions(-) diff --git a/xla/pjrt/distributed/client_server_test.cc b/xla/pjrt/distributed/client_server_test.cc index 462680d5fc764a..20f365194761d9 100644 --- a/xla/pjrt/distributed/client_server_test.cc +++ b/xla/pjrt/distributed/client_server_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "grpcpp/server_builder.h" #include "grpcpp/support/channel_arguments.h" #include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/distributed/service.h" #include "xla/pjrt/distributed/topology_util.h" @@ -996,5 +997,22 @@ TEST_F(ClientServerTest, KeyValueDelete_Directory) { EXPECT_THAT(kvs.value(), IsEmpty()); } +TEST_F(ClientServerTest, UseCompression) { + int port = tsl::testing::PickUnusedPortOrDie(); + StartService(/*num_nodes=*/1, /*service_options=*/{}, + absl::StrCat("[::]:", port)); + + // Sanity check that the client can connect with compression enabled. + DistributedRuntimeClient::Options client_options; + client_options.node_id = 0; + auto client = + GetDistributedRuntimeClient(absl::StrCat("dns:///localhost:", port), + /*use_compression=*/true, client_options); + + TF_ASSERT_OK(client->Connect()); + TF_ASSERT_OK(client->KeyValueSet("foo/bar/1", "1")); + TF_ASSERT_OK(client->Shutdown()); +} + } // namespace } // namespace xla diff --git a/xla/pjrt/distributed/distributed.cc b/xla/pjrt/distributed/distributed.cc index 69f9f2e249b402..9e0603f1b029f1 100644 --- a/xla/pjrt/distributed/distributed.cc +++ b/xla/pjrt/distributed/distributed.cc @@ -38,9 +38,14 @@ GetDistributedRuntimeService(std::string address, } std::shared_ptr GetDistributedRuntimeClient( - std::string address, const DistributedRuntimeClient::Options& options) { - std::shared_ptr channel = grpc::CreateChannel( - address, tsl::GetClientCredentials(kVerifySecureCredentials)); + std::string address, bool use_compression, + const DistributedRuntimeClient::Options& options) { + grpc::ChannelArguments args; + if (use_compression) { + args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP); + } + std::shared_ptr channel = grpc::CreateCustomChannel( + address, tsl::GetClientCredentials(kVerifySecureCredentials), args); return GetDistributedRuntimeClient(channel, options); } diff --git a/xla/pjrt/distributed/distributed.h b/xla/pjrt/distributed/distributed.h index 8145ddaa5c699e..a365635bb28198 100644 --- a/xla/pjrt/distributed/distributed.h +++ b/xla/pjrt/distributed/distributed.h @@ -40,7 +40,8 @@ GetDistributedRuntimeService(std::string address, // Builds a distributed runtime client, connecting to a service at `address`, // where address is a gRPC-style address such as `dns:///localhost:1234`. std::shared_ptr GetDistributedRuntimeClient( - std::string address, const DistributedRuntimeClient::Options& options); + std::string address, bool use_compression, + const DistributedRuntimeClient::Options& options); } // namespace xla diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 868a3aa9d74016..f861f5297d6fdf 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -779,8 +779,10 @@ NB_MODULE(xla_extension, m_nb) { std::optional> missed_heartbeat_callback, - std::optional shutdown_on_destruction) + std::optional shutdown_on_destruction, + std::optional use_compression) -> std::shared_ptr { + bool compression = use_compression.value_or(false); DistributedRuntimeClient::Options options; options.node_id = node_id; if (rpc_timeout.has_value()) { @@ -805,7 +807,7 @@ NB_MODULE(xla_extension, m_nb) { if (shutdown_on_destruction.has_value()) { options.shutdown_on_destruction = *shutdown_on_destruction; } - return GetDistributedRuntimeClient(address, options); + return GetDistributedRuntimeClient(address, compression, options); }, nb::arg("address"), nb::arg("node_id"), nb::arg("rpc_timeout").none() = std::nullopt, @@ -814,7 +816,8 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("heartbeat_interval").none() = std::nullopt, nb::arg("max_missing_heartbeats").none() = std::nullopt, nb::arg("missed_heartbeat_callback").none() = std::nullopt, - nb::arg("shutdown_on_destruction").none() = std::nullopt); + nb::arg("shutdown_on_destruction").none() = std::nullopt, + nb::arg("use_compression").none() = std::nullopt); m_nb.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index b5ae4c6431ca66..8dd5816d425ef4 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -821,6 +821,7 @@ def get_distributed_runtime_client( max_missing_heartbeats: Optional[int] = ..., missed_heartbeat_callback: Optional[Any] = ..., shutdown_on_destruction: Optional[bool] = ..., + use_compression: Optional[bool] = ..., ) -> DistributedRuntimeClient: ... class PreemptionSyncManager: diff --git a/xla/tools/multihost_hlo_runner/create_client.cc b/xla/tools/multihost_hlo_runner/create_client.cc index 3aabf56650af23..7637aa2a125c3c 100644 --- a/xla/tools/multihost_hlo_runner/create_client.cc +++ b/xla/tools/multihost_hlo_runner/create_client.cc @@ -83,7 +83,8 @@ static absl::StatusOr> GetPjRtClient( options.node_id = node_id; options.init_timeout = init_timeout; distributed_client = - GetDistributedRuntimeClient(std::string(address), options); + GetDistributedRuntimeClient(std::string(address), + /*use_compression=*/false, options); TF_QCHECK_OK(distributed_client->Connect()); kv_store = GetDistributedKeyValueStore(distributed_client, /*key_prefix=*/"gpu:");