Skip to content

Commit

Permalink
Workaround for PTP message overflow (#336)
Browse files Browse the repository at this point in the history
* makespan: helpful debugging

* mpi: more cleanup

* endpoint: try to catch bug

* mpi: add contains method to registry

* endpoint: be less strict with empty requests

* debug: add try/catch arround mpi/transport

* mpi: know what everybody is doing in case of an error

* makespan: finally working

* nits: run clang format
  • Loading branch information
csegarragonz authored Jul 25, 2023
1 parent 4bbd28e commit 4a079db
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 27 deletions.
2 changes: 2 additions & 0 deletions include/faabric/mpi/MpiWorldRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class MpiWorldRegistry

MpiWorld& getWorld(int worldId);

bool worldExists(int worldId);

void clear();

private:
Expand Down
8 changes: 8 additions & 0 deletions include/faabric/transport/Message.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <map>
#include <nng/nng.h>
#include <span>
#include <string>
Expand Down Expand Up @@ -36,6 +37,13 @@ enum class MessageResponseCode
ERROR
};

static std::map<MessageResponseCode, std::string> MessageResponseCodeText = {
{ MessageResponseCode::SUCCESS, "Success" },
{ MessageResponseCode::TERM, "Connection terminated" },
{ MessageResponseCode::TIMEOUT, "Message timed out" },
{ MessageResponseCode::ERROR, "Error" },
};

