diff --git a/include/faabric/mpi/MpiWorld.h b/include/faabric/mpi/MpiWorld.h index 0f7fc9e17..2ebb977d9 100644 --- a/include/faabric/mpi/MpiWorld.h +++ b/include/faabric/mpi/MpiWorld.h @@ -212,7 +212,11 @@ class MpiWorld std::string thisHost; faabric::util::TimePoint creationTime; - std::atomic activeLocalRanks = 0; + // Latch used to clear the world from the registry when we are migrating + // out of it (i.e. evicting it). Note that this clean-up is only necessary + // for migration, as we want to clean things up in case we ever migrate + // again back into this host + std::atomic evictionLatch = 0; std::atomic_flag isDestroyed = false; diff --git a/include/faabric/transport/PointToPointBroker.h b/include/faabric/transport/PointToPointBroker.h index 8dce115f7..fb7f91e92 100644 --- a/include/faabric/transport/PointToPointBroker.h +++ b/include/faabric/transport/PointToPointBroker.h @@ -119,6 +119,8 @@ class PointToPointBroker std::set getIdxsRegisteredForGroup(int groupId); + std::set getHostsRegisteredForGroup(int groupId); + void updateHostForIdx(int groupId, int groupIdx, std::string newHost); void sendMessage(int groupId, diff --git a/src/executor/Executor.cpp b/src/executor/Executor.cpp index 5312fd36b..32a91d776 100644 --- a/src/executor/Executor.cpp +++ b/src/executor/Executor.cpp @@ -426,19 +426,7 @@ void Executor::threadPoolThread(std::stop_token st, int threadPoolIdx) if (msg.ismpi()) { auto& mpiWorldRegistry = faabric::mpi::getMpiWorldRegistry(); if (mpiWorldRegistry.worldExists(msg.mpiworldid())) { - bool mustClear = - mpiWorldRegistry.getWorld(msg.mpiworldid()).destroy(); - - if (mustClear) { - SPDLOG_DEBUG("{}:{}:{} clearing world {} from host {}", - msg.appid(), - msg.groupid(), - msg.groupidx(), - msg.mpiworldid(), - msg.executedhost()); - - mpiWorldRegistry.clearWorld(msg.mpiworldid()); - } + mpiWorldRegistry.getWorld(msg.mpiworldid()).destroy(); } } } diff --git a/src/mpi/MpiWorld.cpp b/src/mpi/MpiWorld.cpp index 8a1b387cd..1b31691fe 100644 --- a/src/mpi/MpiWorld.cpp +++ b/src/mpi/MpiWorld.cpp @@ -227,11 +227,15 @@ void MpiWorld::create(faabric::Message& call, int newId, int newSize) bool MpiWorld::destroy() { + int groupId = -1; + if (rankState.msg != nullptr) { SPDLOG_TRACE("{}:{}:{} destroying MPI world", rankState.msg->appid(), rankState.msg->groupid(), rankState.msg->mpirank()); + + groupId = rankState.msg->groupid(); } // ----- Per-rank cleanup ----- @@ -246,12 +250,19 @@ bool MpiWorld::destroy() } #endif - // ----- Global accounting ----- - - int numActiveLocalRanks = - activeLocalRanks.fetch_sub(1, std::memory_order_acquire); + // If we are evicting the host during a migration, use the eviction latch + // for proper resource clean-up in the event of a future migration back + // into this host + bool isEviction = + groupId != -1 && + !broker.getHostsRegisteredForGroup(groupId).contains(thisHost); + if (isEviction) { + int numActiveLocalRanks = + evictionLatch.fetch_sub(1, std::memory_order_acquire); + return numActiveLocalRanks == 1; + } - return numActiveLocalRanks == 1; + return false; } // Initialise shared (per-host) MPI world state. This method is called once @@ -276,7 +287,6 @@ void MpiWorld::initialiseFromMsg(faabric::Message& msg) void MpiWorld::initialiseRankFromMsg(faabric::Message& msg) { rankState.msg = &msg; - activeLocalRanks++; // Pin this thread to a free CPU #ifdef FAABRIC_USE_SPINLOCK @@ -341,8 +351,19 @@ void MpiWorld::initLocalRemoteLeaders() portForRank.at(rankId) = broker.getMpiPortForReceiver(groupId, rankId); } - // Persist the local leader in this host for further use - localLeader = (*ranksForHost[thisHost].begin()); + // Finally, set up the infrastracture for proper clean-up of the world in + // case we are migrating away from it. Note that we are preparing the + // latch one migration before we migrate away. This is because we will also + // call this method right before evicting, so we want to have the latch + // already set + int numInThisHost = + ranksForHost.contains(thisHost) ? ranksForHost.at(thisHost).size() : 0; + bool mustEvictHost = numInThisHost == 0; + + if (!mustEvictHost) { + evictionLatch.store(numInThisHost, std::memory_order_release); + localLeader = (*ranksForHost[thisHost].begin()); + } } void MpiWorld::getCartesianRank(int rank, @@ -1918,9 +1939,14 @@ void MpiWorld::initSendRecvSockets() // corresponding receiver is local to this host, for any sender void MpiWorld::initLocalQueues() { + // Nothing to do if we are migrating away from this host + if (!ranksForHost.contains(thisHost)) { + return; + } + localQueues.resize(size * size); for (int sendRank = 0; sendRank < size; sendRank++) { - for (const int recvRank : ranksForHost[thisHost]) { + for (const int recvRank : ranksForHost.at(thisHost)) { // We handle messages-to-self as memory copies if (sendRank == recvRank) { continue; diff --git a/src/transport/PointToPointBroker.cpp b/src/transport/PointToPointBroker.cpp index 006ff9b8a..20aa48758 100644 --- a/src/transport/PointToPointBroker.cpp +++ b/src/transport/PointToPointBroker.cpp @@ -539,6 +539,21 @@ std::set PointToPointBroker::getIdxsRegisteredForGroup(int groupId) return groupIdIdxsMap[groupId]; } +std::set PointToPointBroker::getHostsRegisteredForGroup( + int groupId) +{ + faabric::util::SharedLock lock(brokerMutex); + std::set indexes = groupIdIdxsMap[groupId]; + + std::set hosts; + for (const auto& idx : indexes) { + std::string key = getPointToPointKey(groupId, idx); + hosts.insert(mappings.at(key)); + } + + return hosts; +} + void PointToPointBroker::initSequenceCounters(int groupId) { if (currentGroupId != NO_CURRENT_GROUP_ID) { diff --git a/tests/test/transport/test_point_to_point.cpp b/tests/test/transport/test_point_to_point.cpp index 98b16b9f7..159c073da 100644 --- a/tests/test/transport/test_point_to_point.cpp +++ b/tests/test/transport/test_point_to_point.cpp @@ -44,6 +44,8 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, REQUIRE(broker.getIdxsRegisteredForGroup(appIdA).empty()); REQUIRE(broker.getIdxsRegisteredForGroup(appIdB).empty()); + REQUIRE(broker.getHostsRegisteredForGroup(appIdA).empty()); + REQUIRE(broker.getHostsRegisteredForGroup(appIdB).empty()); faabric::PointToPointMappings mappingsA; mappingsA.set_appid(appIdA); @@ -73,6 +75,8 @@ TEST_CASE_METHOD(PointToPointClientServerFixture, REQUIRE(broker.getIdxsRegisteredForGroup(groupIdA).size() == 2); REQUIRE(broker.getIdxsRegisteredForGroup(groupIdB).size() == 1); + REQUIRE(broker.getHostsRegisteredForGroup(groupIdA).size() == 2); + REQUIRE(broker.getHostsRegisteredForGroup(groupIdB).size() == 1); REQUIRE(broker.getHostForReceiver(groupIdA, groupIdxA1) == hostA); REQUIRE(broker.getHostForReceiver(groupIdA, groupIdxA2) == hostB);