diff --git a/include/faabric/util/batch.h b/include/faabric/util/batch.h index de8f906e2..56b8e964a 100644 --- a/include/faabric/util/batch.h +++ b/include/faabric/util/batch.h @@ -17,6 +17,9 @@ std::shared_ptr batchExecFactory( bool isBatchExecRequestValid(std::shared_ptr ber); +void updateBatchExecAppId(std::shared_ptr ber, + int newAppId); + void updateBatchExecGroupId(std::shared_ptr ber, int newGroupId); diff --git a/src/util/batch.cpp b/src/util/batch.cpp index 5262cbe0b..de5ad2d8b 100644 --- a/src/util/batch.cpp +++ b/src/util/batch.cpp @@ -61,12 +61,24 @@ bool isBatchExecRequestValid(std::shared_ptr ber) return true; } +void updateBatchExecAppId(std::shared_ptr ber, + int newAppId) +{ + ber->set_appid(newAppId); + for (int i = 0; i < ber->messages_size(); i++) { + ber->mutable_messages(i)->set_appid(newAppId); + } + + // Sanity-check in debug mode + assert(isBatchExecRequestValid(ber)); +} + void updateBatchExecGroupId(std::shared_ptr ber, int newGroupId) { ber->set_groupid(newGroupId); - for (auto msg : *ber->mutable_messages()) { - msg.set_groupid(newGroupId); + for (int i = 0; i < ber->messages_size(); i++) { + ber->mutable_messages(i)->set_groupid(newGroupId); } // Sanity-check in debug mode diff --git a/tests/test/util/test_batch.cpp b/tests/test/util/test_batch.cpp index f664653b7..c4a7f8328 100644 --- a/tests/test/util/test_batch.cpp +++ b/tests/test/util/test_batch.cpp @@ -88,6 +88,21 @@ TEST_CASE("Test batch. exec request sanity checks") REQUIRE(isBerValid == isBatchExecRequestValid(ber)); } +TEST_CASE("Test updating the app ID of a BER") +{ + int nMessages = 4; + std::shared_ptr ber = + batchExecFactory("demo", "echo", nMessages); + + // By default the BER is valid + REQUIRE(isBatchExecRequestValid(ber)); + + int newAppId = 1337; + updateBatchExecAppId(ber, newAppId); + REQUIRE(isBatchExecRequestValid(ber)); + REQUIRE(ber->appid() == newAppId); +} + TEST_CASE("Test updating the group ID of a BER") { int nMessages = 4;