Skip to content

Commit

Permalink
[mdns] add PublishKey() & UnpublishKey() methods
Browse files Browse the repository at this point in the history
This commit adds new methods in `Mdns::Publisher` to publish or
unpublish a key record for a given (host or service instance) name.
New methods are implemented for both MDNSResponder and Avahi
sub-classes.

In the MDNSResponder implementation, if a key registration is for a
service instance name matching a service registration,
`DNSServiceAddRecord()` is used to associate the new record with the
service. Otherwise, `DNSServiceRegisterRecord()` is used to register
the KEY record on its own. The implementation handles cases when
related service and key registrations are updated or unregistered.

This commit also simplifies and updates the `test/mdns/main.cpp`
tests, adding a common `Test()` function that takes a function pointer
to run the test and handles all common boilerplate code, as well as
adding a common callback to check the registration result. New test
cases are also added to check key registration, registering keys on
their own, and registering keys and services in different orders.
  • Loading branch information
abtink committed Sep 20, 2023
1 parent feaf59b commit 4a4026a
Show file tree
Hide file tree
Showing 8 changed files with 818 additions and 292 deletions.
2 changes: 2 additions & 0 deletions src/common/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,11 +402,13 @@ struct MdnsTelemetryInfo
"kEmaFactorDenominator must be greater than kEmaFactorNumerator");

MdnsResponseCounters mHostRegistrations;
MdnsResponseCounters mKeyRegistrations;
MdnsResponseCounters mServiceRegistrations;
MdnsResponseCounters mHostResolutions;
MdnsResponseCounters mServiceResolutions;

uint32_t mHostRegistrationEmaLatency; ///< The EMA latency of host registrations in milliseconds
uint32_t mKeyRegistrationEmaLatency; ///< The EMA latency of key registrations in milliseconds
uint32_t mServiceRegistrationEmaLatency; ///< The EMA latency of service registrations in milliseconds
uint32_t mHostResolutionEmaLatency; ///< The EMA latency of host resolutions in milliseconds
uint32_t mServiceResolutionEmaLatency; ///< The EMA latency of service resolutions in milliseconds
Expand Down
130 changes: 129 additions & 1 deletion src/mdns/mdns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ void Publisher::PublishHost(const std::string &aName, const AddressList &aAddres
}
}

void Publisher::PublishKey(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback)
{
otbrError error;

mKeyRegistrationBeginTime[aName] = Clock::now();

error = PublishKeyImpl(aName, aKeyData, std::move(aCallback));
if (error != OTBR_ERROR_NONE)
{
UpdateMdnsResponseCounters(mTelemetryInfo.mKeyRegistrations, error);
}
}

