diff --git a/sandboxed_api/sandbox2/comms.cc b/sandboxed_api/sandbox2/comms.cc index a0ef45ad..4e0d045b 100644 --- a/sandboxed_api/sandbox2/comms.cc +++ b/sandboxed_api/sandbox2/comms.cc @@ -259,6 +259,7 @@ bool Comms::RecvString(std::string* v) { } if (tag != kTagString) { + v->clear(); SAPI_RAW_LOG(ERROR, "Expected (kTagString == 0x%x), got: 0x%x", kTagString, tag); return false; diff --git a/sandboxed_api/sandbox2/comms_test.cc b/sandboxed_api/sandbox2/comms_test.cc index 970ba3a3..a32604b9 100644 --- a/sandboxed_api/sandbox2/comms_test.cc +++ b/sandboxed_api/sandbox2/comms_test.cc @@ -38,14 +38,16 @@ #include "sandboxed_api/sandbox2/comms_test.pb.h" #include "sandboxed_api/util/status_matchers.h" +namespace sandbox2 { +namespace { + using ::sapi::IsOk; using ::sapi::StatusIs; using ::testing::Eq; +using ::testing::IsEmpty; using ::testing::IsFalse; using ::testing::IsTrue; -namespace sandbox2 { - using CommunicationHandler = std::function; constexpr char kProtoStr[] = "ABCD"; @@ -196,13 +198,22 @@ TEST(CommsTest, TestSendRecvArray) { }; auto b = [](Comms* comms) { // Send 1M bytes. - std::vector buffer(1024 * 1024); - memset(buffer.data(), 0, buffer.size()); + std::vector buffer(1024 * 1024, 0); ASSERT_THAT(comms->SendBytes(buffer), IsTrue()); }; HandleCommunication(a, b); } +TEST(CommsTest, TestSendRecvEmptyArray) { + auto a = [](Comms* comms) { + std::vector buffer; + ASSERT_THAT(comms->RecvBytes(&buffer), IsTrue()); + EXPECT_THAT(buffer, IsEmpty()); + }; + auto b = [](Comms* comms) { ASSERT_THAT(comms->SendBytes({}), IsTrue()); }; + HandleCommunication(a, b); +} + TEST(CommsTest, TestSendRecvFD) { auto a = [](Comms* comms) { // Receive FD and test it. @@ -210,6 +221,7 @@ TEST(CommsTest, TestSendRecvFD) { ASSERT_THAT(comms->RecvFD(&fd), IsTrue()); EXPECT_GE(fd, 0); EXPECT_NE(fcntl(fd, F_GETFD), -1); + close(fd); }; auto b = [](Comms* comms) { // Send our STDERR to the thread. @@ -371,4 +383,71 @@ TEST(CommsTest, TestSendRecvBytes) { HandleCommunication(a, b); } +TEST(CommsTest, SendRecvFailsAfterTerminate) { + auto a = [](Comms* comms) { + comms->Terminate(); + ASSERT_THAT(comms->IsTerminated(), IsTrue()); + EXPECT_THAT(comms->SendInt8(0), IsFalse()); + EXPECT_THAT(comms->SendFD(STDERR_FILENO), IsFalse()); + int8_t tmp; + EXPECT_THAT(comms->RecvInt8(&tmp), IsFalse()); + std::string s; + EXPECT_THAT(comms->RecvString(&s), IsFalse()); + std::vector b; + EXPECT_THAT(comms->RecvBytes(&b), IsFalse()); + int fd; + EXPECT_THAT(comms->RecvFD(&fd), IsFalse()); + CommsTestMsg msg; + EXPECT_THAT(comms->RecvProtoBuf(&msg), IsFalse()); + }; + auto b = [](Comms* comms) {}; + HandleCommunication(a, b); +} + +TEST(CommsTest, RecvIntFailsOnTagMismatch) { + auto a = [](Comms* comms) { + int8_t tmp; + EXPECT_THAT(comms->RecvInt8(&tmp), IsFalse()); + }; + auto b = [](Comms* comms) { ASSERT_THAT(comms->SendUint8(0), IsTrue()); }; + HandleCommunication(a, b); +} + +TEST(CommsTest, RecvStringBytesFailsOnTagMismatch) { + auto a = [](Comms* comms) { + std::string s; + EXPECT_THAT(comms->RecvString(&s), IsFalse()); + EXPECT_THAT(s, IsEmpty()); + ASSERT_THAT(comms->SendString("hello"), IsTrue()); + }; + auto b = [](Comms* comms) { + ASSERT_THAT(comms->SendBytes({1, 0}), IsTrue()); + std::vector b; + EXPECT_THAT(comms->RecvBytes(&b), IsFalse()); + EXPECT_THAT(b, IsEmpty()); + }; + HandleCommunication(a, b); +} + +TEST(CommsTest, RecvFDFailsOnTagMismatch) { + auto a = [](Comms* comms) { + int fd; + EXPECT_THAT(comms->RecvFD(&fd), IsFalse()); + }; + auto b = [](Comms* comms) { ASSERT_THAT(comms->SendBytes({}), IsTrue()); }; + HandleCommunication(a, b); +} + +TEST(CommsTest, RecvProtoBufFailsOnTagMismatch) { + auto a = [](Comms* comms) { + CommsTestMsg msg; + EXPECT_THAT(comms->RecvProtoBuf(&msg), IsFalse()); + }; + auto b = [](Comms* comms) { + ASSERT_THAT(comms->SendString("hello"), IsTrue()); + }; + HandleCommunication(a, b); +} + +} // namespace } // namespace sandbox2