Skip to content

Commit

Permalink
Enable polling for error from coordination service at startup by defa…
Browse files Browse the repository at this point in the history
…ult.

PiperOrigin-RevId: 679145273
  • Loading branch information
Google-ML-Automation committed Sep 26, 2024
1 parent f7d0b34 commit d4cd250
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
5 changes: 2 additions & 3 deletions xla/pjrt/distributed/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,8 @@ class DistributedRuntimeClient {

// Whether the client should send a request to wait for error from the
// coordination service at the startup.
// TODO(b/355706798): Enable this by default once we confirm this works for
// all cases and eventually remove this option.
bool poll_for_error_from_service_at_startup = false;
// TODO(b/355706798): eventually remove this option.
bool poll_for_error_from_service_at_startup = true;
};

virtual ~DistributedRuntimeClient() = default;
Expand Down
18 changes: 8 additions & 10 deletions xla/pjrt/distributed/client_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ TEST_F(ClientServerTest, ZeroInitTimeoutShouldStillWaitForOtherTasks) {
}
}

TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) {
TEST_F(ClientServerTest,
ClientsTerminateShutdownIfAnyClientGoesAway_WithoutErrorPolling) {
int num_nodes = 3;
StartService(num_nodes);

Expand Down Expand Up @@ -425,8 +426,7 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) {
}
}

TEST_F(ClientServerTest,
ClientsTerminateShutdownIfAnyClientGoesAway_WithErrorPolling) {
TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) {
int num_nodes = 3;
StartService(num_nodes);

Expand All @@ -435,7 +435,6 @@ TEST_F(ClientServerTest,
client_options.shutdown_on_destruction = node_id != 0;
client_options.missed_heartbeat_callback =
[&](absl::Status status, bool coordinator_initiated) {};
client_options.poll_for_error_from_service_at_startup = true;
auto client = GetClient(node_id, client_options);

TF_RETURN_IF_ERROR(client->Connect());
Expand Down Expand Up @@ -466,7 +465,7 @@ TEST_F(ClientServerTest,
}
}

TEST_F(ClientServerTest, ClientsShutdownSuccessfully_WithErrorPolling) {
TEST_F(ClientServerTest, ClientsShutdownSuccessfully) {
int num_nodes = 3;
StartService(num_nodes);

Expand All @@ -475,7 +474,6 @@ TEST_F(ClientServerTest, ClientsShutdownSuccessfully_WithErrorPolling) {
client_options.shutdown_on_destruction = true;
client_options.missed_heartbeat_callback =
[&](absl::Status status, bool coordinator_initiated) {};
client_options.poll_for_error_from_service_at_startup = true;
auto client = GetClient(node_id, client_options);

TF_RETURN_IF_ERROR(client->Connect());
Expand All @@ -497,8 +495,7 @@ TEST_F(ClientServerTest, ClientsShutdownSuccessfully_WithErrorPolling) {
}
}

TEST_F(ClientServerTest,
MissedHeartbeatCallbackIsExecutedIfAnyClientGoesAway_WithErrorPolling) {
TEST_F(ClientServerTest, MissedHeartbeatCallbackIsExecutedIfAnyClientGoesAway) {
int num_nodes = 3;
StartService(num_nodes);

Expand All @@ -510,7 +507,6 @@ TEST_F(ClientServerTest,
bool coordinator_initiated) {
shutdown.Notify();
};
client_options.poll_for_error_from_service_at_startup = true;
auto client = GetClient(node_id, client_options);

TF_RETURN_IF_ERROR(client->Connect());
Expand All @@ -535,7 +531,8 @@ TEST_F(ClientServerTest,
}
}

TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) {
TEST_F(ClientServerTest,
ClientsReceiveMissedHeartbeatIfAnyClientGoesAway_WithoutErrorPolling) {
int num_nodes = 3;
StartService(num_nodes);

Expand All @@ -547,6 +544,7 @@ TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) {
bool coordinator_initiated) {
shutdown.Notify();
};
client_options.poll_for_error_from_service_at_startup = false;
auto client = GetClient(node_id, client_options);

TF_RETURN_IF_ERROR(client->Connect());
Expand Down

0 comments on commit d4cd250

Please sign in to comment.