diff --git a/xla/pjrt/distributed/client_server_test.cc b/xla/pjrt/distributed/client_server_test.cc index 462680d5fc764a..20f365194761d9 100644 --- a/xla/pjrt/distributed/client_server_test.cc +++ b/xla/pjrt/distributed/client_server_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "grpcpp/server_builder.h" #include "grpcpp/support/channel_arguments.h" #include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/distributed/service.h" #include "xla/pjrt/distributed/topology_util.h" @@ -996,5 +997,22 @@ TEST_F(ClientServerTest, KeyValueDelete_Directory) { EXPECT_THAT(kvs.value(), IsEmpty()); } +TEST_F(ClientServerTest, UseCompression) { + int port = tsl::testing::PickUnusedPortOrDie(); + StartService(/*num_nodes=*/1, /*service_options=*/{}, + absl::StrCat("[::]:", port)); + + // Sanity check that the client can connect with compression enabled. + DistributedRuntimeClient::Options client_options; + client_options.node_id = 0; + auto client = + GetDistributedRuntimeClient(absl::StrCat("dns:///localhost:", port), + /*use_compression=*/true, client_options); + + TF_ASSERT_OK(client->Connect()); + TF_ASSERT_OK(client->KeyValueSet("foo/bar/1", "1")); + TF_ASSERT_OK(client->Shutdown()); +} + } // namespace } // namespace xla diff --git a/xla/pjrt/distributed/distributed.cc b/xla/pjrt/distributed/distributed.cc index 69f9f2e249b402..9e0603f1b029f1 100644 --- a/xla/pjrt/distributed/distributed.cc +++ b/xla/pjrt/distributed/distributed.cc @@ -38,9 +38,14 @@ GetDistributedRuntimeService(std::string address, } std::shared_ptr GetDistributedRuntimeClient( - std::string address, const DistributedRuntimeClient::Options& options) { - std::shared_ptr channel = grpc::CreateChannel( - address, tsl::GetClientCredentials(kVerifySecureCredentials)); + std::string address, bool use_compression, + const DistributedRuntimeClient::Options& options) { + grpc::ChannelArguments args; + if (use_compression) { + args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP); + } + std::shared_ptr channel = grpc::CreateCustomChannel( + address, tsl::GetClientCredentials(kVerifySecureCredentials), args); return GetDistributedRuntimeClient(channel, options); } diff --git a/xla/pjrt/distributed/distributed.h b/xla/pjrt/distributed/distributed.h index 8145ddaa5c699e..a365635bb28198 100644 --- a/xla/pjrt/distributed/distributed.h +++ b/xla/pjrt/distributed/distributed.h @@ -40,7 +40,8 @@ GetDistributedRuntimeService(std::string address, // Builds a distributed runtime client, connecting to a service at `address`, // where address is a gRPC-style address such as `dns:///localhost:1234`. std::shared_ptr GetDistributedRuntimeClient( - std::string address, const DistributedRuntimeClient::Options& options); + std::string address, bool use_compression, + const DistributedRuntimeClient::Options& options); } // namespace xla diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 868a3aa9d74016..f861f5297d6fdf 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -779,8 +779,10 @@ NB_MODULE(xla_extension, m_nb) { std::optional> missed_heartbeat_callback, - std::optional shutdown_on_destruction) + std::optional shutdown_on_destruction, + std::optional use_compression) -> std::shared_ptr { + bool compression = use_compression.value_or(false); DistributedRuntimeClient::Options options; options.node_id = node_id; if (rpc_timeout.has_value()) { @@ -805,7 +807,7 @@ NB_MODULE(xla_extension, m_nb) { if (shutdown_on_destruction.has_value()) { options.shutdown_on_destruction = *shutdown_on_destruction; } - return GetDistributedRuntimeClient(address, options); + return GetDistributedRuntimeClient(address, compression, options); }, nb::arg("address"), nb::arg("node_id"), nb::arg("rpc_timeout").none() = std::nullopt, @@ -814,7 +816,8 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("heartbeat_interval").none() = std::nullopt, nb::arg("max_missing_heartbeats").none() = std::nullopt, nb::arg("missed_heartbeat_callback").none() = std::nullopt, - nb::arg("shutdown_on_destruction").none() = std::nullopt); + nb::arg("shutdown_on_destruction").none() = std::nullopt, + nb::arg("use_compression").none() = std::nullopt); m_nb.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index b5ae4c6431ca66..8dd5816d425ef4 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -821,6 +821,7 @@ def get_distributed_runtime_client( max_missing_heartbeats: Optional[int] = ..., missed_heartbeat_callback: Optional[Any] = ..., shutdown_on_destruction: Optional[bool] = ..., + use_compression: Optional[bool] = ..., ) -> DistributedRuntimeClient: ... class PreemptionSyncManager: diff --git a/xla/tools/multihost_hlo_runner/create_client.cc b/xla/tools/multihost_hlo_runner/create_client.cc index 3aabf56650af23..7637aa2a125c3c 100644 --- a/xla/tools/multihost_hlo_runner/create_client.cc +++ b/xla/tools/multihost_hlo_runner/create_client.cc @@ -83,7 +83,8 @@ static absl::StatusOr> GetPjRtClient( options.node_id = node_id; options.init_timeout = init_timeout; distributed_client = - GetDistributedRuntimeClient(std::string(address), options); + GetDistributedRuntimeClient(std::string(address), + /*use_compression=*/false, options); TF_QCHECK_OK(distributed_client->Connect()); kv_store = GetDistributedKeyValueStore(distributed_client, /*key_prefix=*/"gpu:");