diff --git a/include/faabric/planner/PlannerApi.h b/include/faabric/planner/PlannerApi.h index ccbcbe47c..65e1a478f 100644 --- a/include/faabric/planner/PlannerApi.h +++ b/include/faabric/planner/PlannerApi.h @@ -16,5 +16,6 @@ enum PlannerCalls GetBatchResults = 10, GetSchedulingDecision = 11, CallBatch = 12, + PreloadSchedulingDecision = 13, }; } diff --git a/include/faabric/planner/PlannerClient.h b/include/faabric/planner/PlannerClient.h index 8e2f56c66..88d70d31b 100644 --- a/include/faabric/planner/PlannerClient.h +++ b/include/faabric/planner/PlannerClient.h @@ -92,6 +92,9 @@ class PlannerClient final : public faabric::transport::MessageEndpointClient faabric::batch_scheduler::SchedulingDecision getSchedulingDecision( std::shared_ptr req); + void preloadSchedulingDecision( + std::shared_ptr preloadDec); + private: std::mutex plannerCacheMx; PlannerCache cache; diff --git a/include/faabric/planner/PlannerServer.h b/include/faabric/planner/PlannerServer.h index 56f953070..11f859649 100644 --- a/include/faabric/planner/PlannerServer.h +++ b/include/faabric/planner/PlannerServer.h @@ -40,6 +40,9 @@ class PlannerServer final : public faabric::transport::MessageEndpointServer std::unique_ptr recvGetSchedulingDecision( std::span buffer); + std::unique_ptr recvPreloadSchedulingDecision( + std::span buffer); + std::unique_ptr recvCallBatch( std::span buffer); diff --git a/include/faabric/util/ptp.h b/include/faabric/util/ptp.h new file mode 100644 index 000000000..edabb04bc --- /dev/null +++ b/include/faabric/util/ptp.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +namespace faabric::util { +faabric::PointToPointMappings ptpMappingsFromSchedulingDecision( + std::shared_ptr decision); +} diff --git a/src/planner/PlannerClient.cpp b/src/planner/PlannerClient.cpp index 4d5760909..db4c03dc9 100644 --- a/src/planner/PlannerClient.cpp +++ b/src/planner/PlannerClient.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace faabric::planner { @@ -380,6 +381,15 @@ PlannerClient::getSchedulingDecision( return decision; } +void PlannerClient::preloadSchedulingDecision( + std::shared_ptr preloadDec) +{ + faabric::EmptyResponse response; + auto mappings = + faabric::util::ptpMappingsFromSchedulingDecision(preloadDec); + syncSend(PlannerCalls::PreloadSchedulingDecision, &mappings, &response); +} + // ----------------------------------- // Static setter/getters // ----------------------------------- diff --git a/src/planner/PlannerServer.cpp b/src/planner/PlannerServer.cpp index d57013925..088496c2b 100644 --- a/src/planner/PlannerServer.cpp +++ b/src/planner/PlannerServer.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include @@ -60,6 +61,9 @@ std::unique_ptr PlannerServer::doSyncRecv( case PlannerCalls::GetSchedulingDecision: { return recvGetSchedulingDecision(message.udata()); } + case PlannerCalls::PreloadSchedulingDecision: { + return recvPreloadSchedulingDecision(message.udata()); + } case PlannerCalls::CallBatch: { return recvCallBatch(message.udata()); } @@ -185,20 +189,29 @@ PlannerServer::recvGetSchedulingDecision(std::span buffer) } // Build PointToPointMappings from scheduling decision - faabric::PointToPointMappings mappings; - mappings.set_appid(decision->appId); - mappings.set_groupid(decision->groupId); - for (int i = 0; i < decision->hosts.size(); i++) { - auto* mapping = mappings.add_mappings(); - mapping->set_host(decision->hosts.at(i)); - mapping->set_messageid(decision->messageIds.at(i)); - mapping->set_appidx(decision->appIdxs.at(i)); - mapping->set_groupidx(decision->groupIdxs.at(i)); - } + faabric::PointToPointMappings mappings = + faabric::util::ptpMappingsFromSchedulingDecision(decision); return std::make_unique(mappings); } +std::unique_ptr +PlannerServer::recvPreloadSchedulingDecision(std::span buffer) +{ + PARSE_MSG(PointToPointMappings, buffer.data(), buffer.size()); + + auto preloadDecision = + faabric::batch_scheduler::SchedulingDecision::fromPointToPointMappings( + parsedMsg); + + planner.preloadSchedulingDecision( + preloadDecision.appId, + std::make_shared( + preloadDecision)); + + return std::make_unique(); +} + std::unique_ptr PlannerServer::recvCallBatch( std::span buffer) { diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index ee674e3cf..0e6121dd8 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -20,6 +20,7 @@ faabric_lib(util logging.cpp memory.cpp network.cpp + ptp.cpp queue.cpp random.cpp snapshot.cpp diff --git a/src/util/ptp.cpp b/src/util/ptp.cpp new file mode 100644 index 000000000..c5518bb19 --- /dev/null +++ b/src/util/ptp.cpp @@ -0,0 +1,20 @@ +#include + +namespace faabric::util { +faabric::PointToPointMappings ptpMappingsFromSchedulingDecision( + std::shared_ptr decision) +{ + faabric::PointToPointMappings mappings; + mappings.set_appid(decision->appId); + mappings.set_groupid(decision->groupId); + for (int i = 0; i < decision->hosts.size(); i++) { + auto* mapping = mappings.add_mappings(); + mapping->set_host(decision->hosts.at(i)); + mapping->set_messageid(decision->messageIds.at(i)); + mapping->set_appidx(decision->appIdxs.at(i)); + mapping->set_groupidx(decision->groupIdxs.at(i)); + } + + return mappings; +} +} diff --git a/tests/test/planner/test_planner_client_server.cpp b/tests/test/planner/test_planner_client_server.cpp index a99ddfe9c..e13ab00b0 100644 --- a/tests/test/planner/test_planner_client_server.cpp +++ b/tests/test/planner/test_planner_client_server.cpp @@ -258,4 +258,43 @@ TEST_CASE_METHOD(PlannerClientServerExecTestFixture, checkMessageEquality(messageResults[msg.id()], msg); } } + +TEST_CASE_METHOD(PlannerClientServerExecTestFixture, + "Test preloading a scheduling decision from the client", + "[planner]") +{ + int nFuncs = 4; + faabric::HostResources res; + res.set_slots(nFuncs); + sch.setThisHostResources(res); + auto req = faabric::util::batchExecFactory("foo", "bar", nFuncs); + + // Preload a scheduling decision + auto decision = + std::make_shared( + req->appid(), req->groupid()); + for (int i = 0; i < nFuncs; i++) { + decision->addMessage( + faabric::util::getSystemConfig().endpointHost, 0, 0, i); + } + plannerCli.preloadSchedulingDecision(decision); + + // Now call the request with a preloaded decision + plannerCli.callFunctions(req); + std::map messageResults; + for (int i = 0; i < req->messages_size(); i++) { + auto result = + plannerCli.getMessageResult(req->appid(), req->messages(i).id(), 500); + REQUIRE(result.returnvalue() == 0); + messageResults[result.id()] = result; + } + + // Now, all results for the batch should be registered + auto berStatus = plannerCli.getBatchResults(req); + REQUIRE(berStatus->appid() == req->appid()); + for (const auto& msg : berStatus->messageresults()) { + REQUIRE(messageResults.contains(msg.id())); + checkMessageEquality(messageResults[msg.id()], msg); + } +} }