void Publisher::OnServiceResolveFailed(std::string aType, std::string aInstanceName, int32_t aErrorCode)
{
UpdateMdnsResponseCounters(mTelemetryInfo.mServiceResolutions, DnsErrorToOtbrError(aErrorCode));
Expand Down Expand Up @@ -289,7 +302,7 @@ std::string Publisher::MakeFullServiceName(const std::string &aName, const std::
return aName + "." + aType + ".local";
}

std::string Publisher::MakeFullHostName(const std::string &aName)
std::string Publisher::MakeFullName(const std::string &aName)
{
return aName + ".local";
}
Expand Down Expand Up @@ -325,6 +338,13 @@ Publisher::ServiceRegistration *Publisher::FindServiceRegistration(const std::st
return it != mServiceRegistrations.end() ? it->second.get() : nullptr;
}

Publisher::ServiceRegistration *Publisher::FindServiceRegistration(const std::string &aNameAndType)
{
auto it = mServiceRegistrations.find(MakeFullName(aNameAndType));

return it != mServiceRegistrations.end() ? it->second.get() : nullptr;
}

Publisher::ResultCallback Publisher::HandleDuplicateServiceRegistration(const std::string &aHostName,
const std::string &aName,
const std::string &aType,
Expand Down Expand Up @@ -435,6 +455,82 @@ Publisher::HostRegistration *Publisher::FindHostRegistration(const std::string &
return it != mHostRegistrations.end() ? it->second.get() : nullptr;
}

Publisher::ResultCallback Publisher::HandleDuplicateKeyRegistration(const std::string &aName,
const KeyData &aKeyData,
ResultCallback &&aCallback)
{
KeyRegistration *keyReg = FindKeyRegistration(aName);

VerifyOrExit(keyReg != nullptr);

if (keyReg->IsOutdated(aName, aKeyData))
{
otbrLogInfo("Removing existing key %s: outdated", aName.c_str());
RemoveKeyRegistration(keyReg->mName, OTBR_ERROR_ABORTED);
}
else if (keyReg->IsCompleted())
{
// Returns success if the same key has already been
// registered with exactly the same parameters.
std::move(aCallback)(OTBR_ERROR_NONE);
}
else
{
// If the same key is being registered with the same parameters,
// let's join the waiting queue for the result.
keyReg->mCallback = std::bind(
[](std::shared_ptr<ResultCallback> aExistingCallback, std::shared_ptr<ResultCallback> aNewCallback,
otbrError aError) {
std::move (*aExistingCallback)(aError);
std::move (*aNewCallback)(aError);
},
std::make_shared<ResultCallback>(std::move(keyReg->mCallback)),
std::make_shared<ResultCallback>(std::move(aCallback)), std::placeholders::_1);
}

exit:
return std::move(aCallback);
}

void Publisher::AddKeyRegistration(KeyRegistrationPtr &&aKeyReg)
{
mKeyRegistrations.emplace(MakeFullKeyName(aKeyReg->mName), std::move(aKeyReg));
}

void Publisher::RemoveKeyRegistration(const std::string &aName, otbrError aError)
{
auto it = mKeyRegistrations.find(MakeFullKeyName(aName));
KeyRegistrationPtr keyReg;

otbrLogInfo("Removing key %s", aName.c_str());
VerifyOrExit(it != mKeyRegistrations.end());

// Keep the KeyRegistration around before calling `Complete`
// to invoke the callback. This is for avoiding invalid access
// to the KeyRegistration when it's freed from the callback.
keyReg = std::move(it->second);
mKeyRegistrations.erase(it);
keyReg->Complete(aError);
otbrLogInfo("Removed key %s", aName.c_str());

exit:
return;
}

Publisher::KeyRegistration *Publisher::FindKeyRegistration(const std::string &aName)
{
auto it = mKeyRegistrations.find(MakeFullKeyName(aName));

return it != mKeyRegistrations.end() ? it->second.get() : nullptr;
}

Publisher::KeyRegistration *Publisher::FindKeyRegistration(const std::string &aName, const std::string &aType)
{
auto it = mKeyRegistrations.find(MakeFullServiceName(aName, aType));

return it != mKeyRegistrations.end() ? it->second.get() : nullptr;
}

Publisher::Registration::~Registration(void)
{
TriggerCompleteCallback(OTBR_ERROR_ABORTED);
Expand Down Expand Up @@ -486,6 +582,26 @@ void Publisher::HostRegistration::OnComplete(otbrError aError)
}
}

bool Publisher::KeyRegistration::IsOutdated(const std::string &aName, const KeyData &aKeyData) const
{
return !(mName == aName && mKeyData == aKeyData);
}

void Publisher::KeyRegistration::Complete(otbrError aError)
{
OnComplete(aError);
Registration::TriggerCompleteCallback(aError);
}

void Publisher::KeyRegistration::OnComplete(otbrError aError)
{
if (!IsCompleted())
{
mPublisher->UpdateMdnsResponseCounters(mPublisher->mTelemetryInfo.mKeyRegistrations, aError);
mPublisher->UpdateKeyRegistrationEmaLatency(mName, aError);
}
}

void Publisher::UpdateMdnsResponseCounters(otbr::MdnsResponseCounters &aCounters, otbrError aError)
{
switch (aError)
Expand Down Expand Up @@ -564,6 +680,18 @@ void Publisher::UpdateHostRegistrationEmaLatency(const std::string &aHostName, o
}
}

void Publisher::UpdateKeyRegistrationEmaLatency(const std::string &aKeyName, otbrError aError)
{
auto it = mKeyRegistrationBeginTime.find(aKeyName);

if (it != mKeyRegistrationBeginTime.end())
{
uint32_t latency = std::chrono::duration_cast<Milliseconds>(Clock::now() - it->second).count();
UpdateEmaLatency(mTelemetryInfo.mKeyRegistrationEmaLatency, latency, aError);
mKeyRegistrationBeginTime.erase(it);
}
}

void Publisher::UpdateServiceInstanceResolutionEmaLatency(const std::string &aInstanceName,
const std::string &aType,
otbrError aError)
Expand Down
70 changes: 69 additions & 1 deletion src/mdns/mdns.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class Publisher : private NonCopyable
typedef std::vector<TxtEntry> TxtList;
typedef std::vector<std::string> SubTypeList;
typedef std::vector<Ip6Address> AddressList;
typedef std::vector<uint8_t> KeyData;

/**
* This structure represents information of a discovered service instance.
Expand Down Expand Up @@ -266,6 +267,29 @@ class Publisher : private NonCopyable
*/
virtual void UnpublishHost(const std::string &aName, ResultCallback &&aCallback) = 0;

/**
* This method publishes or updates a key record for a name.
*
* @param[in] aName The name associated with key record (can be host name or service instance name).
* @param[in] aKeyData The key data to publish.
* @param[in] aCallback The callback for receiving the publishing result.`OTBR_ERROR_NONE` will be
* returned if the operation is successful and all other values indicate a
* failure. Specifically, `OTBR_ERROR_DUPLICATED` indicates that the name has
* already been published and the caller can re-publish with a new name if an
* alternative name is available/acceptable.
*
*/
void PublishKey(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback);

/**
* This method un-publishes a key record
*
* @param[in] aName The name associated with key record.
* @param[in] aCallback The callback for receiving the publishing result.
*
*/
virtual void UnpublishKey(const std::string &aName, ResultCallback &&aCallback) = 0;

/**
* This method subscribes a given service or service instance.
*
Expand Down Expand Up @@ -501,15 +525,43 @@ class Publisher : private NonCopyable
void OnComplete(otbrError aError);
};

class KeyRegistration : public Registration
{
public:
std::string mName;
KeyData mKeyData;

KeyRegistration(std::string aName, KeyData aKeyData, ResultCallback &&aCallback, Publisher *aPublisher)
: Registration(std::move(aCallback), aPublisher)
, mName(std::move(aName))
, mKeyData(std::move(aKeyData))
{
}

~KeyRegistration(void) { OnComplete(OTBR_ERROR_ABORTED); }

void Complete(otbrError aError);

// Tells whether this `KeyRegistration` object is outdated comparing to the given parameters.
bool IsOutdated(const std::string &aName, const KeyData &aKeyData) const;

private:
void OnComplete(otbrError aError);
};

using ServiceRegistrationPtr = std::unique_ptr<ServiceRegistration>;
using ServiceRegistrationMap = std::map<std::string, ServiceRegistrationPtr>;
using HostRegistrationPtr = std::unique_ptr<HostRegistration>;
using HostRegistrationMap = std::map<std::string, HostRegistrationPtr>;
using KeyRegistrationPtr = std::unique_ptr<KeyRegistration>;
using KeyRegistrationMap = std::map<std::string, KeyRegistrationPtr>;

static SubTypeList SortSubTypeList(SubTypeList aSubTypeList);
static AddressList SortAddressList(AddressList aAddressList);
static std::string MakeFullName(const std::string &aName);
static std::string MakeFullServiceName(const std::string &aName, const std::string &aType);
static std::string MakeFullHostName(const std::string &aName);
static std::string MakeFullHostName(const std::string &aName) { return MakeFullName(aName); }
static std::string MakeFullKeyName(const std::string &aName) { return MakeFullName(aName); }

virtual otbrError PublishServiceImpl(const std::string &aHostName,
const std::string &aName,
Expand All @@ -523,6 +575,8 @@ class Publisher : private NonCopyable
const AddressList &aAddresses,
ResultCallback &&aCallback) = 0;

virtual otbrError PublishKeyImpl(const std::string &aName, const KeyData &aKeyData, ResultCallback &&aCallback) = 0;

virtual void OnServiceResolveFailedImpl(const std::string &aType,
const std::string &aInstanceName,
int32_t aErrorCode) = 0;
Expand All @@ -534,6 +588,7 @@ class Publisher : private NonCopyable
void AddServiceRegistration(ServiceRegistrationPtr &&aServiceReg);
void RemoveServiceRegistration(const std::string &aName, const std::string &aType, otbrError aError);
ServiceRegistration *FindServiceRegistration(const std::string &aName, const std::string &aType);
ServiceRegistration *FindServiceRegistration(const std::string &aNameAndType);

void OnServiceResolved(std::string aType, DiscoveredInstanceInfo aInstanceInfo);
void OnServiceResolveFailed(std::string aType, std::string aInstanceName, int32_t aErrorCode);
Expand All @@ -556,24 +611,35 @@ class Publisher : private NonCopyable
const AddressList &aAddresses,
ResultCallback &&aCallback);

ResultCallback HandleDuplicateKeyRegistration(const std::string &aName,
const KeyData &aKeyData,
ResultCallback &&aCallback);

void AddHostRegistration(HostRegistrationPtr &&aHostReg);
void RemoveHostRegistration(const std::string &aName, otbrError aError);
HostRegistration *FindHostRegistration(const std::string &aName);

void AddKeyRegistration(KeyRegistrationPtr &&aKeyReg);
void RemoveKeyRegistration(const std::string &aName, otbrError aError);
KeyRegistration *FindKeyRegistration(const std::string &aName);
KeyRegistration *FindKeyRegistration(const std::string &aName, const std::string &aType);

static void UpdateMdnsResponseCounters(MdnsResponseCounters &aCounters, otbrError aError);
static void UpdateEmaLatency(uint32_t &aEmaLatency, uint32_t aLatency, otbrError aError);

void UpdateServiceRegistrationEmaLatency(const std::string &aInstanceName,
const std::string &aType,
otbrError aError);
void UpdateHostRegistrationEmaLatency(const std::string &aHostName, otbrError aError);
void UpdateKeyRegistrationEmaLatency(const std::string &aKeyName, otbrError aError);
void UpdateServiceInstanceResolutionEmaLatency(const std::string &aInstanceName,
const std::string &aType,
otbrError aError);
void UpdateHostResolutionEmaLatency(const std::string &aHostName, otbrError aError);

ServiceRegistrationMap mServiceRegistrations;
HostRegistrationMap mHostRegistrations;
KeyRegistrationMap mKeyRegistrations;

uint64_t mNextSubscriberId = 1;

Expand All @@ -582,6 +648,8 @@ class Publisher : private NonCopyable
std::map<std::pair<std::string, std::string>, Timepoint> mServiceRegistrationBeginTime;
// host name -> the timepoint to begin host registration
std::map<std::string, Timepoint> mHostRegistrationBeginTime;
// key name -> the timepoint to begin key registration
std::map<std::string, Timepoint> mKeyRegistrationBeginTime;
// {instance name, service type} -> the timepoint to begin service resolution
std::map<std::pair<std::string, std::string>, Timepoint> mServiceInstanceResolutionBeginTime;
// host name -> the timepoint to begin host resolution
Expand Down
Loading

0 comments on commit 4a4026a

Please sign in to comment.