/**
* Represents message data passed around the transport layer. Essentially an
* array of bytes, with a size and a flag to say whether there's more data to
Expand Down
18 changes: 14 additions & 4 deletions src/endpoint/FaabricEndpointHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,21 @@ void FaabricEndpointHandler::onRequest(
// Text response type
response.set(header::content_type, "text/plain");

// Request body contains a string that is formatted as a JSON
std::string requestStr = request.body();

// Handle JSON
// TODO: for the moment we keep the endpoint handler, but we are not meant
// to receive any requests here. Eventually we will delete it
SPDLOG_ERROR("Faabric handler received empty request");
response.result(beast::http::status::bad_request);
response.body() = std::string("Empty request");
ctx.sendFunction(std::move(response));
if (requestStr.empty()) {
SPDLOG_ERROR("Planner handler received empty request");
response.result(beast::http::status::bad_request);
response.body() = std::string("Empty request");
return ctx.sendFunction(std::move(response));
}

SPDLOG_ERROR("Worker HTTP handler received non-empty request (body: {})",
request.body());
throw std::runtime_error("Worker HTTP handler received non-empty request");
}
}
78 changes: 60 additions & 18 deletions src/mpi/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,42 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost,
if (!msg->SerializeToString(&serialisedBuffer)) {
throw std::runtime_error("Error serialising message");
}
broker.sendMessage(
thisRankMsg->groupid(),
sendRank,
recvRank,
reinterpret_cast<const uint8_t*>(serialisedBuffer.data()),
serialisedBuffer.size(),
dstHost,
true);
try {
broker.sendMessage(
thisRankMsg->groupid(),
sendRank,
recvRank,
reinterpret_cast<const uint8_t*>(serialisedBuffer.data()),
serialisedBuffer.size(),
dstHost,
true);
} catch (std::runtime_error& e) {
SPDLOG_ERROR("{}:{}:{} Timed out with: MPI - send {} -> {}",
thisRankMsg->appid(),
thisRankMsg->groupid(),
thisRankMsg->groupidx(),
sendRank,
recvRank);
throw e;
}
}

std::shared_ptr<MPIMessage> MpiWorld::recvRemoteMpiMessage(int sendRank,
int recvRank)
{
auto msg =
broker.recvMessage(thisRankMsg->groupid(), sendRank, recvRank, true);
std::vector<uint8_t> msg;
try {
msg =
broker.recvMessage(thisRankMsg->groupid(), sendRank, recvRank, true);
} catch (std::runtime_error& e) {
SPDLOG_ERROR("{}:{}:{} Timed out with: MPI - recv (remote) {} -> {}",
thisRankMsg->appid(),
thisRankMsg->groupid(),
thisRankMsg->groupidx(),
sendRank,
recvRank);
throw e;
}
PARSE_MSG(MPIMessage, msg.data(), msg.size());
return std::make_shared<MPIMessage>(parsedMsg);
}
Expand Down Expand Up @@ -1456,18 +1477,39 @@ std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,
if (isLocal) {
// First receive messages that happened before us
for (int i = 0; i < batchSize - 1; i++) {
SPDLOG_TRACE("MPI - pending recv {} -> {}", sendRank, recvRank);
auto pendingMsg = getLocalQueue(sendRank, recvRank)->dequeue();

// Put the unacked message in the UMB
assert(!msgIt->isAcknowledged());
msgIt->acknowledge(pendingMsg);
msgIt++;
try {
SPDLOG_TRACE("MPI - pending recv {} -> {}", sendRank, recvRank);
auto pendingMsg = getLocalQueue(sendRank, recvRank)->dequeue();

// Put the unacked message in the UMB
assert(!msgIt->isAcknowledged());
msgIt->acknowledge(pendingMsg);
msgIt++;
} catch (faabric::util::QueueTimeoutException& e) {
SPDLOG_ERROR(
"{}:{}:{} Timed out with: MPI - pending recv {} -> {}",
thisRankMsg->appid(),
thisRankMsg->groupid(),
thisRankMsg->groupidx(),
sendRank,
recvRank);
throw e;
}
}

// Finally receive the message corresponding to us
SPDLOG_TRACE("MPI - recv {} -> {}", sendRank, recvRank);
ourMsg = getLocalQueue(sendRank, recvRank)->dequeue();
try {
ourMsg = getLocalQueue(sendRank, recvRank)->dequeue();
} catch (faabric::util::QueueTimeoutException& e) {
SPDLOG_ERROR("{}:{}:{} Timed out with: MPI - recv {} -> {}",
thisRankMsg->appid(),
thisRankMsg->groupid(),
thisRankMsg->groupidx(),
sendRank,
recvRank);
throw e;
}
} else {
// First receive messages that happened before us
for (int i = 0; i < batchSize - 1; i++) {
Expand Down
7 changes: 7 additions & 0 deletions src/mpi/MpiWorldRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ MpiWorld& MpiWorldRegistry::getWorld(int worldId)
return worldMap[worldId];
}

bool MpiWorldRegistry::worldExists(int worldId)
{
faabric::util::SharedLock lock(registryMutex);

return worldMap.contains(worldId);
}

void MpiWorldRegistry::clear()
{
faabric::util::FullLock lock(registryMutex);
Expand Down
28 changes: 24 additions & 4 deletions src/transport/PointToPointBroker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,18 @@ void PointToPointBroker::sendMessage(int groupId,
localSendSeqNum,
endpoint.getAddress());

endpoint.send(NO_HEADER, buffer, bufferSize, localSendSeqNum);

try {
endpoint.send(NO_HEADER, buffer, bufferSize, localSendSeqNum);
} catch (std::runtime_error& e) {
SPDLOG_ERROR("Timed-out with local point-to-point message {}:{}:{} "
"(seq: {}) to {}",
groupId,
sendIdx,
recvIdx,
localSendSeqNum,
endpoint.getAddress());
throw e;
}
} else {
auto cli = getClient(host);
faabric::PointToPointMessage msg;
Expand All @@ -689,7 +699,17 @@ void PointToPointBroker::sendMessage(int groupId,
remoteSendSeqNum,
host);

cli->sendMessage(msg, remoteSendSeqNum);
try {
cli->sendMessage(msg, remoteSendSeqNum);
} catch (std::runtime_error& e) {
SPDLOG_TRACE("Timed-out with remote point-to-point message "
"{}:{}:{} (seq: {}) to {}",
groupId,
sendIdx,
recvIdx,
remoteSendSeqNum,
host);
}
}
}

Expand Down Expand Up @@ -756,7 +776,7 @@ std::vector<uint8_t> PointToPointBroker::recvMessage(int groupId,
SPDLOG_WARN(
"Error {} ({}) when awaiting a message ({}:{} seq: {} label: {})",
static_cast<int>(recvMsg.getResponseCode()),
nng_strerror(static_cast<int>(recvMsg.getResponseCode())),
MessageResponseCodeText.at(recvMsg.getResponseCode()),
sendIdx,
recvIdx,
expectedSeqNum,
Expand Down
4 changes: 3 additions & 1 deletion src/util/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ void SystemConfig::initialise()
this->getSystemConfIntParam("STATE_SERVER_THREADS", "2");
snapshotServerThreads =
this->getSystemConfIntParam("SNAPSHOT_SERVER_THREADS", "2");
// FIXME: temporarily set this value to a higher number to work-around:
// https://github.com/faasm/faabric/issues/335
pointToPointServerThreads =
this->getSystemConfIntParam("POINT_TO_POINT_SERVER_THREADS", "2");
this->getSystemConfIntParam("POINT_TO_POINT_SERVER_THREADS", "8");

// Dirty tracking
dirtyTrackingMode = getEnvVar("DIRTY_TRACKING_MODE", "segfault");
Expand Down

0 comments on commit 4a079db

Please sign in to comment.