Skip to content

Commit

Permalink
Fix socket watching read and write conflict.
Browse files Browse the repository at this point in the history
  • Loading branch information
jellefoks committed Nov 14, 2024
1 parent 87f2a2c commit 9483cf0
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 103 deletions.
86 changes: 66 additions & 20 deletions base/message_loop/message_pump_io_starboard.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,17 @@ MessagePumpIOStarboard::SocketWatcher::~SocketWatcher() {
bool MessagePumpIOStarboard::SocketWatcher::StopWatchingSocket() {
watcher_ = nullptr;
interests_ = kSbSocketWaiterInterestNone;
if (!SbSocketIsValid(socket_)) {
pump_ = nullptr;
// If this watcher is not watching anything, no-op and return success.
return true;
}

SbSocket socket = Release();
bool result = true;
if (SbSocketIsValid(socket)) {
DCHECK(pump_);
#if defined(STARBOARD)
// This may get called multiple times from TCPSocketStarboard.
if (pump_) {
result = pump_->StopWatching(socket);
result = pump_->UnregisterInterest(
socket, kSbSocketWaiterInterestRead || kSbSocketWaiterInterestWrite,
this);
}
#else
result = pump_->StopWatching(socket);
#endif
}
pump_ = nullptr;
return result;
Expand Down Expand Up @@ -109,27 +102,75 @@ MessagePumpIOStarboard::~MessagePumpIOStarboard() {
SbSocketWaiterDestroy(waiter_);
}

bool MessagePumpIOStarboard::UnregisterInterest(SbSocket socket,
int dropped_interests,
SocketWatcher* controller) {
DCHECK(SbSocketIsValid(socket));
DCHECK(controller);
DCHECK(dropped_interests == kSbSocketWaiterInterestRead ||
dropped_interests == kSbSocketWaiterInterestWrite ||
dropped_interests ==
(kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite));
DCHECK_CALLED_ON_VALID_THREAD(watch_socket_caller_checker_);

// Make sure we don't pick up any funky internal masks.
int old_interest_mask =
controller->interests() &
(kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite);
int interests = old_interest_mask & (~dropped_interests);
if (interests == old_interest_mask) {
// Interests didn't change, return.
return true;
}

SbSocket old_socket = controller->Release();
if (SbSocketIsValid(old_socket)) {
// It's illegal to use this function to listen on 2 separate fds with the
// same |controller|.
if (old_socket != socket) {
NOTREACHED() << "Sockets don't match" << old_socket << "!=" << socket;
return false;
}

// Must disarm the event before we can reuse it.
SbSocketWaiterRemove(waiter_, old_socket);
} else {
interests = kSbSocketWaiterInterestNone;
}

if (!SbSocketIsValid(socket)) {
NOTREACHED() << "Invalid socket" << socket;
return false;
}

if (interests) {
// Set current interest mask and waiter for this event.
if (!SbSocketWaiterAdd(waiter_, socket, controller,
OnSocketWaiterNotification, interests,
controller->persistent())) {
return false;
}
controller->Init(socket, controller->persistent());
}
return true;
}

bool MessagePumpIOStarboard::Watch(SbSocket socket,
bool persistent,
int mode,
int interests,
SocketWatcher* controller,
Watcher* delegate) {
DCHECK(SbSocketIsValid(socket));
DCHECK(controller);
DCHECK(delegate);
DCHECK(mode == WATCH_READ || mode == WATCH_WRITE || mode == WATCH_READ_WRITE);
DCHECK(interests == kSbSocketWaiterInterestRead ||
interests == kSbSocketWaiterInterestWrite ||
interests ==
(kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite));
// Watch should be called on the pump thread. It is not threadsafe, and your
// watcher may never be registered.
DCHECK_CALLED_ON_VALID_THREAD(watch_socket_caller_checker_);

int interests = kSbSocketWaiterInterestNone;
if (mode & WATCH_READ) {
interests |= kSbSocketWaiterInterestRead;
}
if (mode & WATCH_WRITE) {
interests |= kSbSocketWaiterInterestWrite;
}

SbSocket old_socket = controller->Release();
if (SbSocketIsValid(old_socket)) {
// It's illegal to use this function to listen on 2 separate fds with the
Expand All @@ -151,6 +192,11 @@ bool MessagePumpIOStarboard::Watch(SbSocket socket,
SbSocketWaiterRemove(waiter_, old_socket);
}

if (!SbSocketIsValid(socket)) {
NOTREACHED() << "Invalid socket" << socket;
return false;
}

// Set current interest mask and waiter for this event.
if (!SbSocketWaiterAdd(waiter_, socket, controller,
OnSocketWaiterNotification, interests, persistent)) {
Expand Down
13 changes: 6 additions & 7 deletions base/message_loop/message_pump_io_starboard.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,6 @@ class BASE_EXPORT MessagePumpIOStarboard : public MessagePump {
base::WeakPtrFactory<SocketWatcher> weak_factory_;
};

enum Mode {
WATCH_READ = 1 << 0,
WATCH_WRITE = 1 << 1,
WATCH_READ_WRITE = WATCH_READ | WATCH_WRITE
};

MessagePumpIOStarboard();
virtual ~MessagePumpIOStarboard();

Expand All @@ -125,10 +119,15 @@ class BASE_EXPORT MessagePumpIOStarboard : public MessagePump {
// success. Must be called on the same thread the message_pump is running on.
bool Watch(SbSocket socket,
bool persistent,
int mode,
int interests,
SocketWatcher* controller,
Watcher* delegate);

// Removes an interest from a socket, and stops watching the socket if needed.
bool UnregisterInterest(SbSocket socket,
int dropped_interests,
SocketWatcher* controller);

// Stops watching the socket.
bool StopWatching(SbSocket socket);

Expand Down
26 changes: 10 additions & 16 deletions base/message_loop/message_pump_io_starboard_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_DeleteWatcher) {
std::make_unique<MessagePumpIOStarboard::SocketWatcher>(FROM_HERE));
std::unique_ptr<MessagePumpIOStarboard> pump = CreateMessagePump();
pump->Watch(socket(),
/*persistent=*/false,
MessagePumpIOStarboard::WATCH_READ_WRITE,
delegate.controller(),
&delegate);
/*persistent=*/false,
(kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite),
delegate.controller(), &delegate);
SimulateIOEvent(delegate.controller());
}

Expand All @@ -165,10 +164,9 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_StopWatcher) {
MessagePumpIOStarboard::SocketWatcher controller(FROM_HERE);
StopWatcher delegate(&controller);
pump->Watch(socket(),
/*persistent=*/false,
MessagePumpIOStarboard::WATCH_READ_WRITE,
&controller,
&delegate);
/*persistent=*/false,
(kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite),
&controller, &delegate);
SimulateIOEvent(&controller);
}

Expand Down Expand Up @@ -202,10 +200,8 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_NestedPumpWatcher) {
std::unique_ptr<MessagePumpIOStarboard> pump = CreateMessagePump();
MessagePumpIOStarboard::SocketWatcher controller(FROM_HERE);
pump->Watch(socket(),
/*persistent=*/false,
MessagePumpIOStarboard::WATCH_READ,
&controller,
&delegate);
/*persistent=*/false, kSbSocketWaiterInterestRead, &controller,
&delegate);
SimulateIOEvent(&controller);
}

