Skip to content

Commit

Permalink
[mdns-avahi] handle freed AvahiWatch/Timeout entries as processing …
Browse files Browse the repository at this point in the history
…run loop

Avahi provides a mechanism to integrate it to a `select()` style run
loop where an implementation can provide an `AvahiWatch` and an
`AvahiTimeout` structure along with a set of related APIs (as
function pointers given to Avahi client) for it to allocate new
watch/timeout entries or update or free them. A watch entry provides
a set of events to watch for and expects its given callback to be
invoked when the events occur. A timeout provides an abstraction for
a timer.

This commit updates `AvahiPoller` (which integrates Avahi to run loop
as a `MainloopProcessor`) to protect against situations where entries
get freed or updated as Avahi callbacks are invoked from
`AvahiPoller::Process()`.

When we invoke the callback for an `AvahiWatch` or `AvahiTimeout` the
Avahi module can call any of APIs we provided to it. For example, it
can update or free any existing `AvahiWatch/Timeout` entry, which in
turn, modifies the `mWatches` or `mTimers` list being tracked by the
`AvahiPoller`.

This commit updates the `AvahiPoller` class to correctly handle such
situations. Before invoking the callback, we update the entry's state
and after returning from callback we restart the iteration over the
watch/timer list to find the next entry to report, as the list may
have changed during execution of the callback.
  • Loading branch information
abtink committed Sep 20, 2023
1 parent feaf59b commit 3bf835b
Showing 1 changed file with 112 additions and 60 deletions.
172 changes: 112 additions & 60 deletions src/mdns/mdns_avahi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,30 +53,42 @@
#include "common/logging.hpp"
#include "common/time.hpp"

namespace otbr {
namespace Mdns {

class AvahiPoller;

} // namespace Mdns
} // namespace otbr

