From d41907347b569a85edaf1bfdcec9628b611670d4 Mon Sep 17 00:00:00 2001 From: Abtin Keshavarzian Date: Fri, 1 Sep 2023 16:38:01 -0700 Subject: [PATCH] [mdns] update 'TxtEntry' to handle boolean attribute This commit updates `TxtEntry`, `EncodeTxtData()`, and `DecodeTxtData()` to support boolean attributes, which are encoded as `key` string without the `=` character. It also ensures that `EncodeTxtData()` will generate valid TXT data (containing a single zero byte) when given an empty `TxtList`. Additionally, this commit updates `AdvertisingProxy::MakeTxtList()` to use `DecodeTxtData()`, and adds a test to validate the encoder and decoder behavior. --- src/border_agent/border_agent.cpp | 7 ++-- src/mdns/mdns.cpp | 54 +++++++++++++++++++--------- src/mdns/mdns.hpp | 39 ++++++++++++++------ src/mdns/mdns_avahi.cpp | 32 +++++++++++------ src/sdp_proxy/advertising_proxy.cpp | 10 +----- src/trel_dnssd/trel_dnssd.cpp | 7 +++- tests/mdns/main.cpp | 55 +++++++++++++++++++++++++++++ 7 files changed, 154 insertions(+), 50 deletions(-) diff --git a/src/border_agent/border_agent.cpp b/src/border_agent/border_agent.cpp index 065421f7ebe..a1a30c0ef78 100644 --- a/src/border_agent/border_agent.cpp +++ b/src/border_agent/border_agent.cpp @@ -343,10 +343,11 @@ void AppendVendorTxtEntries(const std::map> &a for (auto &addedEntry : aTxtList) { - if (addedEntry.mName == key) + if (addedEntry.mKey == key) { - addedEntry.mValue = value; - found = true; + addedEntry.mValue = value; + addedEntry.mIsBooleanAttribute = false; + found = true; break; } } diff --git a/src/mdns/mdns.cpp b/src/mdns/mdns.cpp index c2c692c461b..2ceaf797714 100644 --- a/src/mdns/mdns.cpp +++ b/src/mdns/mdns.cpp @@ -99,18 +99,32 @@ otbrError Publisher::EncodeTxtData(const TxtList &aTxtList, std::vector { otbrError error = OTBR_ERROR_NONE; - for (const auto &txtEntry : aTxtList) + aTxtData.clear(); + + for (const TxtEntry &txtEntry : aTxtList) { - const auto &name = txtEntry.mName; - const auto &value = txtEntry.mValue; - const size_t entryLength = name.length() + 1 + value.size(); + size_t entryLength = txtEntry.mKey.length(); + + if (!txtEntry.mIsBooleanAttribute) + { + entryLength += txtEntry.mValue.size() + sizeof(uint8_t); // for `=` char. + } VerifyOrExit(entryLength <= kMaxTextEntrySize, error = OTBR_ERROR_INVALID_ARGS); aTxtData.push_back(static_cast(entryLength)); - aTxtData.insert(aTxtData.end(), name.begin(), name.end()); - aTxtData.push_back('='); - aTxtData.insert(aTxtData.end(), value.begin(), value.end()); + aTxtData.insert(aTxtData.end(), txtEntry.mKey.begin(), txtEntry.mKey.end()); + + if (!txtEntry.mIsBooleanAttribute) + { + aTxtData.push_back('='); + aTxtData.insert(aTxtData.end(), txtEntry.mValue.begin(), txtEntry.mValue.end()); + } + } + + if (aTxtData.empty()) + { + aTxtData.push_back(0); } exit: @@ -121,30 +135,39 @@ otbrError Publisher::DecodeTxtData(Publisher::TxtList &aTxtList, const uint8_t * { otbrError error = OTBR_ERROR_NONE; + aTxtList.clear(); + for (uint16_t r = 0; r < aTxtLength;) { uint16_t entrySize = aTxtData[r]; uint16_t keyStart = r + 1; uint16_t entryEnd = keyStart + entrySize; uint16_t keyEnd = keyStart; - uint16_t valStart; + + VerifyOrExit(entryEnd <= aTxtLength, error = OTBR_ERROR_PARSE); while (keyEnd < entryEnd && aTxtData[keyEnd] != '=') { keyEnd++; } - valStart = keyEnd; - if (valStart < entryEnd && aTxtData[valStart] == '=') + if (keyEnd == entryEnd) { - valStart++; + if (keyEnd > keyStart) + { + // No `=`, treat as a boolean attribute. + aTxtList.emplace_back(reinterpret_cast(&aTxtData[keyStart]), keyEnd - keyStart); + } } + else + { + uint16_t valStart = keyEnd + 1; // To skip over `=` - aTxtList.emplace_back(reinterpret_cast(&aTxtData[keyStart]), keyEnd - keyStart, - &aTxtData[valStart], entryEnd - valStart); + aTxtList.emplace_back(reinterpret_cast(&aTxtData[keyStart]), keyEnd - keyStart, + &aTxtData[valStart], entryEnd - valStart); + } r += entrySize + 1; - VerifyOrExit(r <= aTxtLength, error = OTBR_ERROR_PARSE); } exit: @@ -260,7 +283,7 @@ Publisher::SubTypeList Publisher::SortSubTypeList(SubTypeList aSubTypeList) Publisher::TxtList Publisher::SortTxtList(TxtList aTxtList) { std::sort(aTxtList.begin(), aTxtList.end(), - [](const TxtEntry &aLhs, const TxtEntry &aRhs) { return aLhs.mName < aRhs.mName; }); + [](const TxtEntry &aLhs, const TxtEntry &aRhs) { return aLhs.mKey < aRhs.mKey; }); return aTxtList; } @@ -577,5 +600,4 @@ void Publisher::UpdateHostResolutionEmaLatency(const std::string &aHostName, otb } } // namespace Mdns - } // namespace otbr diff --git a/src/mdns/mdns.hpp b/src/mdns/mdns.hpp index 0159009d11b..78e8c5187e1 100644 --- a/src/mdns/mdns.hpp +++ b/src/mdns/mdns.hpp @@ -70,31 +70,48 @@ class Publisher : private NonCopyable { public: /** - * This structure represents a name/value pair of the TXT record. + * This structure represents a key/value pair of the TXT record. * */ struct TxtEntry { - std::string mName; ///< The name of the TXT entry. - std::vector mValue; ///< The value of the TXT entry. + std::string mKey; ///< The key of the TXT entry. + std::vector mValue; ///< The value of the TXT entry. Can be empty. + bool mIsBooleanAttribute; ///< This entry is boolean attribute (encoded as `key` without `=`). - TxtEntry(const char *aName, const char *aValue) - : TxtEntry(aName, reinterpret_cast(aValue), strlen(aValue)) + TxtEntry(const char *aKey, const char *aValue) + : TxtEntry(aKey, reinterpret_cast(aValue), strlen(aValue)) { } - TxtEntry(const char *aName, const uint8_t *aValue, size_t aValueLength) - : TxtEntry(aName, strlen(aName), aValue, aValueLength) + TxtEntry(const char *aKey, const uint8_t *aValue, size_t aValueLength) + : TxtEntry(aKey, strlen(aKey), aValue, aValueLength) { } - TxtEntry(const char *aName, size_t aNameLength, const uint8_t *aValue, size_t aValueLength) - : mName(aName, aNameLength) + TxtEntry(const char *aKey, size_t aKeyLength, const uint8_t *aValue, size_t aValueLength) + : mKey(aKey, aKeyLength) , mValue(aValue, aValue + aValueLength) + , mIsBooleanAttribute(false) { } - bool operator==(const TxtEntry &aOther) const { return mName == aOther.mName && mValue == aOther.mValue; } + TxtEntry(const char *aKey) + : TxtEntry(aKey, strlen(aKey)) + { + } + + TxtEntry(const char *aKey, size_t aKeyLength) + : mKey(aKey, aKeyLength) + , mIsBooleanAttribute(true) + { + } + + bool operator==(const TxtEntry &aOther) const + { + return (mKey == aOther.mKey) && (mValue == aOther.mValue) && + (mIsBooleanAttribute == aOther.mIsBooleanAttribute); + } }; typedef std::vector TxtList; @@ -356,7 +373,7 @@ class Publisher : private NonCopyable * See RFC 6763 for details: https://tools.ietf.org/html/rfc6763#section-6. * * @param[in] aTxtList A TXT entry list. - * @param[out] aTxtData A TXT data buffer. + * @param[out] aTxtData A TXT data buffer. Will be cleared. * * @retval OTBR_ERROR_NONE Successfully write the TXT entry list. * @retval OTBR_ERROR_INVALID_ARGS The @p aTxtList includes invalid TXT entry. diff --git a/src/mdns/mdns_avahi.cpp b/src/mdns/mdns_avahi.cpp index e982f12f5b2..735a2c3a1ef 100644 --- a/src/mdns/mdns_avahi.cpp +++ b/src/mdns/mdns_avahi.cpp @@ -803,24 +803,36 @@ otbrError PublisherAvahi::TxtListToAvahiStringList(const TxtList &aTxtList, aHead = nullptr; for (const auto &txtEntry : aTxtList) { - const char *name = txtEntry.mName.c_str(); - size_t nameLength = txtEntry.mName.length(); + const char *key = txtEntry.mKey.c_str(); + size_t keyLength = txtEntry.mKey.length(); const uint8_t *value = txtEntry.mValue.data(); size_t valueLength = txtEntry.mValue.size(); - // +1 for the size of "=", avahi doesn't need '\0' at the end of the entry - size_t needed = sizeof(AvahiStringList) - sizeof(AvahiStringList::text) + nameLength + valueLength + 1; + size_t needed = sizeof(AvahiStringList) - sizeof(AvahiStringList::text) + keyLength; + const uint8_t *next; + + if (!txtEntry.mIsBooleanAttribute) + { + needed += valueLength + 1; // +1 is for `=` character. + } VerifyOrExit(used + needed <= aBufferSize, error = OTBR_ERROR_INVALID_ARGS); curr->next = last; last = curr; - memcpy(curr->text, name, nameLength); - curr->text[nameLength] = '='; - memcpy(curr->text + nameLength + 1, value, valueLength); - curr->size = nameLength + valueLength + 1; + memcpy(curr->text, key, keyLength); + + if (!txtEntry.mIsBooleanAttribute) { - const uint8_t *next = curr->text + curr->size; - curr = OTBR_ALIGNED(next, AvahiStringList *); + curr->text[keyLength] = '='; + memcpy(curr->text + keyLength + 1, value, valueLength); + curr->size = keyLength + valueLength + 1; } + else + { + curr->size = keyLength; + } + + next = curr->text + curr->size; + curr = OTBR_ALIGNED(next, AvahiStringList *); used = static_cast(reinterpret_cast(curr) - reinterpret_cast(aBuffer)); } SuccessOrExit(error); diff --git a/src/sdp_proxy/advertising_proxy.cpp b/src/sdp_proxy/advertising_proxy.cpp index 817786c0ede..9f6ce494ad5 100644 --- a/src/sdp_proxy/advertising_proxy.cpp +++ b/src/sdp_proxy/advertising_proxy.cpp @@ -342,18 +342,10 @@ Mdns::Publisher::TxtList AdvertisingProxy::MakeTxtList(const otSrpServerService { const uint8_t *txtData; uint16_t txtDataLength = 0; - otDnsTxtEntryIterator iterator; - otDnsTxtEntry txtEntry; Mdns::Publisher::TxtList txtList; txtData = otSrpServerServiceGetTxtData(aSrpService, &txtDataLength); - - otDnsInitTxtEntryIterator(&iterator, txtData, txtDataLength); - - while (otDnsGetNextTxtEntry(&iterator, &txtEntry) == OT_ERROR_NONE) - { - txtList.emplace_back(txtEntry.mKey, txtEntry.mValue, txtEntry.mValueLength); - } + Mdns::Publisher::DecodeTxtData(txtList, txtData, txtDataLength); return txtList; } diff --git a/src/trel_dnssd/trel_dnssd.cpp b/src/trel_dnssd/trel_dnssd.cpp index 11794fcb911..37ec8c5809b 100644 --- a/src/trel_dnssd/trel_dnssd.cpp +++ b/src/trel_dnssd/trel_dnssd.cpp @@ -495,7 +495,12 @@ void TrelDnssd::Peer::ReadExtAddrFromTxtData(void) for (const auto &txtEntry : txtEntries) { - if (StringUtils::EqualCaseInsensitive(txtEntry.mName, kTxtRecordExtAddressKey)) + if (txtEntry.mIsBooleanAttribute) + { + continue; + } + + if (StringUtils::EqualCaseInsensitive(txtEntry.mKey, kTxtRecordExtAddressKey)) { VerifyOrExit(txtEntry.mValue.size() == sizeof(mExtAddr)); diff --git a/tests/mdns/main.cpp b/tests/mdns/main.cpp index 758ed1b27f3..607f14435c4 100644 --- a/tests/mdns/main.cpp +++ b/tests/mdns/main.cpp @@ -410,10 +410,65 @@ otbrError TestStopService(void) return ret; } +otbrError CheckTxtDataEncoderDecoder(void) +{ + otbrError error = OTBR_ERROR_NONE; + Mdns::Publisher::TxtList txtList; + Mdns::Publisher::TxtList parsedTxtList; + std::vector txtData; + + // Encode empty `TxtList` + + SuccessOrExit(error = Mdns::Publisher::EncodeTxtData(txtList, txtData)); + VerifyOrExit(txtData.size() == 1, error = OTBR_ERROR_PARSE); + VerifyOrExit(txtData[0] == 0, error = OTBR_ERROR_PARSE); + + SuccessOrExit(error = Mdns::Publisher::DecodeTxtData(parsedTxtList, txtData.data(), txtData.size())); + VerifyOrExit(parsedTxtList.size() == 0, error = OTBR_ERROR_PARSE); + + // TxtList with one bool attribute + + txtList.clear(); + txtList.emplace_back("b1"); + + SuccessOrExit(error = Mdns::Publisher::EncodeTxtData(txtList, txtData)); + SuccessOrExit(error = Mdns::Publisher::DecodeTxtData(parsedTxtList, txtData.data(), txtData.size())); + VerifyOrExit(parsedTxtList == txtList, error = OTBR_ERROR_PARSE); + + // TxtList with one one key/value + + txtList.clear(); + txtList.emplace_back("k1", "v1"); + + SuccessOrExit(error = Mdns::Publisher::EncodeTxtData(txtList, txtData)); + SuccessOrExit(error = Mdns::Publisher::DecodeTxtData(parsedTxtList, txtData.data(), txtData.size())); + VerifyOrExit(parsedTxtList == txtList, error = OTBR_ERROR_PARSE); + + // TxtList with multiple entries + + txtList.clear(); + txtList.emplace_back("k1", "v1"); + txtList.emplace_back("b1"); + txtList.emplace_back("b2"); + txtList.emplace_back("k2", "valu2"); + + SuccessOrExit(error = Mdns::Publisher::EncodeTxtData(txtList, txtData)); + SuccessOrExit(error = Mdns::Publisher::DecodeTxtData(parsedTxtList, txtData.data(), txtData.size())); + VerifyOrExit(parsedTxtList == txtList, error = OTBR_ERROR_PARSE); + +exit: + return error; +} + int main(int argc, char *argv[]) { int ret = 0; + if (CheckTxtDataEncoderDecoder() != OTBR_ERROR_NONE) + { + return 1; + } + if (argc < 2) { return 1;