diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 17bf5b909..d752722ce 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -119,12 +119,16 @@ jobs: run: ./bin/inv_wrapper.sh dev.cc faabric_tests - name: "Run tests" run: ./bin/inv_wrapper.sh tests + timeout-minutes: 10 dist-tests: if: github.event.pull_request.draft == false needs: [conan-cache] - runs-on: ubuntu-latest + runs-on: self-hosted env: + # Make a unique per-job cluster name, so that different instances can + # run in parallel + COMPOSE_PROJECT_NAME: faabric-gha-${{ github.job }}-${{ github.run_id }}-${{ github.run_attempt }} CONAN_CACHE_MOUNT_SOURCE: ~/.conan/ steps: # --- Code update --- @@ -136,9 +140,13 @@ jobs: run: ./dist-test/build.sh - name: "Run the distributed tests" run: ./dist-test/run.sh + timeout-minutes: 10 - name: "Print planner logs" if: always() run: docker compose logs planner + - name: "Chown all files to avoid docker-related root-owned files" + if: always() + run: sudo chown -R $(id -u):$(id -g) . examples: if: github.event.pull_request.draft == false diff --git a/include/faabric/mpi/MpiMessage.h b/include/faabric/mpi/MpiMessage.h new file mode 100644 index 000000000..7c85fde48 --- /dev/null +++ b/include/faabric/mpi/MpiMessage.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include + +namespace faabric::mpi { + +enum MpiMessageType : int32_t +{ + NORMAL = 0, + BARRIER_JOIN = 1, + BARRIER_DONE = 2, + SCATTER = 3, + GATHER = 4, + ALLGATHER = 5, + REDUCE = 6, + SCAN = 7, + ALLREDUCE = 8, + ALLTOALL = 9, + SENDRECV = 10, + BROADCAST = 11, +}; + +struct MpiMessage +{ + int32_t id; + int32_t worldId; + int32_t sendRank; + int32_t recvRank; + int32_t typeSize; + int32_t count; + MpiMessageType messageType; + void* buffer; +}; + +inline size_t payloadSize(const MpiMessage& msg) +{ + return msg.typeSize * msg.count; +} + +inline size_t msgSize(const MpiMessage& msg) +{ + return sizeof(MpiMessage) + payloadSize(msg); +} + +void serializeMpiMsg(std::vector& buffer, const MpiMessage& msg); + +void parseMpiMsg(const std::vector& bytes, MpiMessage* msg); +} diff --git a/include/faabric/mpi/MpiMessageBuffer.h b/include/faabric/mpi/MpiMessageBuffer.h index 9fc67b644..c36f89887 100644 --- a/include/faabric/mpi/MpiMessageBuffer.h +++ b/include/faabric/mpi/MpiMessageBuffer.h @@ -1,8 +1,9 @@ +#include #include -#include #include #include +#include namespace faabric::mpi { /* The MPI message buffer (MMB) keeps track of the asyncrhonous @@ -25,17 +26,20 @@ class MpiMessageBuffer { public: int requestId = -1; - std::shared_ptr msg = nullptr; + std::shared_ptr msg = nullptr; int sendRank = -1; int recvRank = -1; uint8_t* buffer = nullptr; faabric_datatype_t* dataType = nullptr; int count = -1; - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL; + MpiMessageType messageType = MpiMessageType::NORMAL; bool isAcknowledged() { return msg != nullptr; } - void acknowledge(std::shared_ptr msgIn) { msg = msgIn; } + void acknowledge(const MpiMessage& msgIn) + { + msg = std::make_shared(msgIn); + } }; /* Interface to query the buffer size */ diff --git a/include/faabric/mpi/MpiWorld.h b/include/faabric/mpi/MpiWorld.h index adee54137..97fb24f18 100644 --- a/include/faabric/mpi/MpiWorld.h +++ b/include/faabric/mpi/MpiWorld.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #include #include #include @@ -26,10 +26,9 @@ namespace faabric::mpi { // ----------------------------------- // MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker // as the broker already has mocking capabilities -std::vector> getMpiMockedMessages(int sendRank); +std::vector getMpiMockedMessages(int sendRank); -typedef faabric::util::FixedCapacityQueue> - InMemoryMpiQueue; +typedef faabric::util::SpinLockQueue InMemoryMpiQueue; class MpiWorld { @@ -73,21 +72,21 @@ class MpiWorld const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); int isend(int sendRank, int recvRank, const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); void broadcast(int rootRank, int thisRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); void recv(int sendRank, int recvRank, @@ -95,14 +94,14 @@ class MpiWorld faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); int irecv(int sendRank, int recvRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); void awaitAsyncRequest(int requestId); @@ -185,8 +184,6 @@ class MpiWorld std::shared_ptr getLocalQueue(int sendRank, int recvRank); - long getLocalQueueSize(int sendRank, int recvRank); - void overrideHost(const std::string& newHost); double getWTime(); @@ -240,29 +237,36 @@ class MpiWorld void sendRemoteMpiMessage(std::string dstHost, int sendRank, int recvRank, - const std::shared_ptr& msg); + const MpiMessage& msg); - std::shared_ptr recvRemoteMpiMessage(int sendRank, - int recvRank); + MpiMessage recvRemoteMpiMessage(int sendRank, int recvRank); // Support for asyncrhonous communications std::shared_ptr getUnackedMessageBuffer(int sendRank, int recvRank); - std::shared_ptr recvBatchReturnLast(int sendRank, - int recvRank, - int batchSize = 0); + MpiMessage recvBatchReturnLast(int sendRank, + int recvRank, + int batchSize = 0); /* Helper methods */ void checkRanksRange(int sendRank, int recvRank); // Abstraction of the bulk of the recv work, shared among various functions - void doRecv(std::shared_ptr& m, + void doRecv(const MpiMessage& m, + uint8_t* buffer, + faabric_datatype_t* dataType, + int count, + MPI_Status* status, + MpiMessageType messageType = MpiMessageType::NORMAL); + + // Abstraction of the bulk of the recv work, shared among various functions + void doRecv(std::unique_ptr m, uint8_t* buffer, faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL); + MpiMessageType messageType = MpiMessageType::NORMAL); }; } diff --git a/include/faabric/transport/PointToPointBroker.h b/include/faabric/transport/PointToPointBroker.h index 95f6cba17..87a47ca3b 100644 --- a/include/faabric/transport/PointToPointBroker.h +++ b/include/faabric/transport/PointToPointBroker.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -120,27 +121,16 @@ class PointToPointBroker void updateHostForIdx(int groupId, int groupIdx, std::string newHost); - void sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, + void sendMessage(const PointToPointMessage& msg, std::string hostHint, bool mustOrderMsg = false); - void sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, + void sendMessage(const PointToPointMessage& msg, bool mustOrderMsg = false, int sequenceNum = NO_SEQUENCE_NUM, std::string hostHint = ""); - std::vector recvMessage(int groupId, - int sendIdx, - int recvIdx, - bool mustOrderMsg = false); + void recvMessage(PointToPointMessage& msg, bool mustOrderMsg = false); void clearGroup(int groupId); @@ -163,7 +153,8 @@ class PointToPointBroker std::shared_ptr getGroupFlag(int groupId); - Message doRecvMessage(int groupId, int sendIdx, int recvIdx); + // Returns the message response code and the sequence number + std::pair doRecvMessage(PointToPointMessage& msg); void initSequenceCounters(int groupId); diff --git a/include/faabric/transport/PointToPointClient.h b/include/faabric/transport/PointToPointClient.h index 634b41579..5e5add933 100644 --- a/include/faabric/transport/PointToPointClient.h +++ b/include/faabric/transport/PointToPointClient.h @@ -3,18 +3,19 @@ #include #include #include +#include namespace faabric::transport { std::vector> getSentMappings(); -std::vector> +std::vector> getSentPointToPointMessages(); std::vector> + PointToPointMessage>> getSentLockMessages(); void clearSentMessages(); @@ -26,7 +27,7 @@ class PointToPointClient : public faabric::transport::MessageEndpointClient void sendMappings(faabric::PointToPointMappings& mappings); - void sendMessage(faabric::PointToPointMessage& msg, + void sendMessage(const PointToPointMessage& msg, int sequenceNum = NO_SEQUENCE_NUM); void groupLock(int appId, diff --git a/include/faabric/transport/PointToPointMessage.h b/include/faabric/transport/PointToPointMessage.h new file mode 100644 index 000000000..e61e2c509 --- /dev/null +++ b/include/faabric/transport/PointToPointMessage.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + +namespace faabric::transport { + +/* Simple fixed-size C-struct to capture the state of a PTP message moving + * through Faabric. + * + * We require fixed-size, and no unique pointers to be able to use + * high-throughput ring-buffers to send the messages around. This also means + * that we manually malloc/free the data pointer. The message size is: + * 4 * int32_t = 4 * 4 bytes = 16 bytes + * 1 * size_t = 1 * 8 bytes = 8 bytes + * 1 * void* = 1 * 8 bytes = 8 bytes + * total = 32 bytes = 4 * 8 so the struct is naturally 8 byte-aligned + */ +struct PointToPointMessage +{ + int32_t appId; + int32_t groupId; + int32_t sendIdx; + int32_t recvIdx; + size_t dataSize; + void* dataPtr; +}; +static_assert((sizeof(PointToPointMessage) % 8) == 0, + "PTP message mus be 8-aligned!"); + +// The wire format for a PTP message is very simple: the fixed-size struct, +// followed by dataSize bytes containing the payload. +void serializePtpMsg(std::span buffer, const PointToPointMessage& msg); + +// This parsing function mallocs space for the message payload. This is to +// keep the PTP message at fixed-size, and be able to efficiently move it +// around in-memory queues +void parsePtpMsg(std::span bytes, PointToPointMessage* msg); + +// Alternative signature for parsing PTP messages for when the caller can +// provide an already-allocated buffer to write into +void parsePtpMsg(std::span bytes, + PointToPointMessage* msg, + std::span preAllocBuffer); +} diff --git a/include/faabric/util/queue.h b/include/faabric/util/queue.h index 6d89aab18..9f9e2f164 100644 --- a/include/faabric/util/queue.h +++ b/include/faabric/util/queue.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -215,6 +216,48 @@ class FixedCapacityQueue moodycamel::BlockingReaderWriterCircularBuffer mq; }; +// High-performance, spin-lock single-producer, single-consumer queue. This +// queue spin-locks, so use at your own risk! +template +class SpinLockQueue +{ + public: + void enqueue(T& value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) + { + while (!mq.push(value)) { + ; + }; + } + + T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS) + { + T value; + + while (!mq.pop(value)) { + ; + } + + return value; + } + + long size() + { + throw std::runtime_error("Size for fast queue unimplemented!"); + } + + void drain() + { + while (mq.pop()) { + ; + } + } + + void reset() { ; } + + private: + boost::lockfree::spsc_queue> mq; +}; + class TokenPool { public: diff --git a/src/mpi/CMakeLists.txt b/src/mpi/CMakeLists.txt index dfd434c87..3ac5b98c9 100644 --- a/src/mpi/CMakeLists.txt +++ b/src/mpi/CMakeLists.txt @@ -38,32 +38,12 @@ endif() # ----------------------------------------------- if (NOT ("${CMAKE_PROJECT_NAME}" STREQUAL "faabricmpi")) - # Generate protobuf headers - set(MPI_PB_HEADER_COPIED "${FAABRIC_INCLUDE_DIR}/faabric/mpi/mpi.pb.h") - - protobuf_generate_cpp(MPI_PB_SRC MPI_PB_HEADER mpi.proto) - - # Copy the generated headers into place - add_custom_command( - OUTPUT "${MPI_PB_HEADER_COPIED}" - DEPENDS "${MPI_PB_HEADER}" - COMMAND ${CMAKE_COMMAND} - ARGS -E copy ${MPI_PB_HEADER} ${FAABRIC_INCLUDE_DIR}/faabric/mpi/ - ) - - add_custom_target( - mpi_pbh_copied - DEPENDS ${MPI_PB_HEADER_COPIED} - ) - - add_dependencies(faabric_common_dependencies mpi_pbh_copied) - faabric_lib(mpi MpiContext.cpp + MpiMessage.cpp MpiMessageBuffer.cpp MpiWorld.cpp MpiWorldRegistry.cpp - ${MPI_PB_SRC} ) target_link_libraries(mpi PRIVATE diff --git a/src/mpi/MpiMessage.cpp b/src/mpi/MpiMessage.cpp new file mode 100644 index 000000000..57ee8c85e --- /dev/null +++ b/src/mpi/MpiMessage.cpp @@ -0,0 +1,36 @@ +#include +#include + +#include +#include +#include + +namespace faabric::mpi { + +void parseMpiMsg(const std::vector& bytes, MpiMessage* msg) +{ + assert(msg != nullptr); + assert(bytes.size() >= sizeof(MpiMessage)); + std::memcpy(msg, bytes.data(), sizeof(MpiMessage)); + size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage); + assert(thisPayloadSize == payloadSize(*msg)); + + if (thisPayloadSize == 0) { + msg->buffer = nullptr; + return; + } + + msg->buffer = faabric::util::malloc(thisPayloadSize); + std::memcpy( + msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize); +} + +void serializeMpiMsg(std::vector& buffer, const MpiMessage& msg) +{ + std::memcpy(buffer.data(), &msg, sizeof(MpiMessage)); + size_t payloadSz = payloadSize(msg); + if (payloadSz > 0 && msg.buffer != nullptr) { + std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz); + } +} +} diff --git a/src/mpi/MpiWorld.cpp b/src/mpi/MpiWorld.cpp index d50344c40..cc8705dfc 100644 --- a/src/mpi/MpiWorld.cpp +++ b/src/mpi/MpiWorld.cpp @@ -1,7 +1,8 @@ #include +#include #include -#include #include +#include #include #include #include @@ -34,10 +35,9 @@ static std::mutex mockMutex; // The identifier in this map is the sending rank. For the receiver's rank // we can inspect the MPIMessage object -static std::map>> - mpiMockedMessages; +static std::map> mpiMockedMessages; -std::vector> getMpiMockedMessages(int sendRank) +std::vector getMpiMockedMessages(int sendRank) { faabric::util::UniqueLock lock(mockMutex); return mpiMockedMessages[sendRank]; @@ -53,21 +53,23 @@ MpiWorld::MpiWorld() void MpiWorld::sendRemoteMpiMessage(std::string dstHost, int sendRank, int recvRank, - const std::shared_ptr& msg) + const MpiMessage& msg) { - std::string serialisedBuffer; - if (!msg->SerializeToString(&serialisedBuffer)) { - throw std::runtime_error("Error serialising message"); - } + // Serialise + std::vector serialisedBuffer(msgSize(msg)); + serializeMpiMsg(serialisedBuffer, msg); + try { - broker.sendMessage( - thisRankMsg->groupid(), - sendRank, - recvRank, - reinterpret_cast(serialisedBuffer.data()), - serialisedBuffer.size(), - dstHost, - true); + // It is safe to send a pointer to a stack-allocated object + // because the broker will make an additional copy (and so will NNG!) + faabric::transport::PointToPointMessage msg( + { .groupId = thisRankMsg->groupid(), + .sendIdx = sendRank, + .recvIdx = recvRank, + .dataSize = serialisedBuffer.size(), + .dataPtr = (void*)serialisedBuffer.data() }); + + broker.sendMessage(msg, dstHost, true); } catch (std::runtime_error& e) { SPDLOG_ERROR("{}:{}:{} Timed out with: MPI - send {} -> {}", thisRankMsg->appid(), @@ -79,13 +81,14 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost, } } -std::shared_ptr MpiWorld::recvRemoteMpiMessage(int sendRank, - int recvRank) +MpiMessage MpiWorld::recvRemoteMpiMessage(int sendRank, int recvRank) { - std::vector msg; + faabric::transport::PointToPointMessage msg( + { .groupId = thisRankMsg->groupid(), + .sendIdx = sendRank, + .recvIdx = recvRank }); try { - msg = - broker.recvMessage(thisRankMsg->groupid(), sendRank, recvRank, true); + broker.recvMessage(msg, true); } catch (std::runtime_error& e) { SPDLOG_ERROR("{}:{}:{} Timed out with: MPI - recv (remote) {} -> {}", thisRankMsg->appid(), @@ -95,8 +98,13 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage(int sendRank, recvRank); throw e; } - PARSE_MSG(MPIMessage, msg.data(), msg.size()); - return std::make_shared(parsedMsg); + + // TODO(mpi-opt): make sure we minimze copies here + MpiMessage parsedMsg; + std::vector msgBytes((uint8_t*) msg.dataPtr, (uint8_t*) msg.dataPtr + msg.dataSize); + parseMpiMsg(msgBytes, &parsedMsg); + + return parsedMsg; } std::shared_ptr MpiWorld::getUnackedMessageBuffer( @@ -447,7 +455,7 @@ int MpiWorld::isend(int sendRank, const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { int requestId = (int)faabric::util::generateGid(); iSendRequests.insert(requestId); @@ -462,7 +470,7 @@ int MpiWorld::irecv(int sendRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { int requestId = (int)faabric::util::generateGid(); reqIdToRanks.try_emplace(requestId, sendRank, recvRank); @@ -489,7 +497,7 @@ void MpiWorld::send(int sendRank, const uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { // Sanity-check input parameters checkRanksRange(sendRank, recvRank); @@ -506,45 +514,39 @@ void MpiWorld::send(int sendRank, // Generate a message ID int msgId = (localMsgCount + 1) % INT32_MAX; - // Create the message - auto m = std::make_shared(); - m->set_id(msgId); - m->set_worldid(id); - m->set_sender(sendRank); - m->set_destination(recvRank); - m->set_type(dataType->id); - m->set_count(count); - m->set_messagetype(messageType); - - // Set up message data - bool mustSendData = count > 0 && buffer != nullptr; + MpiMessage msg = { .id = msgId, + .worldId = id, + .sendRank = sendRank, + .recvRank = recvRank, + .typeSize = dataType->size, + .count = count, + .messageType = messageType, + .buffer = (void*)buffer }; // Mock the message sending in tests if (faabric::util::isMockMode()) { - mpiMockedMessages[sendRank].push_back(m); + mpiMockedMessages[sendRank].push_back(msg); return; } // Dispatch the message locally or globally if (isLocal) { - if (mustSendData) { + // Take control over the buffer data if we are gonna move it to + // the in-memory queues for local messaging + if (count > 0 && buffer != nullptr) { void* bufferPtr = faabric::util::malloc(count * dataType->size); std::memcpy(bufferPtr, buffer, count * dataType->size); - m->set_bufferptr((uint64_t)bufferPtr); + msg.buffer = bufferPtr; } SPDLOG_TRACE( "MPI - send {} -> {} ({})", sendRank, recvRank, messageType); - getLocalQueue(sendRank, recvRank)->enqueue(std::move(m)); + getLocalQueue(sendRank, recvRank)->enqueue(msg); } else { - if (mustSendData) { - m->set_buffer(buffer, dataType->size * count); - } - SPDLOG_TRACE( "MPI - send remote {} -> {} ({})", sendRank, recvRank, messageType); - sendRemoteMpiMessage(otherHost, sendRank, recvRank, m); + sendRemoteMpiMessage(otherHost, sendRank, recvRank, msg); } /* 02/05/2022 - The following bit of code fails randomly with a protobuf @@ -572,7 +574,7 @@ void MpiWorld::recv(int sendRank, faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { // Sanity-check input parameters checkRanksRange(sendRank, recvRank); @@ -582,54 +584,47 @@ void MpiWorld::recv(int sendRank, return; } - // Recv message from underlying transport - std::shared_ptr m = recvBatchReturnLast(sendRank, recvRank); + auto msg = recvBatchReturnLast(sendRank, recvRank); - // Do the processing - doRecv(m, buffer, dataType, count, status, messageType); + doRecv(std::move(msg), buffer, dataType, count, status, messageType); } -void MpiWorld::doRecv(std::shared_ptr& m, +void MpiWorld::doRecv(const MpiMessage& m, uint8_t* buffer, faabric_datatype_t* dataType, int count, MPI_Status* status, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { // Assert message integrity // Note - this checks won't happen in Release builds - if (m->messagetype() != messageType) { + if (m.messageType != messageType) { SPDLOG_ERROR("Different message types (got: {}, expected: {})", - m->messagetype(), + m.messageType, messageType); } - assert(m->messagetype() == messageType); - assert(m->count() <= count); - - const std::string otherHost = getHostForRank(m->destination()); - bool isLocal = - getHostForRank(m->destination()) == getHostForRank(m->sender()); - - if (m->count() > 0) { - if (isLocal) { - // Make sure we do not overflow the recepient buffer - auto bytesToCopy = std::min(m->count() * dataType->size, - count * dataType->size); - std::memcpy(buffer, (void*)m->bufferptr(), bytesToCopy); - faabric::util::free((void*)m->bufferptr()); - } else { - // TODO - avoid copy here - std::move(m->buffer().begin(), m->buffer().end(), buffer); - } + assert(m.messageType == messageType); + assert(m.count <= count); + + // We must copy the data into the application-provided buffer + if (m.count > 0 && m.buffer != nullptr) { + // Make sure we do not overflow the recepient buffer + auto bytesToCopy = + std::min(m.count * dataType->size, count * dataType->size); + std::memcpy(buffer, m.buffer, bytesToCopy); + + // This buffer has been malloc-ed either as part of a local `send` + // or as part of a remote `parseMpiMsg` + faabric::util::free((void*)m.buffer); } // Set status values if required if (status != nullptr) { - status->MPI_SOURCE = m->sender(); + status->MPI_SOURCE = m.sendRank; status->MPI_ERROR = MPI_SUCCESS; // Take the message size here as the receive count may be larger - status->bytesSize = m->count() * dataType->size; + status->bytesSize = m.count * dataType->size; // TODO - thread through tag status->MPI_TAG = -1; @@ -667,14 +662,14 @@ void MpiWorld::sendRecv(uint8_t* sendBuffer, recvBuffer, recvDataType, recvCount, - MPIMessage::SENDRECV); + MpiMessageType::SENDRECV); // Then send the message send(myRank, sendRank, sendBuffer, sendDataType, sendCount, - MPIMessage::SENDRECV); + MpiMessageType::SENDRECV); // And wait awaitAsyncRequest(recvId); } @@ -684,7 +679,7 @@ void MpiWorld::broadcast(int sendRank, uint8_t* buffer, faabric_datatype_t* dataType, int count, - MPIMessage::MPIMessageType messageType) + MpiMessageType messageType) { SPDLOG_TRACE("MPI - bcast {} -> {}", sendRank, recvRank); @@ -795,7 +790,7 @@ void MpiWorld::scatter(int sendRank, startPtr, sendType, sendCount, - MPIMessage::SCATTER); + MpiMessageType::SCATTER); } } } else { @@ -806,7 +801,7 @@ void MpiWorld::scatter(int sendRank, recvType, recvCount, nullptr, - MPIMessage::SCATTER); + MpiMessageType::SCATTER); } } @@ -880,7 +875,7 @@ void MpiWorld::gather(int sendRank, recvType, recvCount, nullptr, - MPIMessage::GATHER); + MpiMessageType::GATHER); } } } else { @@ -894,7 +889,7 @@ void MpiWorld::gather(int sendRank, recvType, recvCount * it.second.size(), nullptr, - MPIMessage::GATHER); + MpiMessageType::GATHER); // Copy each received chunk to its offset for (int r = 0; r < it.second.size(); r++) { @@ -924,7 +919,7 @@ void MpiWorld::gather(int sendRank, sendType, sendCount, nullptr, - MPIMessage::GATHER); + MpiMessageType::GATHER); } } @@ -934,7 +929,7 @@ void MpiWorld::gather(int sendRank, rankData.get(), sendType, sendCount * ranksForHost[thisHost].size(), - MPIMessage::GATHER); + MpiMessageType::GATHER); } else if (isLocalLeader && isLocalGather) { // Scenario 3 @@ -943,7 +938,7 @@ void MpiWorld::gather(int sendRank, sendBuffer + sendBufferOffset, sendType, sendCount, - MPIMessage::GATHER); + MpiMessageType::GATHER); } else if (!isLocalLeader && !isLocalGather) { // Scenario 4 send(sendRank, @@ -951,7 +946,7 @@ void MpiWorld::gather(int sendRank, sendBuffer + sendBufferOffset, sendType, sendCount, - MPIMessage::GATHER); + MpiMessageType::GATHER); } else if (!isLocalLeader && isLocalGather) { // Scenario 5 send(sendRank, @@ -959,7 +954,7 @@ void MpiWorld::gather(int sendRank, sendBuffer + sendBufferOffset, sendType, sendCount, - MPIMessage::GATHER); + MpiMessageType::GATHER); } else { SPDLOG_ERROR("Don't know how to gather rank's data."); SPDLOG_ERROR("- sendRank: {}\n- recvRank: {}\n- isGatherReceiver: " @@ -1001,7 +996,7 @@ void MpiWorld::allGather(int rank, // Do a broadcast with a hard-coded root broadcast( - root, rank, recvBuffer, recvType, fullCount, MPIMessage::ALLGATHER); + root, rank, recvBuffer, recvType, fullCount, MpiMessageType::ALLGATHER); } void MpiWorld::awaitAsyncRequest(int requestId) @@ -1033,10 +1028,10 @@ void MpiWorld::awaitAsyncRequest(int requestId) std::list::iterator msgIt = umb->getRequestPendingMsg(requestId); - std::shared_ptr m; + MpiMessage m; if (msgIt->msg != nullptr) { // This id has already been acknowledged by a recv call, so do the recv - m = msgIt->msg; + m = *(msgIt->msg); } else { // We need to acknowledge all messages not acknowledged from the // begining until us @@ -1094,7 +1089,7 @@ void MpiWorld::reduce(int sendRank, datatype, count, nullptr, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); op_reduce( operation, datatype, count, rankData.get(), recvBuffer); @@ -1108,7 +1103,7 @@ void MpiWorld::reduce(int sendRank, datatype, count, nullptr, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); op_reduce( operation, datatype, count, rankData.get(), recvBuffer); @@ -1138,7 +1133,7 @@ void MpiWorld::reduce(int sendRank, datatype, count, nullptr, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); op_reduce(operation, datatype, @@ -1152,7 +1147,7 @@ void MpiWorld::reduce(int sendRank, sendBufferCopy.get(), datatype, count, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); } else { // Send to the receiver rank send(sendRank, @@ -1160,7 +1155,7 @@ void MpiWorld::reduce(int sendRank, sendBuffer, datatype, count, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); } } else { // If we are neither the receiver of the reduce nor a local leader, we @@ -1175,7 +1170,7 @@ void MpiWorld::reduce(int sendRank, sendBuffer, datatype, count, - MPIMessage::REDUCE); + MpiMessageType::REDUCE); } } @@ -1191,7 +1186,7 @@ void MpiWorld::allReduce(int rank, reduce(rank, 0, sendBuffer, recvBuffer, datatype, count, operation); // Second, 0 broadcasts the result to all ranks - broadcast(0, rank, recvBuffer, datatype, count, MPIMessage::ALLREDUCE); + broadcast(0, rank, recvBuffer, datatype, count, MpiMessageType::ALLREDUCE); } void MpiWorld::op_reduce(faabric_op_t* operation, @@ -1350,14 +1345,14 @@ void MpiWorld::scan(int rank, datatype, count, nullptr, - MPIMessage::SCAN); + MpiMessageType::SCAN); // Reduce with our own value op_reduce(operation, datatype, count, currentAcc.get(), recvBuffer); } // If not the last process, send to the next one if (rank < this->size - 1) { - send(rank, rank + 1, recvBuffer, MPI_INT, count, MPIMessage::SCAN); + send(rank, rank + 1, recvBuffer, MPI_INT, count, MpiMessageType::SCAN); } } @@ -1385,7 +1380,12 @@ void MpiWorld::allToAll(int rank, sendChunk, sendChunk + sendOffset, recvBuffer + rankOffset); } else { // Send message to other rank - send(rank, r, sendChunk, sendType, sendCount, MPIMessage::ALLTOALL); + send(rank, + r, + sendChunk, + sendType, + sendCount, + MpiMessageType::ALLTOALL); } } @@ -1405,7 +1405,7 @@ void MpiWorld::allToAll(int rank, recvType, recvCount, nullptr, - MPIMessage::ALLTOALL); + MpiMessageType::ALLTOALL); } } @@ -1416,15 +1416,17 @@ void MpiWorld::allToAll(int rank, // queues. void MpiWorld::probe(int sendRank, int recvRank, MPI_Status* status) { + throw std::runtime_error("Probe not implemented!"); + /* const std::shared_ptr& queue = getLocalQueue(sendRank, recvRank); - // 30/12/21 - Peek will throw a runtime error std::shared_ptr m = *(queue->peek()); faabric_datatype_t* datatype = getFaabricDatatypeFromId(m->type()); status->bytesSize = m->count() * datatype->size; status->MPI_ERROR = 0; status->MPI_SOURCE = m->sender(); + */ } void MpiWorld::barrier(int thisRank) @@ -1437,17 +1439,17 @@ void MpiWorld::barrier(int thisRank) // Await messages from all others for (int r = 1; r < size; r++) { MPI_Status s{}; - recv(r, 0, nullptr, MPI_INT, 0, &s, MPIMessage::BARRIER_JOIN); + recv(r, 0, nullptr, MPI_INT, 0, &s, MpiMessageType::BARRIER_JOIN); SPDLOG_TRACE("MPI - recv barrier join {}", s.MPI_SOURCE); } } else { // Tell the root that we're waiting SPDLOG_TRACE("MPI - barrier join {}", thisRank); - send(thisRank, 0, nullptr, MPI_INT, 0, MPIMessage::BARRIER_JOIN); + send(thisRank, 0, nullptr, MPI_INT, 0, MpiMessageType::BARRIER_JOIN); } // Rank 0 broadcasts that the barrier is done (the others block here) - broadcast(0, thisRank, nullptr, MPI_INT, 0, MPIMessage::BARRIER_DONE); + broadcast(0, thisRank, nullptr, MPI_INT, 0, MpiMessageType::BARRIER_DONE); SPDLOG_TRACE("MPI - barrier done {}", thisRank); } @@ -1477,9 +1479,10 @@ void MpiWorld::initLocalQueues() } } -std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, - int recvRank, - int batchSize) +// TODO(mpi-opt): double-check that the fast (no-async) path is fast +MpiMessage MpiWorld::recvBatchReturnLast(int sendRank, + int recvRank, + int batchSize) { std::shared_ptr umb = getUnackedMessageBuffer(sendRank, recvRank); @@ -1499,7 +1502,7 @@ std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, // Recv message: first we receive all messages for which there is an id // in the unacknowleged buffer but no msg. Note that these messages // (batchSize - 1) were `irecv`-ed before ours. - std::shared_ptr ourMsg; + MpiMessage ourMsg; auto msgIt = umb->getFirstNullMsg(); if (isLocal) { // First receive messages that happened before us @@ -1565,13 +1568,6 @@ int MpiWorld::getIndexForRanks(int sendRank, int recvRank) const return index; } -long MpiWorld::getLocalQueueSize(int sendRank, int recvRank) -{ - const std::shared_ptr& queue = - getLocalQueue(sendRank, recvRank); - return queue->size(); -} - double MpiWorld::getWTime() { double t = faabric::util::getTimeDiffMillis(creationTime); diff --git a/src/mpi/mpi.proto b/src/mpi/mpi.proto deleted file mode 100644 index 80a690820..000000000 --- a/src/mpi/mpi.proto +++ /dev/null @@ -1,35 +0,0 @@ -syntax = "proto3"; - -package faabric.mpi; - -message MPIMessage { - enum MPIMessageType { - NORMAL = 0; - BARRIER_JOIN = 1; - BARRIER_DONE = 2; - SCATTER = 3; - GATHER = 4; - ALLGATHER = 5; - REDUCE = 6; - SCAN = 7; - ALLREDUCE = 8; - ALLTOALL = 9; - SENDRECV = 10; - BROADCAST = 11; - }; - - MPIMessageType messageType = 1; - - int32 id = 2; - int32 worldId = 3; - int32 sender = 4; - int32 destination = 5; - int32 type = 6; - int32 count = 7; - - // For remote messaging - optional bytes buffer = 8; - - // For local messaging - optional int64 bufferPtr = 9; -} diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index 5daa2b5cb..8ed729a8e 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -199,14 +199,6 @@ message StateAppendedResponse { // POINT-TO-POINT // --------------------------------------------- -message PointToPointMessage { - int32 appId = 1; - int32 groupId = 2; - int32 sendIdx = 3; - int32 recvIdx = 4; - bytes data = 5; -} - message PointToPointMappings { int32 appId = 1; int32 groupId = 2; diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 63a87c49d..6ce6f4afd 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -459,12 +459,32 @@ Scheduler::checkForMigrationOpportunities(faabric::Message& msg, auto groupIdxs = broker.getIdxsRegisteredForGroup(groupId); groupIdxs.erase(0); for (const auto& recvIdx : groupIdxs) { - broker.sendMessage( - groupId, 0, recvIdx, BYTES_CONST(&newGroupId), sizeof(int)); + // It is safe to send a pointer to the stack, because the + // transport layer will perform an additional copy of the PTP + // message to put it in the message body + // TODO(no-inproc): this may not be true once we move the inproc + // sockets to in-memory queues + faabric::transport::PointToPointMessage msg( + { .groupId = groupId, + .sendIdx = 0, + .recvIdx = recvIdx, + .dataSize = sizeof(int), + .dataPtr = &newGroupId }); + + broker.sendMessage(msg); } } else if (overwriteNewGroupId == 0) { - std::vector bytes = broker.recvMessage(groupId, 0, groupIdx); + faabric::transport::PointToPointMessage msg( + { .groupId = groupId, .sendIdx = 0, .recvIdx = groupIdx }); + // TODO(no-order): when we remove the need to order ptp messages we + // should be able to call recv giving it a pre-allocated buffer, + // avoiding the hassle of malloc-ing and free-ing + broker.recvMessage(msg); + std::vector bytes((uint8_t*)msg.dataPtr, + (uint8_t*)msg.dataPtr + msg.dataSize); newGroupId = faabric::util::bytesToInt(bytes); + // The previous call makes a copy, so safe to free now + faabric::util::free(msg.dataPtr); } else { // In some settings, like tests, we already know the new group id, so // we can set it here (and in fact, we need to do so when faking two diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt index e8b7c339d..e68fa72bd 100644 --- a/src/transport/CMakeLists.txt +++ b/src/transport/CMakeLists.txt @@ -9,6 +9,7 @@ faabric_lib(transport MessageEndpointServer.cpp PointToPointBroker.cpp PointToPointClient.cpp + PointToPointMessage.cpp PointToPointServer.cpp ) diff --git a/src/transport/MessageEndpointClient.cpp b/src/transport/MessageEndpointClient.cpp index 7984c951b..4bee05763 100644 --- a/src/transport/MessageEndpointClient.cpp +++ b/src/transport/MessageEndpointClient.cpp @@ -36,6 +36,7 @@ void MessageEndpointClient::asyncSend(int header, sequenceNum); } +// TODO: consider making an iovec-style scatter/gather alternative signature void MessageEndpointClient::asyncSend(int header, const uint8_t* buffer, size_t bufferSize, diff --git a/src/transport/PointToPointBroker.cpp b/src/transport/PointToPointBroker.cpp index 9581fc27a..d2c5a0cc3 100644 --- a/src/transport/PointToPointBroker.cpp +++ b/src/transport/PointToPointBroker.cpp @@ -53,7 +53,8 @@ thread_local std::vector sentMsgCount; thread_local std::vector recvMsgCount; -thread_local std::vector> outOfOrderMsgs; +thread_local std::vector>> + outOfOrderMsgs; static std::shared_ptr getClient(const std::string& host) { @@ -202,8 +203,12 @@ void PointToPointGroup::lock(int groupIdx, bool recursive) groupId, recursive); - ptpBroker.recvMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); } else { // Notify remote locker that they've acquired the lock SPDLOG_TRACE( @@ -217,10 +222,6 @@ void PointToPointGroup::lock(int groupIdx, bool recursive) } } else { auto cli = getClient(mainHost); - faabric::PointToPointMessage msg; - msg.set_groupid(groupId); - msg.set_sendidx(groupIdx); - msg.set_recvidx(POINT_TO_POINT_MAIN_IDX); SPDLOG_TRACE("Remote lock {}:{}:{} to {}", groupId, @@ -232,7 +233,12 @@ void PointToPointGroup::lock(int groupIdx, bool recursive) // acquired cli->groupLock(appId, groupId, groupIdx, recursive); - ptpBroker.recvMessage(groupId, POINT_TO_POINT_MAIN_IDX, groupIdx); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); } } @@ -285,10 +291,6 @@ void PointToPointGroup::unlock(int groupIdx, bool recursive) } } else { auto cli = getClient(host); - faabric::PointToPointMessage msg; - msg.set_groupid(groupId); - msg.set_sendidx(groupIdx); - msg.set_recvidx(POINT_TO_POINT_MAIN_IDX); SPDLOG_TRACE("Remote unlock {}:{}:{} to {}", groupId, @@ -308,9 +310,13 @@ void PointToPointGroup::localUnlock() void PointToPointGroup::notifyLocked(int groupIdx) { std::vector data(1, 0); - - ptpBroker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx, data.data(), data.size()); + PointToPointMessage msg = { .appId = 0, + .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }; + ptpBroker.sendMessage(msg); } void PointToPointGroup::barrier(int groupIdx) @@ -324,23 +330,40 @@ void PointToPointGroup::barrier(int groupIdx) if (groupIdx == POINT_TO_POINT_MAIN_IDX) { // Receive from all for (int i = 1; i < groupSize; i++) { - ptpBroker.recvMessage(groupId, i, POINT_TO_POINT_MAIN_IDX); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = i, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); } // Reply to all std::vector data(1, 0); for (int i = 1; i < groupSize; i++) { - ptpBroker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, i, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = i, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.sendMessage(msg); } } else { // Do the send - std::vector data(1, 0); - ptpBroker.sendMessage( - groupId, groupIdx, POINT_TO_POINT_MAIN_IDX, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.sendMessage(msg); // Await the response - ptpBroker.recvMessage(groupId, POINT_TO_POINT_MAIN_IDX, groupIdx); + PointToPointMessage response({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(response); } } @@ -351,15 +374,23 @@ void PointToPointGroup::notify(int groupIdx) SPDLOG_TRACE( "Master group {} waiting for notify from index {}", groupId, i); - ptpBroker.recvMessage(groupId, i, POINT_TO_POINT_MAIN_IDX); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = i, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.recvMessage(msg); SPDLOG_TRACE("Master group {} notified by index {}", groupId, i); } } else { - std::vector data(1, 0); SPDLOG_TRACE("Notifying group {} from index {}", groupId, groupIdx); - ptpBroker.sendMessage( - groupId, groupIdx, POINT_TO_POINT_MAIN_IDX, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); + ptpBroker.sendMessage(msg); } } @@ -581,22 +612,11 @@ void PointToPointBroker::updateHostForIdx(int groupId, mappings[key] = newHost; } -void PointToPointBroker::sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, +void PointToPointBroker::sendMessage(const PointToPointMessage& msg, std::string hostHint, bool mustOrderMsg) { - sendMessage(groupId, - sendIdx, - recvIdx, - buffer, - bufferSize, - mustOrderMsg, - NO_SEQUENCE_NUM, - hostHint); + sendMessage(msg, mustOrderMsg, NO_SEQUENCE_NUM, hostHint); } // Gets or creates a pair of inproc endpoints (recv&send) in the endpoints map. @@ -634,11 +654,7 @@ auto getEndpointPtrs(const std::string& label) return endpointPtrs; } -void PointToPointBroker::sendMessage(int groupId, - int sendIdx, - int recvIdx, - const uint8_t* buffer, - size_t bufferSize, +void PointToPointBroker::sendMessage(const PointToPointMessage& msg, bool mustOrderMsg, int sequenceNum, std::string hostHint) @@ -647,19 +663,21 @@ void PointToPointBroker::sendMessage(int groupId, // sender thread, and another time from the point-to-point server to route // it to the receiver thread - waitForMappingsOnThisHost(groupId); + waitForMappingsOnThisHost(msg.groupId); // If the application code knows which host does the receiver live in // (cached for performance) we allow it to provide a hint to avoid // acquiring a shared lock here - std::string host = - hostHint.empty() ? getHostForReceiver(groupId, recvIdx) : hostHint; + std::string host = hostHint.empty() + ? getHostForReceiver(msg.groupId, msg.recvIdx) + : hostHint; // Set the sequence number if we need ordering and one is not provided bool mustSetSequenceNum = mustOrderMsg && sequenceNum == NO_SEQUENCE_NUM; if (host == conf.endpointHost) { - std::string label = getPointToPointKey(groupId, sendIdx, recvIdx); + std::string label = + getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx); auto endpointPtrs = getEndpointPtrs(label); auto& endpoint = @@ -671,46 +689,49 @@ void PointToPointBroker::sendMessage(int groupId, // the sender thread we add a sequence number (if needed) int localSendSeqNum = sequenceNum; if (mustSetSequenceNum) { - localSendSeqNum = getAndIncrementSentMsgCount(groupId, recvIdx); + localSendSeqNum = + getAndIncrementSentMsgCount(msg.groupId, msg.recvIdx); } SPDLOG_TRACE("Local point-to-point message {}:{}:{} (seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, localSendSeqNum, endpoint.getAddress()); try { - endpoint.send(NO_HEADER, buffer, bufferSize, localSendSeqNum); + // TODO(no-inproc): once we convert the inproc endpoints to a queue + // we should be able to just push the whole message to the queue + std::vector buffer(sizeof(PointToPointMessage) + + msg.dataSize); + serializePtpMsg(buffer, msg); + endpoint.send( + NO_HEADER, buffer.data(), buffer.size(), localSendSeqNum); } catch (std::runtime_error& e) { SPDLOG_ERROR("Timed-out with local point-to-point message {}:{}:{} " "(seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, localSendSeqNum, endpoint.getAddress()); throw e; } } else { auto cli = getClient(host); - faabric::PointToPointMessage msg; - msg.set_groupid(groupId); - msg.set_sendidx(sendIdx); - msg.set_recvidx(recvIdx); - msg.set_data(buffer, bufferSize); // When sending a remote message, we set a sequence number if required int remoteSendSeqNum = NO_SEQUENCE_NUM; if (mustSetSequenceNum) { - remoteSendSeqNum = getAndIncrementSentMsgCount(groupId, recvIdx); + remoteSendSeqNum = + getAndIncrementSentMsgCount(msg.groupId, msg.recvIdx); } SPDLOG_TRACE("Remote point-to-point message {}:{}:{} (seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, remoteSendSeqNum, host); @@ -719,59 +740,81 @@ void PointToPointBroker::sendMessage(int groupId, } catch (std::runtime_error& e) { SPDLOG_TRACE("Timed-out with remote point-to-point message " "{}:{}:{} (seq: {}) to {}", - groupId, - sendIdx, - recvIdx, + msg.groupId, + msg.sendIdx, + msg.recvIdx, remoteSendSeqNum, host); } } } -Message PointToPointBroker::doRecvMessage(int groupId, int sendIdx, int recvIdx) +std::pair PointToPointBroker::doRecvMessage( + PointToPointMessage& msg) { - std::string label = getPointToPointKey(groupId, sendIdx, recvIdx); + std::string label = + getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx); auto endpointPtrs = getEndpointPtrs(label); auto& endpoint = *std::get>( *endpointPtrs); - return endpoint.recv(); + // TODO(no-inproc): this will become a pop from a queue, not a read from + // an in-proc socket + Message bytes = endpoint.recv(); + + // WARNING: this call mallocs + parsePtpMsg(bytes.udata(), &msg); + + /* TODO(no-order): for the moment always parse and malloc memory, as it is + * not easy to track when did we malloc or not. This is gonna become + * simpler once we remove the need to order messages in the PTP layer + * + if (hasPreAllocBuffer) { + std::span msgDataSpan((uint8_t*) msg.dataPtr, msg.dataSize); + parsePtpMsg(bytes.udata(), &msg, msgDataSpan); + } else { + parsePtpMsg(bytes.udata(), &msg); + } + */ + + assert(getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx) == label); + + return std::make_pair(bytes.getResponseCode(), + bytes.getSequenceNum()); } -std::vector PointToPointBroker::recvMessage(int groupId, - int sendIdx, - int recvIdx, - bool mustOrderMsg) +void PointToPointBroker::recvMessage(PointToPointMessage& msg, + bool mustOrderMsg) { // If we don't need to receive messages in order, return here if (!mustOrderMsg) { - // TODO - can we avoid this copy? - return doRecvMessage(groupId, sendIdx, recvIdx).dataCopy(); + doRecvMessage(msg); + return; } // Get the sequence number we expect to receive - int expectedSeqNum = getExpectedSeqNum(groupId, sendIdx); + int expectedSeqNum = getExpectedSeqNum(msg.groupId, msg.sendIdx); // We first check if we have already received the message. We only need to // check this once. - auto foundIterator = - std::find_if(outOfOrderMsgs.at(sendIdx).begin(), - outOfOrderMsgs.at(sendIdx).end(), - [expectedSeqNum](const Message& msg) { - return msg.getSequenceNum() == expectedSeqNum; - }); - if (foundIterator != outOfOrderMsgs.at(sendIdx).end()) { + auto foundIterator = std::find_if( + outOfOrderMsgs.at(msg.sendIdx).begin(), + outOfOrderMsgs.at(msg.sendIdx).end(), + [expectedSeqNum](const std::pair& pair) { + return pair.first == expectedSeqNum; + }); + if (foundIterator != outOfOrderMsgs.at(msg.sendIdx).end()) { SPDLOG_TRACE("Retrieved the expected message ({}:{} seq: {}) from the " "out-of-order buffer", - sendIdx, - recvIdx, + msg.sendIdx, + msg.recvIdx, expectedSeqNum); - incrementRecvMsgCount(groupId, sendIdx); - Message returnMsg = std::move(*foundIterator); - outOfOrderMsgs.at(sendIdx).erase(foundIterator); - return returnMsg.dataCopy(); + incrementRecvMsgCount(msg.groupId, msg.sendIdx); + msg = foundIterator->second; + outOfOrderMsgs.at(msg.sendIdx).erase(foundIterator); + return; } // Given that we don't have the message, we query the transport layer until @@ -779,47 +822,52 @@ std::vector PointToPointBroker::recvMessage(int groupId, while (true) { SPDLOG_TRACE( "Entering loop to query transport layer for msg ({}:{} seq: {})", - sendIdx, - recvIdx, + msg.sendIdx, + msg.recvIdx, expectedSeqNum); - // Receive from the transport layer - Message recvMsg = doRecvMessage(groupId, sendIdx, recvIdx); + + // Receive from the transport layer with the same group id and + // send/recv indexes + PointToPointMessage tmpMsg({ .groupId = msg.groupId, + .sendIdx = msg.sendIdx, + .recvIdx = msg.recvIdx }); + auto [responseCode, seqNum] = doRecvMessage(tmpMsg); // If the receive was not successful, exit the loop - if (recvMsg.getResponseCode() != - faabric::transport::MessageResponseCode::SUCCESS) { + if (responseCode != faabric::transport::MessageResponseCode::SUCCESS) { SPDLOG_WARN( "Error {} ({}) when awaiting a message ({}:{} seq: {} label: {})", - static_cast(recvMsg.getResponseCode()), - MessageResponseCodeText.at(recvMsg.getResponseCode()), - sendIdx, - recvIdx, + static_cast(responseCode), + MessageResponseCodeText.at(responseCode), + msg.sendIdx, + msg.recvIdx, expectedSeqNum, - getPointToPointKey(groupId, sendIdx, recvIdx)); + getPointToPointKey(msg.groupId, msg.sendIdx, msg.recvIdx)); throw std::runtime_error("Error when awaiting a PTP message"); } // If the sequence numbers match, exit the loop - int seqNum = recvMsg.getSequenceNum(); if (seqNum == expectedSeqNum) { SPDLOG_TRACE("Received the expected message ({}:{} seq: {})", - sendIdx, - recvIdx, + msg.sendIdx, + msg.recvIdx, expectedSeqNum); - incrementRecvMsgCount(groupId, sendIdx); - return recvMsg.dataCopy(); + incrementRecvMsgCount(msg.groupId, msg.sendIdx); + + msg = tmpMsg; + return; } // If not, we must insert the received message in the out of order // received messages SPDLOG_TRACE("Received out-of-order message ({}:{} seq: {}) (expected: " "{} - got: {})", - sendIdx, - recvIdx, + tmpMsg.sendIdx, + tmpMsg.recvIdx, seqNum, expectedSeqNum, seqNum); - outOfOrderMsgs.at(sendIdx).emplace_back(std::move(recvMsg)); + outOfOrderMsgs.at(tmpMsg.sendIdx).emplace_back(seqNum, tmpMsg); } } @@ -874,10 +922,10 @@ void PointToPointBroker::resetThreadLocalCache() void PointToPointBroker::postMigrationHook(int groupId, int groupIdx) { + /* int postMigrationOkCode = 1337; int recvCode = 0; - // TODO: implement this as a broadcast in the PTP broker int mainIdx = 0; if (groupIdx == mainIdx) { auto groupIdxs = getIdxsRegisteredForGroup(groupId); @@ -902,6 +950,8 @@ void PointToPointBroker::postMigrationHook(int groupId, int groupIdx) recvCode); throw std::runtime_error("Error in post-migration hook"); } + */ + PointToPointGroup::getGroup(groupId)->barrier(groupIdx); SPDLOG_DEBUG("{}:{} exiting post-migration hook", groupId, groupIdx); } diff --git a/src/transport/PointToPointClient.cpp b/src/transport/PointToPointClient.cpp index d0b7188f8..506fc9874 100644 --- a/src/transport/PointToPointClient.cpp +++ b/src/transport/PointToPointClient.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -13,12 +14,11 @@ static std::mutex mockMutex; static std::vector> sentMappings; -static std::vector> - sentMessages; +static std::vector> sentMessages; static std::vector> + PointToPointMessage>> sentLockMessages; std::vector> @@ -27,7 +27,7 @@ getSentMappings() return sentMappings; } -std::vector> +std::vector> getSentPointToPointMessages() { return sentMessages; @@ -35,7 +35,7 @@ getSentPointToPointMessages() std::vector> + PointToPointMessage>> getSentLockMessages() { return sentLockMessages; @@ -64,13 +64,18 @@ void PointToPointClient::sendMappings(faabric::PointToPointMappings& mappings) } } -void PointToPointClient::sendMessage(faabric::PointToPointMessage& msg, +void PointToPointClient::sendMessage(const PointToPointMessage& msg, int sequenceNum) { if (faabric::util::isMockMode()) { sentMessages.emplace_back(host, msg); } else { - asyncSend(PointToPointCall::MESSAGE, &msg, sequenceNum); + // TODO(FIXME): consider how we can avoid serialising once, and then + // copying again into NNG's buffer + std::vector buffer(sizeof(msg) + msg.dataSize); + serializePtpMsg(buffer, msg); + asyncSend( + PointToPointCall::MESSAGE, buffer.data(), buffer.size(), sequenceNum); } } @@ -80,11 +85,12 @@ void PointToPointClient::makeCoordinationRequest( int groupIdx, faabric::transport::PointToPointCall call) { - faabric::PointToPointMessage req; - req.set_appid(appId); - req.set_groupid(groupId); - req.set_sendidx(groupIdx); - req.set_recvidx(POINT_TO_POINT_MAIN_IDX); + PointToPointMessage req({ .appId = appId, + .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = POINT_TO_POINT_MAIN_IDX, + .dataSize = 0, + .dataPtr = nullptr }); switch (call) { case (faabric::transport::PointToPointCall::LOCK_GROUP): { @@ -115,7 +121,11 @@ void PointToPointClient::makeCoordinationRequest( faabric::util::UniqueLock lock(mockMutex); sentLockMessages.emplace_back(host, call, req); } else { - asyncSend(call, &req); + // TODO(FIXME): consider how we can avoid serialising once, and then + // copying again into NNG's buffer + std::vector buffer(sizeof(PointToPointMessage) + req.dataSize); + serializePtpMsg(buffer, req); + asyncSend(call, buffer.data(), buffer.size()); } } diff --git a/src/transport/PointToPointMessage.cpp b/src/transport/PointToPointMessage.cpp new file mode 100644 index 000000000..e9415db11 --- /dev/null +++ b/src/transport/PointToPointMessage.cpp @@ -0,0 +1,62 @@ +#include +#include + +#include +#include +#include + +namespace faabric::transport { + +void serializePtpMsg(std::span buffer, const PointToPointMessage& msg) +{ + assert(buffer.size() == sizeof(PointToPointMessage) + msg.dataSize); + std::memcpy(buffer.data(), &msg, sizeof(PointToPointMessage)); + + if (msg.dataSize > 0 && msg.dataPtr != nullptr) { + std::memcpy(buffer.data() + sizeof(PointToPointMessage), + msg.dataPtr, + msg.dataSize); + } +} + +// Parse all the fixed-size parts of the struct +static void parsePtpMsgCommon(std::span bytes, + PointToPointMessage* msg) +{ + assert(msg != nullptr); + assert(bytes.size() >= sizeof(PointToPointMessage)); + std::memcpy(msg, bytes.data(), sizeof(PointToPointMessage)); + size_t thisDataSize = bytes.size() - sizeof(PointToPointMessage); + assert(thisDataSize == msg->dataSize); + + if (thisDataSize == 0) { + msg->dataPtr = nullptr; + } +} + +void parsePtpMsg(std::span bytes, PointToPointMessage* msg) +{ + parsePtpMsgCommon(bytes, msg); + + if (msg->dataSize == 0) { + return; + } + + // malloc memory for the PTP message payload + msg->dataPtr = faabric::util::malloc(msg->dataSize); + std::memcpy( + msg->dataPtr, bytes.data() + sizeof(PointToPointMessage), msg->dataSize); +} + +void parsePtpMsg(std::span bytes, + PointToPointMessage* msg, + std::span preAllocBuffer) +{ + parsePtpMsgCommon(bytes, msg); + + assert(msg->dataSize == preAllocBuffer.size()); + msg->dataPtr = preAllocBuffer.data(); + std::memcpy( + msg->dataPtr, bytes.data() + sizeof(PointToPointMessage), msg->dataSize); +} +} diff --git a/src/transport/PointToPointServer.cpp b/src/transport/PointToPointServer.cpp index 173fec0bf..6224eed84 100644 --- a/src/transport/PointToPointServer.cpp +++ b/src/transport/PointToPointServer.cpp @@ -1,12 +1,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include namespace faabric::transport { @@ -25,9 +27,11 @@ void PointToPointServer::doAsyncRecv(transport::Message& message) int sequenceNum = message.getSequenceNum(); switch (header) { case (faabric::transport::PointToPointCall::MESSAGE): { - PARSE_MSG(faabric::PointToPointMessage, - message.udata().data(), - message.udata().size()) + // Here we are copying the message from the transport layer (NNG) + // into our PTP message structure + // NOTE: this mallocs + PointToPointMessage parsedMsg; + parsePtpMsg(message.udata(), &parsedMsg); // If the sequence number is set, we must also set the ordering // flag @@ -35,13 +39,15 @@ void PointToPointServer::doAsyncRecv(transport::Message& message) // Send the message locally to the downstream socket, add the // sequence number for in-order reception - broker.sendMessage(parsedMsg.groupid(), - parsedMsg.sendidx(), - parsedMsg.recvidx(), - BYTES_CONST(parsedMsg.data().c_str()), - parsedMsg.data().size(), - mustOrderMsg, - sequenceNum); + broker.sendMessage(parsedMsg, mustOrderMsg, sequenceNum); + + // TODO(no-inproc): for the moment, the downstream (inproc) + // socket makes a copy of this message, so we can free it now + // after sending. This will not be the case once we move to + // in-memory queues + if (parsedMsg.dataPtr != nullptr) { + faabric::util::free(parsedMsg.dataPtr); + } break; } case faabric::transport::PointToPointCall::LOCK_GROUP: { @@ -101,28 +107,33 @@ std::unique_ptr PointToPointServer::doRecvMappings( void PointToPointServer::recvGroupLock(std::span buffer, bool recursive) { - PARSE_MSG(faabric::PointToPointMessage, buffer.data(), buffer.size()) + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg); + assert(parsedMsg.dataPtr == nullptr && parsedMsg.dataSize == 0); + SPDLOG_TRACE("Receiving lock on {} for idx {} (recursive {})", - parsedMsg.groupid(), - parsedMsg.sendidx(), + parsedMsg.groupId, + parsedMsg.sendIdx, recursive); - PointToPointGroup::getGroup(parsedMsg.groupid()) - ->lock(parsedMsg.sendidx(), recursive); + PointToPointGroup::getGroup(parsedMsg.groupId) + ->lock(parsedMsg.sendIdx, recursive); } void PointToPointServer::recvGroupUnlock(std::span buffer, bool recursive) { - PARSE_MSG(faabric::PointToPointMessage, buffer.data(), buffer.size()) + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg); + assert(parsedMsg.dataPtr == nullptr && parsedMsg.dataSize == 0); SPDLOG_TRACE("Receiving unlock on {} for idx {} (recursive {})", - parsedMsg.groupid(), - parsedMsg.sendidx(), + parsedMsg.groupId, + parsedMsg.sendIdx, recursive); - PointToPointGroup::getGroup(parsedMsg.groupid()) - ->unlock(parsedMsg.sendidx(), recursive); + PointToPointGroup::getGroup(parsedMsg.groupId) + ->unlock(parsedMsg.sendIdx, recursive); } void PointToPointServer::onWorkerStop() diff --git a/tests/dist/mpi/mpi_native.cpp b/tests/dist/mpi/mpi_native.cpp index d41235940..a499fb357 100644 --- a/tests/dist/mpi/mpi_native.cpp +++ b/tests/dist/mpi/mpi_native.cpp @@ -2,9 +2,9 @@ #include #include +#include #include #include -#include #include #include #include @@ -126,7 +126,7 @@ int MPI_Send(const void* buf, (uint8_t*)buf, datatype, count, - MPIMessage::NORMAL); + MpiMessageType::NORMAL); return MPI_SUCCESS; } @@ -159,7 +159,7 @@ int MPI_Recv(void* buf, datatype, count, status, - MPIMessage::NORMAL); + MpiMessageType::NORMAL); return MPI_SUCCESS; } @@ -245,7 +245,7 @@ int MPI_Bcast(void* buffer, int rank = executingContext.getRank(); world.broadcast( - root, rank, (uint8_t*)buffer, datatype, count, MPIMessage::BROADCAST); + root, rank, (uint8_t*)buffer, datatype, count, MpiMessageType::BROADCAST); return MPI_SUCCESS; } diff --git a/tests/dist/transport/functions.cpp b/tests/dist/transport/functions.cpp index 8c99b05b2..1485f5f47 100644 --- a/tests/dist/transport/functions.cpp +++ b/tests/dist/transport/functions.cpp @@ -4,9 +4,9 @@ #include "faabric_utils.h" #include "init.h" -#include #include #include +#include #include #include #include @@ -43,12 +43,25 @@ int handlePointToPointFunction( std::vector expectedRecvData(10, recvFromIdx); // Do the sending - broker.sendMessage( - groupId, groupIdx, sendToIdx, sendData.data(), sendData.size()); + PointToPointMessage sendMsg({ .groupId = groupId, + .sendIdx = groupIdx, + .recvIdx = sendToIdx, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(sendMsg); // Do the receiving - std::vector actualRecvData = - broker.recvMessage(groupId, recvFromIdx, groupIdx); + PointToPointMessage recvMsg({ .groupId = groupId, + .sendIdx = recvFromIdx, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.recvMessage(recvMsg); + std::vector actualRecvData(recvMsg.dataSize); + std::memcpy(actualRecvData.data(), recvMsg.dataPtr, recvMsg.dataSize); + // TODO(no-order): we will be able to change the signature of recvMessage + // to take in a pre-allocated buffer to read into + faabric::util::free(recvMsg.dataPtr); // Check data is as expected if (actualRecvData != expectedRecvData) { @@ -82,19 +95,31 @@ int handleManyPointToPointMsgFunction( // Send loop for (int i = 0; i < numMsg; i++) { std::vector sendData(5, i); - broker.sendMessage(groupId, - sendIdx, - recvIdx, - sendData.data(), - sendData.size(), - true); + PointToPointMessage sendMsg({ .groupId = groupId, + .sendIdx = sendIdx, + .recvIdx = recvIdx, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(sendMsg, true); } } else if (groupIdx == recvIdx) { // Recv loop for (int i = 0; i < numMsg; i++) { std::vector expectedData(5, i); - auto actualData = - broker.recvMessage(groupId, sendIdx, recvIdx, true); + + PointToPointMessage recvMsg({ .groupId = groupId, + .sendIdx = sendIdx, + .recvIdx = recvIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.recvMessage(recvMsg, true); + + std::vector actualData(recvMsg.dataSize); + std::memcpy(actualData.data(), recvMsg.dataPtr, recvMsg.dataSize); + // TODO(no-order): we will be able to change the signature of + // recvMessage to take in a pre-allocated buffer to read into + faabric::util::free(recvMsg.dataPtr); + if (actualData != expectedData) { SPDLOG_ERROR( "Out-of-order message reception (got: {}, expected: {})", diff --git a/tests/dist/transport/test_point_to_point.cpp b/tests/dist/transport/test_point_to_point.cpp index 28ab5b6c7..8cd0f6e49 100644 --- a/tests/dist/transport/test_point_to_point.cpp +++ b/tests/dist/transport/test_point_to_point.cpp @@ -5,7 +5,6 @@ #include "init.h" #include -#include #include #include #include diff --git a/tests/test/mpi/test_mpi_message.cpp b/tests/test/mpi/test_mpi_message.cpp new file mode 100644 index 000000000..9c79f8d3b --- /dev/null +++ b/tests/test/mpi/test_mpi_message.cpp @@ -0,0 +1,123 @@ +#include + +#include +#include + +#include + +using namespace faabric::mpi; + +namespace tests { + +bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB) +{ + auto sizeA = msgSize(msgA); + auto sizeB = msgSize(msgB); + + if (sizeA != sizeB) { + return false; + } + + // First, compare the message body (excluding the pointer, which we + // know is at the end) + if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - sizeof(void*)) != 0) { + return false; + } + + // Check that if one buffer points to null, so must do the other + if (msgA.buffer == nullptr || msgB.buffer == nullptr) { + return msgA.buffer == msgB.buffer; + } + + // If none points to null, they must point to the same data + auto payloadSizeA = payloadSize(msgA); + auto payloadSizeB = payloadSize(msgB); + // Assert, as this should pass given the previous comparisons + assert(payloadSizeA == payloadSizeB); + + return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0; +} + +TEST_CASE("Test getting a message size", "[mpi]") +{ + MpiMessage msg = { .id = 1, + .worldId = 3, + .sendRank = 3, + .recvRank = 7, + .typeSize = 1, + .count = 3, + .messageType = MpiMessageType::NORMAL }; + + size_t expectedMsgSize = 0; + size_t expectedPayloadSize = 0; + + SECTION("Empty message") + { + msg.buffer = nullptr; + msg.count = 0; + expectedMsgSize = sizeof(MpiMessage); + expectedPayloadSize = 0; + } + + SECTION("Non-empty message") + { + std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.count = nums.size(); + msg.typeSize = sizeof(int); + msg.buffer = faabric::util::malloc(msg.count * msg.typeSize); + std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int)); + + expectedPayloadSize = sizeof(int) * nums.size(); + expectedMsgSize = sizeof(MpiMessage) + expectedPayloadSize; + } + + REQUIRE(expectedMsgSize == msgSize(msg)); + REQUIRE(expectedPayloadSize == payloadSize(msg)); + + if (msg.buffer != nullptr) { + faabric::util::free(msg.buffer); + } +} + +TEST_CASE("Test (de)serialising an MPI message", "[mpi]") +{ + MpiMessage msg = { .id = 1, + .worldId = 3, + .sendRank = 3, + .recvRank = 7, + .typeSize = 1, + .count = 3, + .messageType = MpiMessageType::NORMAL }; + + SECTION("Empty message") + { + msg.count = 0; + msg.buffer = nullptr; + } + + SECTION("Non-empty message") + { + std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.count = nums.size(); + msg.typeSize = sizeof(int); + msg.buffer = faabric::util::malloc(msg.count * msg.typeSize); + std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int)); + } + + // Serialise and de-serialise + std::vector buffer(msgSize(msg)); + serializeMpiMsg(buffer, msg); + + MpiMessage parsedMsg; + parseMpiMsg(buffer, &parsedMsg); + + REQUIRE(areMpiMsgEqual(msg, parsedMsg)); + + if (msg.buffer != nullptr) { + faabric::util::free(msg.buffer); + } + if (parsedMsg.buffer != nullptr) { + faabric::util::free(parsedMsg.buffer); + } +} +} diff --git a/tests/test/mpi/test_mpi_message_buffer.cpp b/tests/test/mpi/test_mpi_message_buffer.cpp index 1674172fd..710a3c259 100644 --- a/tests/test/mpi/test_mpi_message_buffer.cpp +++ b/tests/test/mpi/test_mpi_message_buffer.cpp @@ -21,7 +21,7 @@ MpiMessageBuffer::PendingAsyncMpiMessage genRandomArguments( pendingMsg.requestId = requestId; if (!nullMsg) { - pendingMsg.msg = std::make_shared(); + pendingMsg.msg = std::make_shared(); } return pendingMsg; diff --git a/tests/test/mpi/test_mpi_world.cpp b/tests/test/mpi/test_mpi_world.cpp index 2c0030b5f..8094cae7e 100644 --- a/tests/test/mpi/test_mpi_world.cpp +++ b/tests/test/mpi/test_mpi_world.cpp @@ -212,23 +212,22 @@ TEST_CASE_METHOD(MpiBaseTestFixture, "Test local barrier", "[mpi]") world.destroy(); } -void checkMessage(MPIMessage& actualMessage, +void checkMessage(MpiMessage& actualMessage, int worldId, int senderRank, int destRank, const std::vector& data) { // Check the message contents - REQUIRE(actualMessage.worldid() == worldId); - REQUIRE(actualMessage.count() == data.size()); - REQUIRE(actualMessage.destination() == destRank); - REQUIRE(actualMessage.sender() == senderRank); - REQUIRE(actualMessage.type() == FAABRIC_INT); + REQUIRE(actualMessage.worldId == worldId); + REQUIRE(actualMessage.count == data.size()); + REQUIRE(actualMessage.recvRank == destRank); + REQUIRE(actualMessage.sendRank == senderRank); + REQUIRE(actualMessage.typeSize == FAABRIC_INT); // Check data - const auto* rawInts = - reinterpret_cast(actualMessage.buffer().c_str()); - size_t nInts = actualMessage.buffer().size() / sizeof(int); + const auto* rawInts = reinterpret_cast(actualMessage.buffer); + size_t nInts = payloadSize(actualMessage) / sizeof(int); std::vector actualData(rawInts, rawInts + nInts); REQUIRE(actualData == data); } @@ -396,10 +395,10 @@ TEST_CASE_METHOD(MpiTestFixture, "Test send/recv message with no data", "[mpi]") SECTION("Check on queue") { // Check message content - MPIMessage actualMessage = - *(world.getLocalQueue(rankA1, rankA2)->dequeue()); - REQUIRE(actualMessage.count() == 0); - REQUIRE(actualMessage.type() == FAABRIC_INT); + MpiMessage actualMessage = + world.getLocalQueue(rankA1, rankA2)->dequeue(); + REQUIRE(actualMessage.count == 0); + REQUIRE(actualMessage.typeSize == FAABRIC_INT); } SECTION("Check receiving with null ptr") @@ -502,7 +501,7 @@ TEST_CASE_METHOD(MpiTestFixture, "Test collective messaging locally", "[mpi]") BYTES(messageData.data()), MPI_INT, messageData.size(), - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); // Recv on all non-root ranks for (int rank = 0; rank < worldSize; rank++) { @@ -515,7 +514,7 @@ TEST_CASE_METHOD(MpiTestFixture, "Test collective messaging locally", "[mpi]") BYTES(actual.data()), MPI_INT, 3, - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); REQUIRE(actual == messageData); } } diff --git a/tests/test/mpi/test_multiple_mpi_worlds.cpp b/tests/test/mpi/test_multiple_mpi_worlds.cpp index 735556f6e..a1062e81b 100644 --- a/tests/test/mpi/test_multiple_mpi_worlds.cpp +++ b/tests/test/mpi/test_multiple_mpi_worlds.cpp @@ -155,29 +155,6 @@ TEST_CASE_METHOD(MultiWorldMpiTestFixture, worldB.send( rankA1, rankA2, BYTES(messageData.data()), MPI_INT, messageData.size()); - SECTION("Test queueing") - { - // Check for world A - REQUIRE(worldA.getLocalQueueSize(rankA1, rankA2) == 1); - REQUIRE(worldA.getLocalQueueSize(rankA2, rankA1) == 0); - REQUIRE(worldA.getLocalQueueSize(rankA1, 0) == 0); - REQUIRE(worldA.getLocalQueueSize(rankA2, 0) == 0); - const std::shared_ptr& queueA2 = - worldA.getLocalQueue(rankA1, rankA2); - MPIMessage actualMessage = *(queueA2->dequeue()); - // checkMessage(actualMessage, worldId, rankA1, rankA2, messageData); - - // Check for world B - REQUIRE(worldB.getLocalQueueSize(rankA1, rankA2) == 1); - REQUIRE(worldB.getLocalQueueSize(rankA2, rankA1) == 0); - REQUIRE(worldB.getLocalQueueSize(rankA1, 0) == 0); - REQUIRE(worldB.getLocalQueueSize(rankA2, 0) == 0); - const std::shared_ptr& queueA2B = - worldB.getLocalQueue(rankA1, rankA2); - actualMessage = *(queueA2B->dequeue()); - // checkMessage(actualMessage, worldId, rankA1, rankA2, messageData); - } - SECTION("Test recv") { MPI_Status status{}; diff --git a/tests/test/mpi/test_remote_mpi_worlds.cpp b/tests/test/mpi/test_remote_mpi_worlds.cpp index 1e56b48b1..54662929f 100644 --- a/tests/test/mpi/test_remote_mpi_worlds.cpp +++ b/tests/test/mpi/test_remote_mpi_worlds.cpp @@ -21,12 +21,11 @@ using namespace faabric::mpi; using namespace faabric::scheduler; namespace tests { -std::set getReceiversFromMessages( - std::vector> msgs) +std::set getReceiversFromMessages(std::vector msgs) { std::set receivers; for (const auto& msg : msgs) { - receivers.insert(msg->destination()); + receivers.insert(msg.recvRank); } return receivers; @@ -108,14 +107,14 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, BYTES(messageData.data()), MPI_INT, messageData.size(), - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); } else { otherWorld.broadcast(sendRank, recvRank, BYTES(messageData.data()), MPI_INT, messageData.size(), - MPIMessage::BROADCAST); + MpiMessageType::BROADCAST); } auto msgs = getMpiMockedMessages(recvRank); REQUIRE(msgs.size() == expectedNumMsg); @@ -219,12 +218,11 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, thisWorld.destroy(); } -std::set getMsgCountsFromMessages( - std::vector> msgs) +std::set getMsgCountsFromMessages(std::vector msgs) { std::set counts; for (const auto& msg : msgs) { - counts.insert(msg->count()); + counts.insert(msg.count); } return counts; diff --git a/tests/test/transport/test_point_to_point.cpp b/tests/test/transport/test_point_to_point.cpp index 98b16b9f7..6a23b90c5 100644 --- a/tests/test/transport/test_point_to_point.cpp +++ b/tests/test/transport/test_point_to_point.cpp @@ -120,9 +120,7 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, std::vector sentDataA = { 0, 1, 2, 3 }; std::vector receivedDataA; std::vector sentDataB = { 3, 4, 5 }; - std::vector receivedDataB; std::vector sentDataC = { 6, 7, 8 }; - std::vector receivedDataC; std::shared_ptr msgLatch = std::make_shared(2, 1000); @@ -131,34 +129,60 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, PointToPointBroker& broker = getPointToPointBroker(); // Receive the first message - receivedDataA = broker.recvMessage(groupId, idxA, idxB); + PointToPointMessage msgAB( + { .groupId = groupId, .sendIdx = idxA, .recvIdx = idxB }); + broker.recvMessage(msgAB); + receivedDataA.resize(msgAB.dataSize); + std::memcpy(receivedDataA.data(), msgAB.dataPtr, msgAB.dataSize); + faabric::util::free(msgAB.dataPtr); msgLatch->wait(); // Send a message back - broker.sendMessage( - groupId, idxB, idxA, sentDataB.data(), sentDataB.size()); + PointToPointMessage msgBA({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sentDataB.size(), + .dataPtr = sentDataB.data() }); + broker.sendMessage(msgBA); // Lastly, send another message specifying the recepient host to avoid // an extra check in the broker - broker.sendMessage(groupId, - idxB, - idxA, - sentDataC.data(), - sentDataC.size(), - std::string(LOCALHOST)); + PointToPointMessage msgBA2({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sentDataC.size(), + .dataPtr = sentDataC.data() }); + broker.sendMessage(msgBA2, std::string(LOCALHOST)); broker.resetThreadLocalCache(); }); // Only send the message after the thread creates a receiving socket to // avoid deadlock - broker.sendMessage(groupId, idxA, idxB, sentDataA.data(), sentDataA.size()); + PointToPointMessage msgAB({ .groupId = groupId, + .sendIdx = idxA, + .recvIdx = idxB, + .dataSize = sentDataA.size(), + .dataPtr = sentDataA.data() }); + broker.sendMessage(msgAB); // Wait for the thread to handle the message msgLatch->wait(); // Receive the two messages sent back - receivedDataB = broker.recvMessage(groupId, idxB, idxA); - receivedDataC = broker.recvMessage(groupId, idxB, idxA); + + PointToPointMessage msgBA1( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msgBA1); + std::vector receivedDataB( + (uint8_t*)msgBA1.dataPtr, (uint8_t*)msgBA1.dataPtr + msgBA1.dataSize); + faabric::util::free(msgBA1.dataPtr); + + PointToPointMessage msgBA2( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msgBA2); + std::vector receivedDataC( + (uint8_t*)msgBA2.dataPtr, (uint8_t*)msgBA2.dataPtr + msgBA2.dataSize); + faabric::util::free(msgBA2.dataPtr); if (t.joinable()) { t.join(); @@ -236,22 +260,28 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, std::vector recvData; for (int i = 0; i < numMsg; i++) { - recvData = - broker.recvMessage(groupId, idxA, idxB, isMessageOrderingOn); sendData = std::vector(3, i); + PointToPointMessage msg( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msg, isMessageOrderingOn); + recvData.resize(msg.dataSize); + // TODO(no-order): when we remove the need to order PTP messages + // we will be able to provide a buffer to receive the message into + std::memcpy(recvData.data(), msg.dataPtr, msg.dataSize); REQUIRE(recvData == sendData); + faabric::util::free(msg.dataPtr); } msgLatch->wait(); for (int i = 0; i < numMsg; i++) { sendData = std::vector(3, i); - broker.sendMessage(groupId, - idxB, - idxA, - sendData.data(), - sendData.size(), - isMessageOrderingOn); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(msg, isMessageOrderingOn); } broker.resetThreadLocalCache(); @@ -262,20 +292,26 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, for (int i = 0; i < numMsg; i++) { sendData = std::vector(3, i); - broker.sendMessage(groupId, - idxA, - idxB, - sendData.data(), - sendData.size(), - isMessageOrderingOn); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = idxB, + .recvIdx = idxA, + .dataSize = sendData.size(), + .dataPtr = sendData.data() }); + broker.sendMessage(msg, isMessageOrderingOn); } msgLatch->wait(); for (int i = 0; i < numMsg; i++) { sendData = std::vector(3, i); - recvData = broker.recvMessage(groupId, idxB, idxA, isMessageOrderingOn); + PointToPointMessage msg( + { .groupId = groupId, .sendIdx = idxB, .recvIdx = idxA }); + broker.recvMessage(msg, isMessageOrderingOn); + recvData.resize(msg.dataSize); + // REQUIRE(msg.dataSize == recvData.size()); + std::memcpy(recvData.data(), msg.dataPtr, msg.dataSize); REQUIRE(sendData == recvData); + faabric::util::free(msg.dataPtr); } if (t.joinable()) { diff --git a/tests/test/transport/test_point_to_point_groups.cpp b/tests/test/transport/test_point_to_point_groups.cpp index 8d9761335..fa583e70b 100644 --- a/tests/test/transport/test_point_to_point_groups.cpp +++ b/tests/test/transport/test_point_to_point_groups.cpp @@ -135,8 +135,12 @@ TEST_CASE_METHOD(PointToPointGroupFixture, op = PointToPointCall::LOCK_GROUP; // Prepare response - broker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.sendMessage(msg); group->lock(groupIdx, false); } @@ -147,8 +151,12 @@ TEST_CASE_METHOD(PointToPointGroupFixture, recursive = true; // Prepare response - broker.sendMessage( - groupId, POINT_TO_POINT_MAIN_IDX, groupIdx, data.data(), data.size()); + PointToPointMessage msg({ .groupId = groupId, + .sendIdx = POINT_TO_POINT_MAIN_IDX, + .recvIdx = groupIdx, + .dataSize = 0, + .dataPtr = nullptr }); + broker.sendMessage(msg); group->lock(groupIdx, recursive); } @@ -166,8 +174,7 @@ TEST_CASE_METHOD(PointToPointGroupFixture, group->unlock(groupIdx, recursive); } - std::vector< - std::tuple> + std::vector> actualRequests = getSentLockMessages(); REQUIRE(actualRequests.size() == 1); @@ -176,11 +183,11 @@ TEST_CASE_METHOD(PointToPointGroupFixture, PointToPointCall actualOp = std::get<1>(actualRequests.at(0)); REQUIRE(actualOp == op); - faabric::PointToPointMessage req = std::get<2>(actualRequests.at(0)); - REQUIRE(req.appid() == appId); - REQUIRE(req.groupid() == groupId); - REQUIRE(req.sendidx() == groupIdx); - REQUIRE(req.recvidx() == POINT_TO_POINT_MAIN_IDX); + PointToPointMessage req = std::get<2>(actualRequests.at(0)); + REQUIRE(req.appId == appId); + REQUIRE(req.groupId == groupId); + REQUIRE(req.sendIdx == groupIdx); + REQUIRE(req.recvIdx == POINT_TO_POINT_MAIN_IDX); } TEST_CASE_METHOD(PointToPointGroupFixture, diff --git a/tests/test/transport/test_point_to_point_message.cpp b/tests/test/transport/test_point_to_point_message.cpp new file mode 100644 index 000000000..51b1cb87c --- /dev/null +++ b/tests/test/transport/test_point_to_point_message.cpp @@ -0,0 +1,95 @@ +#include + +#include +#include + +#include + +using namespace faabric::transport; + +namespace tests { + +bool arePtpMsgEqual(const PointToPointMessage& msgA, const PointToPointMessage& msgB) +{ + // First, compare the message body (excluding the pointer, which we + // know is at the end) + if (std::memcmp(&msgA, &msgB, sizeof(PointToPointMessage) - sizeof(void*)) != 0) { + return false; + } + + // Check that if one buffer points to null, so must do the other + if (msgA.dataPtr == nullptr || msgB.dataPtr == nullptr) { + return msgA.dataPtr == msgB.dataPtr; + } + + return std::memcmp(msgA.dataPtr, msgB.dataPtr, msgA.dataSize) == 0; +} + +TEST_CASE("Test (de)serialising a PTP message", "[ptp]") +{ + PointToPointMessage msg({ .appId = 1, + .groupId = 2, + .sendIdx = 3, + .recvIdx = 4, + .dataSize = 0, + .dataPtr = nullptr }); + + SECTION("Empty message") + { + msg.dataSize = 0; + msg.dataPtr = nullptr; + } + + SECTION("Non-empty message") + { + std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.dataSize = nums.size() * sizeof(int); + msg.dataPtr = faabric::util::malloc(msg.dataSize); + std::memcpy(msg.dataPtr, nums.data(), msg.dataSize); + } + + // Serialise and de-serialise + std::vector buffer(sizeof(PointToPointMessage) + msg.dataSize); + serializePtpMsg(buffer, msg); + + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg); + + REQUIRE(arePtpMsgEqual(msg, parsedMsg)); + + if (msg.dataPtr != nullptr) { + faabric::util::free(msg.dataPtr); + } + if (parsedMsg.dataPtr != nullptr) { + faabric::util::free(parsedMsg.dataPtr); + } +} + +TEST_CASE("Test (de)serialising a PTP message into prealloc buffer", "[ptp]") +{ + PointToPointMessage msg({ .appId = 1, + .groupId = 2, + .sendIdx = 3, + .recvIdx = 4, + .dataSize = 0, + .dataPtr = nullptr }); + + std::vector nums = { 1, 2, 3, 4, 5, 6, 6 }; + msg.dataSize = nums.size() * sizeof(int); + msg.dataPtr = faabric::util::malloc(msg.dataSize); + std::memcpy(msg.dataPtr, nums.data(), msg.dataSize); + + // Serialise and de-serialise + std::vector buffer(sizeof(PointToPointMessage) + msg.dataSize); + serializePtpMsg(buffer, msg); + + std::vector preAllocBuffer(msg.dataSize); + PointToPointMessage parsedMsg; + parsePtpMsg(buffer, &parsedMsg, preAllocBuffer); + + REQUIRE(arePtpMsgEqual(msg, parsedMsg)); + REQUIRE(parsedMsg.dataPtr == preAllocBuffer.data()); + + faabric::util::free(msg.dataPtr); +} +}