diff --git a/src/mdns/mdns.cpp b/src/mdns/mdns.cpp index d7e093531a5..438e2a27d6e 100644 --- a/src/mdns/mdns.cpp +++ b/src/mdns/mdns.cpp @@ -174,31 +174,24 @@ otbrError Publisher::DecodeTxtData(Publisher::TxtList &aTxtList, const uint8_t * void Publisher::RemoveSubscriptionCallbacks(uint64_t aSubscriberId) { - size_t erased; - - OTBR_UNUSED_VARIABLE(erased); - - assert(aSubscriberId > 0); - - erased = mDiscoveredCallbacks.erase(aSubscriberId); - - assert(erased == 1); + mDiscoverCallbacks.remove_if( + [aSubscriberId](DiscoverCallback &aCallback) { return (aCallback.mId == aSubscriberId); }); } uint64_t Publisher::AddSubscriptionCallbacks(Publisher::DiscoveredServiceInstanceCallback aInstanceCallback, Publisher::DiscoveredHostCallback aHostCallback) { - uint64_t subscriberId = mNextSubscriberId++; + uint64_t id = mNextSubscriberId++; - assert(subscriberId > 0); + assert(id > 0); + mDiscoverCallbacks.emplace_back(id, aInstanceCallback, aHostCallback); - mDiscoveredCallbacks.emplace(subscriberId, std::make_pair(std::move(aInstanceCallback), std::move(aHostCallback))); - return subscriberId; + return id; } void Publisher::OnServiceResolved(std::string aType, DiscoveredInstanceInfo aInstanceInfo) { - std::vector subscriberIds; + bool checkToInvoke = false; otbrLogInfo("Service %s is resolved successfully: %s %s host %s addresses %zu", aType.c_str(), aInstanceInfo.mRemoved ? "remove" : "add", aInstanceInfo.mName.c_str(), aInstanceInfo.mHostName.c_str(), @@ -216,22 +209,33 @@ void Publisher::OnServiceResolved(std::string aType, DiscoveredInstanceInfo aIns UpdateMdnsResponseCounters(mTelemetryInfo.mServiceResolutions, OTBR_ERROR_NONE); UpdateServiceInstanceResolutionEmaLatency(aInstanceInfo.mName, aType, OTBR_ERROR_NONE); - // In a callback, the mDiscoveredCallbacks may get changed which invalidates the running iterator. We need to refer - // to the callbacks by subscriberId to avoid invalid memory access. - subscriberIds.reserve(mDiscoveredCallbacks.size()); - for (const auto &subCallback : mDiscoveredCallbacks) + // The `mDiscoverCallbacks` list can get updated as the callbacks + // are invoked. We first mark `mShouldInvoke` on all non-null + // service callbacks. We clear it before invoking the callback + // and restart the iteration over the `mDiscoverCallbacks` list + // to find the next one to signal, since the list may have changed. + + for (DiscoverCallback &callback : mDiscoverCallbacks) { - subscriberIds.push_back(subCallback.first); + if (callback.mServiceCallback != nullptr) + { + callback.mShouldInvoke = true; + checkToInvoke = true; + } } - for (const auto &subscriberId : subscriberIds) + + while (checkToInvoke) { - auto it = mDiscoveredCallbacks.find(subscriberId); - if (it != mDiscoveredCallbacks.end()) + checkToInvoke = false; + + for (DiscoverCallback &callback : mDiscoverCallbacks) { - const auto &subCallback = *it; - if (subCallback.second.first != nullptr) + if (callback.mShouldInvoke) { - subCallback.second.first(aType, aInstanceInfo); + callback.mShouldInvoke = false; + checkToInvoke = true; + callback.mServiceCallback(aType, aInstanceInfo); + break; } } } @@ -252,6 +256,8 @@ void Publisher::OnServiceRemoved(uint32_t aNetifIndex, std::string aType, std::s void Publisher::OnHostResolved(std::string aHostName, Publisher::DiscoveredHostInfo aHostInfo) { + bool checkToInvoke = false; + otbrLogInfo("Host %s is resolved successfully: host %s addresses %zu ttl %u", aHostName.c_str(), aHostInfo.mHostName.c_str(), aHostInfo.mAddresses.size(), aHostInfo.mTtl); @@ -263,11 +269,34 @@ void Publisher::OnHostResolved(std::string aHostName, Publisher::DiscoveredHostI UpdateMdnsResponseCounters(mTelemetryInfo.mHostResolutions, OTBR_ERROR_NONE); UpdateHostResolutionEmaLatency(aHostName, OTBR_ERROR_NONE); - for (const auto &subCallback : mDiscoveredCallbacks) + // The `mDiscoverCallbacks` list can get updated as the callbacks + // are invoked. We first mark `mShouldInvoke` on all non-null + // host callbacks. We clear it before invoking the callback + // and restart the iteration over the `mDiscoverCallbacks` list + // to find the next one to signal, since the list may have changed. + + for (DiscoverCallback &callback : mDiscoverCallbacks) { - if (subCallback.second.second != nullptr) + if (callback.mHostCallback != nullptr) { - subCallback.second.second(aHostName, aHostInfo); + callback.mShouldInvoke = true; + checkToInvoke = true; + } + } + + while (checkToInvoke) + { + checkToInvoke = false; + + for (DiscoverCallback &callback : mDiscoverCallbacks) + { + if (callback.mShouldInvoke) + { + callback.mShouldInvoke = false; + checkToInvoke = true; + callback.mHostCallback(aHostName, aHostInfo); + break; + } } } } diff --git a/src/mdns/mdns.hpp b/src/mdns/mdns.hpp index 9a28b884cda..43fd3c26242 100644 --- a/src/mdns/mdns.hpp +++ b/src/mdns/mdns.hpp @@ -37,6 +37,7 @@ #include "openthread-br/config.h" #include +#include #include #include #include @@ -575,9 +576,28 @@ class Publisher : private NonCopyable ServiceRegistrationMap mServiceRegistrations; HostRegistrationMap mHostRegistrations; + struct DiscoverCallback + { + DiscoverCallback(uint64_t aId, + DiscoveredServiceInstanceCallback aServiceCallback, + DiscoveredHostCallback aHostCallback) + : mId(aId) + , mServiceCallback(aServiceCallback) + , mHostCallback(aHostCallback) + , mShouldInvoke(false) + { + } + + uint64_t mId; + DiscoveredServiceInstanceCallback mServiceCallback; + DiscoveredHostCallback mHostCallback; + bool mShouldInvoke; + }; + uint64_t mNextSubscriberId = 1; - std::map> mDiscoveredCallbacks; + std::list mDiscoverCallbacks; + // {instance name, service type} -> the timepoint to begin service registration std::map, Timepoint> mServiceRegistrationBeginTime; // host name -> the timepoint to begin host registration