Skip to content

Commit

Permalink
PR #17704: [jax.distributed] Allow enabling grpc channel compression
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17704

Allows passing an additional boolean argument `use_compression` via `xla_extension.get_distributed_runtime_client(...)` that controls whether compression is enabled on the gRPC channels created for each distributed runtime client.

Motivation: XLA sends O(mesh) [device topologies](https://github.com/openxla/xla/blob/9fb4f21c3542c10b6a5bd98144801bbeec10b489/xla/pjrt/distributed/protocol.proto#L84) through its centralized coordination service and we have reason to believe that this becomes a bottleneck at large scale. Compression of the underlying gRPC communication is currently implicitly disabled, and might give us a cheap avenue to scale a bit further with the centralized KV store design.

One small note: I refrained from adding `use_compression` to `DistributedRuntimeClient::Options` because the new flag is only relevant during channel creation in `distributed.cc`, but not within `DistributedRuntimeClient`. If we added `use_compression` to Options then the `GetDistributedRuntimeClient(channel, options)` defined in `client.cc` would seem to allow controlling compression, but it's really ignored. Let me know if you'd rather go that way.

Corresponding JAX PR: jax-ml/jax#23969
Copybara import of the project:

--
609e21a by Georg Stefan Schmid <[email protected]>:

[jax.distributed] Allow enabling grpc channel compression

Merging this change closes #17704

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17704 from gspschmid:gschmid/dist-compression 609e21a
PiperOrigin-RevId: 679573391
  • Loading branch information
gspschmid authored and Google-ML-Automation committed Sep 27, 2024
1 parent 6c3194e commit 027cfb1
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 8 deletions.
18 changes: 18 additions & 0 deletions xla/pjrt/distributed/client_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "grpcpp/server_builder.h"
#include "grpcpp/support/channel_arguments.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/protocol.pb.h"
#include "xla/pjrt/distributed/service.h"
#include "xla/pjrt/distributed/topology_util.h"
Expand Down Expand Up @@ -996,5 +997,22 @@ TEST_F(ClientServerTest, KeyValueDelete_Directory) {
EXPECT_THAT(kvs.value(), IsEmpty());
}

TEST_F(ClientServerTest, UseCompression) {
int port = tsl::testing::PickUnusedPortOrDie();
StartService(/*num_nodes=*/1, /*service_options=*/{},
absl::StrCat("[::]:", port));

// Sanity check that the client can connect with compression enabled.
DistributedRuntimeClient::Options client_options;
client_options.node_id = 0;
auto client =
GetDistributedRuntimeClient(absl::StrCat("dns:///localhost:", port),
/*use_compression=*/true, client_options);

TF_ASSERT_OK(client->Connect());
TF_ASSERT_OK(client->KeyValueSet("foo/bar/1", "1"));
TF_ASSERT_OK(client->Shutdown());
}

} // namespace
} // namespace xla
11 changes: 8 additions & 3 deletions xla/pjrt/distributed/distributed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@ GetDistributedRuntimeService(std::string address,
}

std::shared_ptr<DistributedRuntimeClient> GetDistributedRuntimeClient(
std::string address, const DistributedRuntimeClient::Options& options) {
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
address, tsl::GetClientCredentials(kVerifySecureCredentials));
std::string address, bool use_compression,
const DistributedRuntimeClient::Options& options) {
grpc::ChannelArguments args;
if (use_compression) {
args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
}
std::shared_ptr<grpc::Channel> channel = grpc::CreateCustomChannel(
address, tsl::GetClientCredentials(kVerifySecureCredentials), args);
return GetDistributedRuntimeClient(channel, options);
}

Expand Down
3 changes: 2 additions & 1 deletion xla/pjrt/distributed/distributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ GetDistributedRuntimeService(std::string address,
// Builds a distributed runtime client, connecting to a service at `address`,
// where address is a gRPC-style address such as `dns:///localhost:1234`.
std::shared_ptr<DistributedRuntimeClient> GetDistributedRuntimeClient(
std::string address, const DistributedRuntimeClient::Options& options);
std::string address, bool use_compression,
const DistributedRuntimeClient::Options& options);

} // namespace xla

Expand Down
9 changes: 6 additions & 3 deletions xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -779,8 +779,10 @@ NB_MODULE(xla_extension, m_nb) {
std::optional<std::function<void(absl::Status,
bool coordinator_reported_failure)>>
missed_heartbeat_callback,
std::optional<bool> shutdown_on_destruction)
std::optional<bool> shutdown_on_destruction,
std::optional<bool> use_compression)
-> std::shared_ptr<DistributedRuntimeClient> {
bool compression = use_compression.value_or(false);
DistributedRuntimeClient::Options options;
options.node_id = node_id;
if (rpc_timeout.has_value()) {
Expand All @@ -805,7 +807,7 @@ NB_MODULE(xla_extension, m_nb) {
if (shutdown_on_destruction.has_value()) {
options.shutdown_on_destruction = *shutdown_on_destruction;
}
return GetDistributedRuntimeClient(address, options);
return GetDistributedRuntimeClient(address, compression, options);
},
nb::arg("address"), nb::arg("node_id"),
nb::arg("rpc_timeout").none() = std::nullopt,
Expand All @@ -814,7 +816,8 @@ NB_MODULE(xla_extension, m_nb) {
nb::arg("heartbeat_interval").none() = std::nullopt,
nb::arg("max_missing_heartbeats").none() = std::nullopt,
nb::arg("missed_heartbeat_callback").none() = std::nullopt,
nb::arg("shutdown_on_destruction").none() = std::nullopt);
nb::arg("shutdown_on_destruction").none() = std::nullopt,
nb::arg("use_compression").none() = std::nullopt);

m_nb.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); });

Expand Down
1 change: 1 addition & 0 deletions xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ def get_distributed_runtime_client(
max_missing_heartbeats: Optional[int] = ...,
missed_heartbeat_callback: Optional[Any] = ...,
shutdown_on_destruction: Optional[bool] = ...,
use_compression: Optional[bool] = ...,
) -> DistributedRuntimeClient: ...

class PreemptionSyncManager:
Expand Down
3 changes: 2 additions & 1 deletion xla/tools/multihost_hlo_runner/create_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ static absl::StatusOr<std::unique_ptr<xla::PjRtClient>> GetPjRtClient(
options.node_id = node_id;
options.init_timeout = init_timeout;
distributed_client =
GetDistributedRuntimeClient(std::string(address), options);
GetDistributedRuntimeClient(std::string(address),
/*use_compression=*/false, options);
TF_QCHECK_OK(distributed_client->Connect());
kv_store = GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
Expand Down

0 comments on commit 027cfb1

Please sign in to comment.