Expand Down Expand Up @@ -253,10 +249,8 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_QuitWatcher) {

// Tell the pump to watch the pipe.
pump->Watch(socket(),
/*persistent=*/false,
MessagePumpIOStarboard::WATCH_READ,
&controller,
&delegate);
/*persistent=*/false, kSbSocketWaiterInterestRead, &controller,
&delegate);

// Make the IO thread wait for |event| before writing to pipefds[1].
const char buf = 0;
Expand Down
10 changes: 8 additions & 2 deletions base/task/current_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,17 @@ MessagePumpForIO* CurrentIOThread::GetMessagePumpForIO() const {
#if defined(STARBOARD)
bool CurrentIOThread::Watch(SbSocket socket,
bool persistent,
int mode,
SbSocketWaiterInterest interests,
SocketWatcher* controller,
Watcher* delegate) {
return static_cast<MessagePumpIOStarboard*>(GetMessagePumpForIO())
->Watch(socket, persistent, mode, controller, delegate);
->Watch(socket, persistent, interests, controller, delegate);
}
bool CurrentIOThread::UnregisterInterest(SbSocket socket,
int dropped_interests,
SocketWatcher* controller) {
return static_cast<MessagePumpIOStarboard*>(GetMessagePumpForIO())
->UnregisterInterest(socket, dropped_interests, controller);
}
#elif BUILDFLAG(IS_WIN)
HRESULT CurrentIOThread::RegisterIOHandler(
Expand Down
9 changes: 4 additions & 5 deletions base/task/current_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,14 @@ class BASE_EXPORT CurrentIOThread : public CurrentThread {
typedef base::MessagePumpIOStarboard::SocketWatcher SocketWatcher;
typedef base::MessagePumpIOStarboard::IOObserver IOObserver;

enum Mode{WATCH_READ = base::MessagePumpIOStarboard::WATCH_READ,
WATCH_WRITE = base::MessagePumpIOStarboard::WATCH_WRITE,
WATCH_READ_WRITE = base::MessagePumpIOStarboard::WATCH_READ_WRITE};

bool Watch(SbSocket socket,
bool persistent,
int mode,
SbSocketWaiterInterest interests,
SocketWatcher* controller,
Watcher* delegate);
bool UnregisterInterest(SbSocket socket,
int dropped_interests,
SocketWatcher* controller);
#elif BUILDFLAG(IS_WIN)
// Please see MessagePumpWin for definitions of these methods.
HRESULT RegisterIOHandler(HANDLE file, MessagePumpForIO::IOHandler* handler);
Expand Down
48 changes: 28 additions & 20 deletions net/socket/tcp_socket_starboard.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ int TCPSocketStarboard::Accept(std::unique_ptr<TCPSocketStarboard>* socket,
int result = AcceptInternal(socket, address);

if (result == ERR_IO_PENDING) {
if (!base::CurrentIOThread::Get()->Watch(
socket_, true, base::MessagePumpIOStarboard::WATCH_READ,
&socket_watcher_, this)) {
if (!base::CurrentIOThread::Get()->Watch(socket_, true,
kSbSocketWaiterInterestRead,
&socket_watcher_, this)) {
DLOG(ERROR) << "WatchSocket failed on read";
return MapLastSocketError(socket_);
}
Expand Down Expand Up @@ -252,7 +252,9 @@ void TCPSocketStarboard::Close() {
}

void TCPSocketStarboard::StopWatchingAndCleanUp() {
bool ok = socket_watcher_.StopWatchingSocket();
bool ok = base::CurrentIOThread::Get()->UnregisterInterest(
socket_, kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite,
&socket_watcher_);
DCHECK(ok);

if (!accept_callback_.is_null()) {
Expand Down Expand Up @@ -293,7 +295,7 @@ void TCPSocketStarboard::OnSocketReadyToRead(SbSocket socket) {
} else if (read_pending()) {
DidCompleteRead();
} else {
ClearWatcherIfOperationsNotPending();
ClearReadWatcherIfOperationsNotPending();
}
}

Expand Down Expand Up @@ -345,8 +347,7 @@ int TCPSocketStarboard::Connect(const IPEndPoint& address,

// When it is ready to write, it will have connected.
base::CurrentIOThread::Get()->Watch(
socket_, true, base::MessagePumpIOStarboard::WATCH_WRITE,
&socket_watcher_, this);
socket_, true, kSbSocketWaiterInterestWrite, &socket_watcher_, this);

return ERR_IO_PENDING;
}
Expand Down Expand Up @@ -374,7 +375,7 @@ void TCPSocketStarboard::DidCompleteConnect() {
waiting_connect_ = false;
CompletionOnceCallback callback = std::move(write_callback_);
write_callback_.Reset();
ClearWatcherIfOperationsNotPending();
ClearReadWatcherIfOperationsNotPending();
std::move(callback).Run(HandleConnectCompleted(rv));
}

Expand Down Expand Up @@ -460,17 +461,17 @@ int TCPSocketStarboard::ReadIfReady(IOBuffer* buf,
}

read_if_ready_callback_ = std::move(callback);
base::CurrentIOThread::Get()->Watch(
socket_, true, base::MessagePumpIOStarboard::WATCH_READ,
&socket_watcher_, this);
base::CurrentIOThread::Get()->Watch(
socket_, true, kSbSocketWaiterInterestRead, &socket_watcher_, this);

return rv;
}

int TCPSocketStarboard::CancelReadIfReady() {
DCHECK(read_if_ready_callback_);

bool ok = socket_watcher_.StopWatchingSocket();
bool ok = base::CurrentIOThread::Get()->UnregisterInterest(
socket_, kSbSocketWaiterInterestRead, &socket_watcher_);
DCHECK(ok);

read_if_ready_callback_.Reset();
Expand Down Expand Up @@ -522,7 +523,7 @@ void TCPSocketStarboard::DidCompleteRead() {
CompletionOnceCallback callback = std::move(read_if_ready_callback_);
read_if_ready_callback_.Reset();

ClearWatcherIfOperationsNotPending();
ClearReadWatcherIfOperationsNotPending();
std::move(callback).Run(OK);
}

Expand All @@ -545,8 +546,7 @@ int TCPSocketStarboard::Write(
write_buf_len_ = buf_len;
write_callback_ = std::move(callback);
base::CurrentIOThread::Get()->Watch(
socket_, true, base::MessagePumpIOStarboard::WATCH_WRITE,
&socket_watcher_, this);
socket_, true, kSbSocketWaiterInterestWrite, &socket_watcher_, this);
}

return rv;
Expand Down Expand Up @@ -586,7 +586,7 @@ void TCPSocketStarboard::DidCompleteWrite() {
CompletionOnceCallback callback = std::move(write_callback_);
write_callback_.Reset();

ClearWatcherIfOperationsNotPending();
ClearWriteWatcherIfOperationsNotPending();
std::move(callback).Run(rv);
}
}
Expand Down Expand Up @@ -667,10 +667,18 @@ void TCPSocketStarboard::ApplySocketTag(const SocketTag& tag) {
tag_ = tag;
}

void TCPSocketStarboard::ClearWatcherIfOperationsNotPending() {
if (!read_pending() && !write_pending() && !accept_pending() &&
!connect_pending()) {
bool ok = socket_watcher_.StopWatchingSocket();
void TCPSocketStarboard::ClearReadWatcherIfOperationsNotPending() {
if (!read_pending() && !accept_pending() && !connect_pending()) {
bool ok = base::CurrentIOThread::Get()->UnregisterInterest(
socket_, kSbSocketWaiterInterestRead, &socket_watcher_);
DCHECK(ok);
}
}

void TCPSocketStarboard::ClearWriteWatcherIfOperationsNotPending() {
if (!write_pending()) {
bool ok = base::CurrentIOThread::Get()->UnregisterInterest(
socket_, kSbSocketWaiterInterestWrite, &socket_watcher_);
DCHECK(ok);
}
}
Expand Down
3 changes: 2 additions & 1 deletion net/socket/tcp_socket_starboard.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class NET_EXPORT TCPSocketStarboard : public base::MessagePumpIOStarboard::Watch
int DoWrite(IOBuffer* buf, int buf_len);

void StopWatchingAndCleanUp();
void ClearWatcherIfOperationsNotPending();
void ClearReadWatcherIfOperationsNotPending();
void ClearWriteWatcherIfOperationsNotPending();

bool read_pending() const { return !read_if_ready_callback_.is_null(); }
bool write_pending() const {
Expand Down
Loading

0 comments on commit 9483cf0

Please sign in to comment.