diff --git a/include/faabric/batch-scheduler/SchedulingDecision.h b/include/faabric/batch-scheduler/SchedulingDecision.h index 7bbaa3099..ffb2cd078 100644 --- a/include/faabric/batch-scheduler/SchedulingDecision.h +++ b/include/faabric/batch-scheduler/SchedulingDecision.h @@ -99,6 +99,8 @@ class SchedulingDecision int32_t appIdx, int32_t groupIdx); + void removeMessage(int32_t messageId); + std::set uniqueHosts(); void print(const std::string& logLevel = "debug"); diff --git a/src/batch-scheduler/SchedulingDecision.cpp b/src/batch-scheduler/SchedulingDecision.cpp index dafd6f8d5..22d268561 100644 --- a/src/batch-scheduler/SchedulingDecision.cpp +++ b/src/batch-scheduler/SchedulingDecision.cpp @@ -54,6 +54,25 @@ SchedulingDecision SchedulingDecision::fromPointToPointMappings( return decision; } +void SchedulingDecision::removeMessage(int32_t messageId) +{ + // Work out the index for the to-be-deleted message + auto idxItr = std::find(messageIds.begin(), messageIds.end(), messageId); + if (idxItr == messageIds.end()) { + SPDLOG_ERROR("Attempting to remove a message id ({}) that is not in " + "the scheduling decision!", + messageId); + throw std::runtime_error("Removing non-existant message!"); + } + int idx = std::distance(messageIds.begin(), idxItr); + + nFunctions--; + hosts.erase(hosts.begin() + idx); + messageIds.erase(messageIds.begin() + idx); + appIdxs.erase(appIdxs.begin() + idx); + groupIdxs.erase(groupIdxs.begin() + idx); +} + std::set SchedulingDecision::uniqueHosts() { return std::set(hosts.begin(), hosts.end()); diff --git a/tests/test/batch-scheduler/test_scheduling_decisions.cpp b/tests/test/batch-scheduler/test_scheduling_decisions.cpp index e36694812..a58ca52ca 100644 --- a/tests/test/batch-scheduler/test_scheduling_decisions.cpp +++ b/tests/test/batch-scheduler/test_scheduling_decisions.cpp @@ -141,6 +141,46 @@ TEST_CASE("Test converting point-to-point mappings to scheduling decisions", REQUIRE(actual.hosts == expectedHosts); } +TEST_CASE("Test removing a message from a scheduling decision", + "[batch-scheduler]") +{ + // Build a scheduling decision + auto req = faabric::util::batchExecFactory("foo", "bar", 3); + SchedulingDecision decision(req->appid(), req->groupid()); + decision.addMessage("foo", req->messages(0)); + decision.addMessage("bar", req->messages(1)); + decision.addMessage("baz", req->messages(2)); + + // Record the original values + int nFunctions = decision.nFunctions; + int nHosts = decision.hosts.size(); + int nMessageIds = decision.messageIds.size(); + int nAppIdxs = decision.appIdxs.size(); + int nGroupIdxs = decision.groupIdxs.size(); + + // Remove message from scheduling decision + decision.removeMessage(req->messages(1).id()); + + // Check decision after removal + REQUIRE(decision.nFunctions == (nFunctions - 1)); + REQUIRE(decision.hosts.size() == (nHosts - 1)); + REQUIRE(decision.messageIds.size() == (nMessageIds - 1)); + REQUIRE(decision.appIdxs.size() == (nAppIdxs - 1)); + REQUIRE(decision.groupIdxs.size() == (nGroupIdxs - 1)); + + // Removing a non-existant id throws an exception + REQUIRE_THROWS(decision.removeMessage(req->messages(1).id())); + + // Lastly, drain the decision and check again + decision.removeMessage(req->messages(0).id()); + decision.removeMessage(req->messages(2).id()); + REQUIRE(decision.nFunctions == 0); + REQUIRE(decision.hosts.empty()); + REQUIRE(decision.messageIds.empty()); + REQUIRE(decision.appIdxs.empty()); + REQUIRE(decision.groupIdxs.empty()); +} + TEST_CASE_METHOD(CachedDecisionTestFixture, "Test caching scheduling decisions", "[util]")