Skip to content

Commit

Permalink
Propagate shutdown errors before destroying agent.
Browse files Browse the repository at this point in the history
Also add a short buffer to the RPC timeouts so that service-related errors may get propagated before the RPC layer times it out.

PiperOrigin-RevId: 695572673
  • Loading branch information
Google-ML-Automation committed Nov 12, 2024
1 parent c9211b1 commit 9e0fe15
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 13 deletions.
40 changes: 40 additions & 0 deletions xla/tsl/distributed_runtime/coordination/client_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,46 @@ TEST_F(ClientServerTest, MissedHeartbeatCallbackIsExecutedIfAnyClientGoesAway) {
}
}

TEST_F(ClientServerTest, ShutdownErrorIsPropagatedToClients) {
int num_nodes = 2;
StartService(num_nodes);
std::vector<absl::Status> statuses = {
absl::UnknownError("Uninitialized status."),
absl::UnknownError("Uninitialized status.")};
absl::Notification shutdown;

auto thread_fn = [&](int node_id) {
auto client = GetClient(
node_id,
/*init_and_shutdown_timeout=*/absl::Seconds(2),
/*shutdown_on_destruction=*/true,
/*error_fn=*/[&statuses, node_id](const absl::Status& status) {
statuses[node_id] = status;
});

TF_ASSERT_OK(client->Connect());

if (node_id == 0) {
// Shut down early.
client = nullptr;
shutdown.Notify();
} else {
// Block until shutdown barrier times out.
shutdown.WaitForNotification();
}
};

{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
for (int i = 0; i < num_nodes; ++i) {
thread_pool.Schedule([&, i]() { thread_fn(i); });
}
}
EXPECT_THAT(statuses[0], StatusIs(absl::StatusCode::kInternal));
EXPECT_THAT(statuses[1], StatusIs(absl::StatusCode::kInternal));
}

TEST_F(ClientServerTest, ClientsTerminateIfServiceGoesAway) {
#if defined(ADDRESS_SANITIZER)
GTEST_SKIP()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,9 @@ absl::Status CoordinationServiceAgentImpl::Connect() {
configs_.cluster_register_timeout_in_ms() > 0
? configs_.cluster_register_timeout_in_ms()
: absl::ToInt64Milliseconds(kDefaultClusterRegisterTimeout);
// Give 1 second for any service-related timeouts to propagate.
const absl::Time deadline =
absl::Now() + absl::Milliseconds(register_timeout);
absl::Now() + absl::Milliseconds(register_timeout) + absl::Seconds(1);
int attempt = 0;
std::default_random_engine generator;
std::uniform_real_distribution<double> distribution(0.0, 1.0);
Expand Down Expand Up @@ -590,9 +591,11 @@ absl::Status CoordinationServiceAgentImpl::ShutdownInternal() {
ShutdownTaskResponse response;
CallOptions call_opts;
const int64_t shutdown_timeout =
configs_.shutdown_barrier_timeout_in_ms() > 0
? configs_.shutdown_barrier_timeout_in_ms()
: absl::ToInt64Milliseconds(kDefaultShutdownTimeout);
(configs_.shutdown_barrier_timeout_in_ms() > 0
? configs_.shutdown_barrier_timeout_in_ms()
: absl::ToInt64Milliseconds(kDefaultShutdownTimeout)) +
// Add 1s for service-related errors to propagate.
1000;
call_opts.SetTimeout(shutdown_timeout);

absl::Notification n;
Expand All @@ -605,14 +608,20 @@ absl::Status CoordinationServiceAgentImpl::ShutdownInternal() {
if (status.ok()) {
LOG(INFO) << "Coordination agent has successfully shut down.";
} else {
LOG(ERROR)
<< "Failed to disconnect from coordination service with status: "
<< TrimCoordinationErrorMessage(status)
<< "\nProceeding with agent shutdown anyway. This is usually caused "
"by an earlier error during execution. Check the logs of (a) this "
"task, (b) the leader (usually slice 0 task 0) and (c) the "
"scheduler (e.g. preemption, eviction) for an earlier error to "
"debug further.";
status = MakeCoordinationError(absl::Status(
status.code(),
absl::StrCat(
"Failed to disconnect from coordination service with "
"status: ",
TrimCoordinationErrorMessage(status).ToString(),
"Proceeding with agent shutdown anyway. This is usually caused "
"by an "
"earlier error during execution. Check the logs of (a) this "
"task, "
"(b) the leader (usually slice 0 task 0) and (c) the scheduler "
"(e.g. "
"preemption, eviction) for an earlier error to debug further.")));
SetError(status);
}
}

Expand All @@ -621,7 +630,7 @@ absl::Status CoordinationServiceAgentImpl::ShutdownInternal() {
StopErrorPolling();
{
absl::MutexLock l(&state_mu_);
if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) {
if (status.ok() && state_ == CoordinatedTaskState::TASKSTATE_ERROR) {
const std::string status_message = absl::StrCat(
"Shutdown() was called while coordination agent is in error state, "
"implying that distributed execution failed. Note: agent will "
Expand Down

0 comments on commit 9e0fe15

Please sign in to comment.