struct AvahiWatch
{
int mFd; ///< The file descriptor to watch.
AvahiWatchEvent mEvents; ///< The interested events.
int mHappened; ///< The events happened.
AvahiWatchCallback mCallback; ///< The function to be called when interested events happened on mFd.
void *mContext; ///< A pointer to application-specific context.
void *mPoller; ///< The poller created this watch.
typedef otbr::Mdns::AvahiPoller AvahiPoller;

int mFd; ///< The file descriptor to watch.
AvahiWatchEvent mEvents; ///< The interested events.
int mHappened; ///< The events happened.
AvahiWatchCallback mCallback; ///< The function to be called to report events happened on `mFd`.
void *mContext; ///< A pointer to application-specific context to use with `mCallback`.
bool mShouldReport; ///< Whether or not we need to report events (invoking callback).
AvahiPoller &mPoller; ///< The poller owning this watch.

/**
* The constructor to initialize an Avahi watch.
*
* @param[in] aFd The file descriptor to watch.
* @param[in] aEvents The events to watch.
* @param[in] aCallback The function to be called when events happend on this file descriptor.
* @param[in] aCallback The function to be called when events happened on this file descriptor.
* @param[in] aContext A pointer to application-specific context.
* @param[in] aPoller The AvahiPoller this watcher belongs to.
*
*/
AvahiWatch(int aFd, AvahiWatchEvent aEvents, AvahiWatchCallback aCallback, void *aContext, void *aPoller)
AvahiWatch(int aFd, AvahiWatchEvent aEvents, AvahiWatchCallback aCallback, void *aContext, AvahiPoller &aPoller)
: mFd(aFd)
, mEvents(aEvents)
, mCallback(aCallback)
, mContext(aContext)
, mShouldReport(false)
, mPoller(aPoller)
{
}
Expand All @@ -88,10 +100,13 @@ struct AvahiWatch
*/
struct AvahiTimeout
{
otbr::Timepoint mTimeout; ///< Absolute time when this timer timeout.
AvahiTimeoutCallback mCallback; ///< The function to be called when timeout.
void *mContext; ///< The pointer to application-specific context.
void *mPoller; ///< The poller created this timer.
typedef otbr::Mdns::AvahiPoller AvahiPoller;

otbr::Timepoint mTimeout; ///< Absolute time when this timer timeout.
AvahiTimeoutCallback mCallback; ///< The function to be called when timeout.
void *mContext; ///< The pointer to application-specific context.
bool mShouldReport; ///< Whether or not timeout occurred and need to reported (invoking callback).
AvahiPoller &mPoller; ///< The poller created this timer.

/**
* The constructor to initialize an AvahiTimeout.
Expand All @@ -102,9 +117,10 @@ struct AvahiTimeout
* @param[in] aPoller The AvahiPoller this timeout belongs to.
*
*/
AvahiTimeout(const struct timeval *aTimeout, AvahiTimeoutCallback aCallback, void *aContext, void *aPoller)
AvahiTimeout(const struct timeval *aTimeout, AvahiTimeoutCallback aCallback, void *aContext, AvahiPoller &aPoller)
: mCallback(aCallback)
, mContext(aContext)
, mShouldReport(false)
, mPoller(aPoller)
{
if (aTimeout)
Expand Down Expand Up @@ -168,13 +184,13 @@ class AvahiPoller : public MainloopProcessor
void Update(MainloopContext &aMainloop) override;
void Process(const MainloopContext &aMainloop) override;

const AvahiPoll *GetAvahiPoll(void) const { return &mAvahiPoller; }
const AvahiPoll *GetAvahiPoll(void) const { return &mAvahiPoll; }

private:
typedef std::vector<AvahiWatch *> Watches;
typedef std::vector<AvahiTimeout *> Timers;

static AvahiWatch *WatchNew(const struct AvahiPoll *aPoller,
static AvahiWatch *WatchNew(const struct AvahiPoll *aPoll,
int aFd,
AvahiWatchEvent aEvent,
AvahiWatchCallback aCallback,
Expand All @@ -184,7 +200,7 @@ class AvahiPoller : public MainloopProcessor
static AvahiWatchEvent WatchGetEvents(AvahiWatch *aWatch);
static void WatchFree(AvahiWatch *aWatch);
void WatchFree(AvahiWatch &aWatch);
static AvahiTimeout *TimeoutNew(const AvahiPoll *aPoller,
static AvahiTimeout *TimeoutNew(const AvahiPoll *aPoll,
const struct timeval *aTimeout,
AvahiTimeoutCallback aCallback,
void *aContext);
Expand All @@ -195,36 +211,36 @@ class AvahiPoller : public MainloopProcessor

Watches mWatches;
Timers mTimers;
AvahiPoll mAvahiPoller;
AvahiPoll mAvahiPoll;
};

AvahiPoller::AvahiPoller(void)
{
mAvahiPoller.userdata = this;
mAvahiPoller.watch_new = WatchNew;
mAvahiPoller.watch_update = WatchUpdate;
mAvahiPoller.watch_get_events = WatchGetEvents;
mAvahiPoller.watch_free = WatchFree;

mAvahiPoller.timeout_new = TimeoutNew;
mAvahiPoller.timeout_update = TimeoutUpdate;
mAvahiPoller.timeout_free = TimeoutFree;
mAvahiPoll.userdata = this;
mAvahiPoll.watch_new = WatchNew;
mAvahiPoll.watch_update = WatchUpdate;
mAvahiPoll.watch_get_events = WatchGetEvents;
mAvahiPoll.watch_free = WatchFree;

mAvahiPoll.timeout_new = TimeoutNew;
mAvahiPoll.timeout_update = TimeoutUpdate;
mAvahiPoll.timeout_free = TimeoutFree;
}

AvahiWatch *AvahiPoller::WatchNew(const struct AvahiPoll *aPoller,
AvahiWatch *AvahiPoller::WatchNew(const struct AvahiPoll *aPoll,
int aFd,
AvahiWatchEvent aEvent,
AvahiWatchCallback aCallback,
void *aContext)
{
return reinterpret_cast<AvahiPoller *>(aPoller->userdata)->WatchNew(aFd, aEvent, aCallback, aContext);
return reinterpret_cast<AvahiPoller *>(aPoll->userdata)->WatchNew(aFd, aEvent, aCallback, aContext);
}

AvahiWatch *AvahiPoller::WatchNew(int aFd, AvahiWatchEvent aEvent, AvahiWatchCallback aCallback, void *aContext)
{
assert(aEvent && aCallback && aFd >= 0);

mWatches.push_back(new AvahiWatch(aFd, aEvent, aCallback, aContext, this));
mWatches.push_back(new AvahiWatch(aFd, aEvent, aCallback, aContext, *this));

return mWatches.back();
}
Expand All @@ -241,7 +257,7 @@ AvahiWatchEvent AvahiPoller::WatchGetEvents(AvahiWatch *aWatch)

void AvahiPoller::WatchFree(AvahiWatch *aWatch)
{
reinterpret_cast<AvahiPoller *>(aWatch->mPoller)->WatchFree(*aWatch);
aWatch->mPoller.WatchFree(*aWatch);
}

void AvahiPoller::WatchFree(AvahiWatch &aWatch)
Expand All @@ -257,18 +273,18 @@ void AvahiPoller::WatchFree(AvahiWatch &aWatch)
}
}

AvahiTimeout *AvahiPoller::TimeoutNew(const AvahiPoll *aPoller,
AvahiTimeout *AvahiPoller::TimeoutNew(const AvahiPoll *aPoll,
const struct timeval *aTimeout,
AvahiTimeoutCallback aCallback,
void *aContext)
{
assert(aPoller && aCallback);
return static_cast<AvahiPoller *>(aPoller->userdata)->TimeoutNew(aTimeout, aCallback, aContext);
assert(aPoll && aCallback);
return static_cast<AvahiPoller *>(aPoll->userdata)->TimeoutNew(aTimeout, aCallback, aContext);
}

AvahiTimeout *AvahiPoller::TimeoutNew(const struct timeval *aTimeout, AvahiTimeoutCallback aCallback, void *aContext)
{
mTimers.push_back(new AvahiTimeout(aTimeout, aCallback, aContext, this));
mTimers.push_back(new AvahiTimeout(aTimeout, aCallback, aContext, *this));
return mTimers.back();
}

Expand All @@ -286,7 +302,7 @@ void AvahiPoller::TimeoutUpdate(AvahiTimeout *aTimer, const struct timeval *aTim

void AvahiPoller::TimeoutFree(AvahiTimeout *aTimer)
{
static_cast<AvahiPoller *>(aTimer->mPoller)->TimeoutFree(*aTimer);
aTimer->mPoller.TimeoutFree(*aTimer);
}

void AvahiPoller::TimeoutFree(AvahiTimeout &aTimer)
Expand All @@ -306,10 +322,10 @@ void AvahiPoller::Update(MainloopContext &aMainloop)
{
Timepoint now = Clock::now();

for (Watches::iterator it = mWatches.begin(); it != mWatches.end(); ++it)
for (AvahiWatch *watch : mWatches)
{
int fd = (*it)->mFd;
AvahiWatchEvent events = (*it)->mEvents;
int fd = watch->mFd;
AvahiWatchEvent events = watch->mEvents;

if (AVAHI_WATCH_IN & events)
{
Expand All @@ -333,12 +349,12 @@ void AvahiPoller::Update(MainloopContext &aMainloop)

aMainloop.mMaxFd = std::max(aMainloop.mMaxFd, fd);

(*it)->mHappened = 0;
watch->mHappened = 0;
}

for (Timers::iterator it = mTimers.begin(); it != mTimers.end(); ++it)
for (AvahiTimeout *timer : mTimers)
{
Timepoint timeout = (*it)->mTimeout;
Timepoint timeout = timer->mTimeout;

if (timeout == Timepoint::min())
{
Expand All @@ -364,56 +380,92 @@ void AvahiPoller::Update(MainloopContext &aMainloop)

void AvahiPoller::Process(const MainloopContext &aMainloop)
{
Timepoint now = Clock::now();
std::vector<AvahiTimeout *> expired;
Timepoint now = Clock::now();
bool shouldReport = false;

for (Watches::iterator it = mWatches.begin(); it != mWatches.end(); ++it)
for (AvahiWatch *watch : mWatches)
{
int fd = (*it)->mFd;
AvahiWatchEvent events = (*it)->mEvents;
int fd = watch->mFd;
AvahiWatchEvent events = watch->mEvents;

(*it)->mHappened = 0;
watch->mHappened = 0;

if ((AVAHI_WATCH_IN & events) && FD_ISSET(fd, &aMainloop.mReadFdSet))
{
(*it)->mHappened |= AVAHI_WATCH_IN;
watch->mHappened |= AVAHI_WATCH_IN;
}

if ((AVAHI_WATCH_OUT & events) && FD_ISSET(fd, &aMainloop.mWriteFdSet))
{
(*it)->mHappened |= AVAHI_WATCH_OUT;
watch->mHappened |= AVAHI_WATCH_OUT;
}

if ((AVAHI_WATCH_ERR & events) && FD_ISSET(fd, &aMainloop.mErrorFdSet))
{
(*it)->mHappened |= AVAHI_WATCH_ERR;
watch->mHappened |= AVAHI_WATCH_ERR;
}

// TODO hup events
if ((*it)->mHappened)
if (watch->mHappened != 0)
{
(*it)->mCallback(*it, (*it)->mFd, static_cast<AvahiWatchEvent>((*it)->mHappened), (*it)->mContext);
watch->mShouldReport = true;
shouldReport = true;
}
}

for (Timers::iterator it = mTimers.begin(); it != mTimers.end(); ++it)
// When we invoke the callback for an `AvahiWatch` or `AvahiTimeout`,
// the Avahi module can call any of `mAvahiPoll` APIs we provided to
// it. For example, it can update or free any of `AvahiWatch/Timeout`
// entries, which in turn, modifies our `mWatches` or `mTimers` list.
// So, before invoking the callback, we update the entry's state and
// then restart the iteration over the `mWacthes` list to find the
// next entry to report, as the list may have changed.

while (shouldReport)
{
if ((*it)->mTimeout == Timepoint::min())
shouldReport = false;

for (AvahiWatch *watch : mWatches)
{
if (watch->mShouldReport)
{
shouldReport = true;
watch->mShouldReport = false;
watch->mCallback(watch, watch->mFd, WatchGetEvents(watch), watch->mContext);

break;
}
}
}

for (AvahiTimeout *timer : mTimers)
{
if (timer->mTimeout == Timepoint::min())
{
continue;
}

if ((*it)->mTimeout <= now)
if (timer->mTimeout <= now)
{
expired.push_back(*it);
timer->mShouldReport = true;
shouldReport = true;
}
}

for (std::vector<AvahiTimeout *>::iterator it = expired.begin(); it != expired.end(); ++it)
while (shouldReport)
{
AvahiTimeout *avahiTimeout = *it;
shouldReport = false;

avahiTimeout->mCallback(avahiTimeout, avahiTimeout->mContext);
for (AvahiTimeout *timer : mTimers)
{
if (timer->mShouldReport)
{
shouldReport = true;
timer->mShouldReport = false;
timer->mCallback(timer, timer->mContext);

break;
}
}
}
}

Expand Down

0 comments on commit 3bf835b

Please sign in to comment.