Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix socket watching read and write conflict. #4430

Draft
wants to merge 1 commit into
base: 25.lts.1+
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 66 additions & 20 deletions base/message_loop/message_pump_io_starboard.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,17 @@ MessagePumpIOStarboard::SocketWatcher::~SocketWatcher() {
bool MessagePumpIOStarboard::SocketWatcher::StopWatchingSocket() {
watcher_ = nullptr;
interests_ = kSbSocketWaiterInterestNone;
if (!SbSocketIsValid(socket_)) {
pump_ = nullptr;
// If this watcher is not watching anything, no-op and return success.
return true;
}

SbSocket socket = Release();
bool result = true;
if (SbSocketIsValid(socket)) {
DCHECK(pump_);
#if defined(STARBOARD)
// This may get called multiple times from TCPSocketStarboard.
if (pump_) {
result = pump_->StopWatching(socket);
result = pump_->UnregisterInterest(
socket, kSbSocketWaiterInterestRead || kSbSocketWaiterInterestWrite,
this);
}
#else
result = pump_->StopWatching(socket);
#endif
}
pump_ = nullptr;
return result;
Expand Down Expand Up @@ -109,27 +102,75 @@ MessagePumpIOStarboard::~MessagePumpIOStarboard() {
SbSocketWaiterDestroy(waiter_);
}

bool MessagePumpIOStarboard::UnregisterInterest(SbSocket socket,
int dropped_interests,
SocketWatcher* controller) {
DCHECK(SbSocketIsValid(socket));
DCHECK(controller);
DCHECK(dropped_interests == kSbSocketWaiterInterestRead ||
dropped_interests == kSbSocketWaiterInterestWrite ||
dropped_interests ==
(kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite));
DCHECK_CALLED_ON_VALID_THREAD(watch_socket_caller_checker_);

// Make sure we don't pick up any funky internal masks.
int old_interest_mask =
controller->interests() &
(kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite);
int interests = old_interest_mask & (~dropped_interests);
if (interests == old_interest_mask) {
// Interests didn't change, return.
return true;
}

SbSocket old_socket = controller->Release();
if (SbSocketIsValid(old_socket)) {
// It's illegal to use this function to listen on 2 separate fds with the
// same |controller|.
if (old_socket != socket) {
NOTREACHED() << "Sockets don't match" << old_socket << "!=" << socket;
return false;
}

// Must disarm the event before we can reuse it.
SbSocketWaiterRemove(waiter_, old_socket);
} else {
interests = kSbSocketWaiterInterestNone;
}

if (!SbSocketIsValid(socket)) {
NOTREACHED() << "Invalid socket" << socket;
return false;
}

if (interests) {
// Set current interest mask and waiter for this event.
if (!SbSocketWaiterAdd(waiter_, socket, controller,
OnSocketWaiterNotification, interests,
controller->persistent())) {
return false;
}
controller->Init(socket, controller->persistent());
}
return true;
}

bool MessagePumpIOStarboard::Watch(SbSocket socket,
bool persistent,
int mode,
int interests,
SocketWatcher* controller,
Watcher* delegate) {
DCHECK(SbSocketIsValid(socket));
DCHECK(controller);
DCHECK(delegate);
DCHECK(mode == WATCH_READ || mode == WATCH_WRITE || mode == WATCH_READ_WRITE);
DCHECK(interests == kSbSocketWaiterInterestRead ||
interests == kSbSocketWaiterInterestWrite ||
interests ==
(kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite));
// Watch should be called on the pump thread. It is not threadsafe, and your
// watcher may never be registered.
DCHECK_CALLED_ON_VALID_THREAD(watch_socket_caller_checker_);

int interests = kSbSocketWaiterInterestNone;
if (mode & WATCH_READ) {
interests |= kSbSocketWaiterInterestRead;
}
if (mode & WATCH_WRITE) {
interests |= kSbSocketWaiterInterestWrite;
}

SbSocket old_socket = controller->Release();
if (SbSocketIsValid(old_socket)) {
// It's illegal to use this function to listen on 2 separate fds with the
Expand All @@ -151,6 +192,11 @@ bool MessagePumpIOStarboard::Watch(SbSocket socket,
SbSocketWaiterRemove(waiter_, old_socket);
}

if (!SbSocketIsValid(socket)) {
NOTREACHED() << "Invalid socket" << socket;
return false;
}

// Set current interest mask and waiter for this event.
if (!SbSocketWaiterAdd(waiter_, socket, controller,
OnSocketWaiterNotification, interests, persistent)) {
Expand Down
13 changes: 6 additions & 7 deletions base/message_loop/message_pump_io_starboard.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,6 @@ class BASE_EXPORT MessagePumpIOStarboard : public MessagePump {
base::WeakPtrFactory<SocketWatcher> weak_factory_;
};

enum Mode {
WATCH_READ = 1 << 0,
WATCH_WRITE = 1 << 1,
WATCH_READ_WRITE = WATCH_READ | WATCH_WRITE
};

MessagePumpIOStarboard();
virtual ~MessagePumpIOStarboard();

Expand All @@ -125,10 +119,15 @@ class BASE_EXPORT MessagePumpIOStarboard : public MessagePump {
// success. Must be called on the same thread the message_pump is running on.
bool Watch(SbSocket socket,
bool persistent,
int mode,
int interests,
SocketWatcher* controller,
Watcher* delegate);

// Removes an interest from a socket, and stops watching the socket if needed.
bool UnregisterInterest(SbSocket socket,
int dropped_interests,
SocketWatcher* controller);

// Stops watching the socket.
bool StopWatching(SbSocket socket);

Expand Down
26 changes: 10 additions & 16 deletions base/message_loop/message_pump_io_starboard_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_DeleteWatcher) {
std::make_unique<MessagePumpIOStarboard::SocketWatcher>(FROM_HERE));
std::unique_ptr<MessagePumpIOStarboard> pump = CreateMessagePump();
pump->Watch(socket(),
/*persistent=*/false,
MessagePumpIOStarboard::WATCH_READ_WRITE,
delegate.controller(),
&delegate);
/*persistent=*/false,
(kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite),
delegate.controller(), &delegate);
SimulateIOEvent(delegate.controller());
}

Expand All @@ -165,10 +164,9 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_StopWatcher) {
MessagePumpIOStarboard::SocketWatcher controller(FROM_HERE);
StopWatcher delegate(&controller);
pump->Watch(socket(),
/*persistent=*/false,
MessagePumpIOStarboard::WATCH_READ_WRITE,
&controller,
&delegate);
/*persistent=*/false,
(kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite),
&controller, &delegate);
SimulateIOEvent(&controller);
}

Expand Down Expand Up @@ -202,10 +200,8 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_NestedPumpWatcher) {
std::unique_ptr<MessagePumpIOStarboard> pump = CreateMessagePump();
MessagePumpIOStarboard::SocketWatcher controller(FROM_HERE);
pump->Watch(socket(),
/*persistent=*/false,
MessagePumpIOStarboard::WATCH_READ,
&controller,
&delegate);
/*persistent=*/false, kSbSocketWaiterInterestRead, &controller,
&delegate);
SimulateIOEvent(&controller);
}

Expand Down Expand Up @@ -253,10 +249,8 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_QuitWatcher) {

// Tell the pump to watch the pipe.
pump->Watch(socket(),
/*persistent=*/false,
MessagePumpIOStarboard::WATCH_READ,
&controller,
&delegate);
/*persistent=*/false, kSbSocketWaiterInterestRead, &controller,
&delegate);

// Make the IO thread wait for |event| before writing to pipefds[1].
const char buf = 0;
Expand Down
10 changes: 8 additions & 2 deletions base/task/current_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,17 @@ MessagePumpForIO* CurrentIOThread::GetMessagePumpForIO() const {
#if defined(STARBOARD)
bool CurrentIOThread::Watch(SbSocket socket,
bool persistent,
int mode,
SbSocketWaiterInterest interests,
SocketWatcher* controller,
Watcher* delegate) {
return static_cast<MessagePumpIOStarboard*>(GetMessagePumpForIO())
->Watch(socket, persistent, mode, controller, delegate);
->Watch(socket, persistent, interests, controller, delegate);
}
bool CurrentIOThread::UnregisterInterest(SbSocket socket,
int dropped_interests,
SocketWatcher* controller) {
return static_cast<MessagePumpIOStarboard*>(GetMessagePumpForIO())
->UnregisterInterest(socket, dropped_interests, controller);
}
#elif BUILDFLAG(IS_WIN)
HRESULT CurrentIOThread::RegisterIOHandler(
Expand Down
9 changes: 4 additions & 5 deletions base/task/current_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,14 @@ class BASE_EXPORT CurrentIOThread : public CurrentThread {
typedef base::MessagePumpIOStarboard::SocketWatcher SocketWatcher;
typedef base::MessagePumpIOStarboard::IOObserver IOObserver;

enum Mode{WATCH_READ = base::MessagePumpIOStarboard::WATCH_READ,
WATCH_WRITE = base::MessagePumpIOStarboard::WATCH_WRITE,
WATCH_READ_WRITE = base::MessagePumpIOStarboard::WATCH_READ_WRITE};

bool Watch(SbSocket socket,
bool persistent,
int mode,
SbSocketWaiterInterest interests,
SocketWatcher* controller,
Watcher* delegate);
bool UnregisterInterest(SbSocket socket,
int dropped_interests,
SocketWatcher* controller);
#elif BUILDFLAG(IS_WIN)
// Please see MessagePumpWin for definitions of these methods.
HRESULT RegisterIOHandler(HANDLE file, MessagePumpForIO::IOHandler* handler);
Expand Down
48 changes: 28 additions & 20 deletions net/socket/tcp_socket_starboard.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ int TCPSocketStarboard::Accept(std::unique_ptr<TCPSocketStarboard>* socket,
int result = AcceptInternal(socket, address);

if (result == ERR_IO_PENDING) {
if (!base::CurrentIOThread::Get()->Watch(
socket_, true, base::MessagePumpIOStarboard::WATCH_READ,
&socket_watcher_, this)) {
if (!base::CurrentIOThread::Get()->Watch(socket_, true,
kSbSocketWaiterInterestRead,
&socket_watcher_, this)) {
DLOG(ERROR) << "WatchSocket failed on read";
return MapLastSocketError(socket_);
}
Expand Down Expand Up @@ -252,7 +252,9 @@ void TCPSocketStarboard::Close() {
}

void TCPSocketStarboard::StopWatchingAndCleanUp() {
bool ok = socket_watcher_.StopWatchingSocket();
bool ok = base::CurrentIOThread::Get()->UnregisterInterest(
socket_, kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite,
&socket_watcher_);
DCHECK(ok);

if (!accept_callback_.is_null()) {
Expand Down Expand Up @@ -293,7 +295,7 @@ void TCPSocketStarboard::OnSocketReadyToRead(SbSocket socket) {
} else if (read_pending()) {
DidCompleteRead();
} else {
ClearWatcherIfOperationsNotPending();
ClearReadWatcherIfOperationsNotPending();
}
}

Expand Down Expand Up @@ -345,8 +347,7 @@ int TCPSocketStarboard::Connect(const IPEndPoint& address,

// When it is ready to write, it will have connected.
base::CurrentIOThread::Get()->Watch(
socket_, true, base::MessagePumpIOStarboard::WATCH_WRITE,
&socket_watcher_, this);
socket_, true, kSbSocketWaiterInterestWrite, &socket_watcher_, this);

return ERR_IO_PENDING;
}
Expand Down Expand Up @@ -374,7 +375,7 @@ void TCPSocketStarboard::DidCompleteConnect() {
waiting_connect_ = false;
CompletionOnceCallback callback = std::move(write_callback_);
write_callback_.Reset();
ClearWatcherIfOperationsNotPending();
ClearReadWatcherIfOperationsNotPending();
std::move(callback).Run(HandleConnectCompleted(rv));
}

Expand Down Expand Up @@ -460,17 +461,17 @@ int TCPSocketStarboard::ReadIfReady(IOBuffer* buf,
}

read_if_ready_callback_ = std::move(callback);
base::CurrentIOThread::Get()->Watch(
socket_, true, base::MessagePumpIOStarboard::WATCH_READ,
&socket_watcher_, this);
base::CurrentIOThread::Get()->Watch(
socket_, true, kSbSocketWaiterInterestRead, &socket_watcher_, this);

return rv;
}

int TCPSocketStarboard::CancelReadIfReady() {
DCHECK(read_if_ready_callback_);

bool ok = socket_watcher_.StopWatchingSocket();
bool ok = base::CurrentIOThread::Get()->UnregisterInterest(
socket_, kSbSocketWaiterInterestRead, &socket_watcher_);
DCHECK(ok);

read_if_ready_callback_.Reset();
Expand Down Expand Up @@ -522,7 +523,7 @@ void TCPSocketStarboard::DidCompleteRead() {
CompletionOnceCallback callback = std::move(read_if_ready_callback_);
read_if_ready_callback_.Reset();

ClearWatcherIfOperationsNotPending();
ClearReadWatcherIfOperationsNotPending();
std::move(callback).Run(OK);
}

Expand All @@ -545,8 +546,7 @@ int TCPSocketStarboard::Write(
write_buf_len_ = buf_len;
write_callback_ = std::move(callback);
base::CurrentIOThread::Get()->Watch(
socket_, true, base::MessagePumpIOStarboard::WATCH_WRITE,
&socket_watcher_, this);
socket_, true, kSbSocketWaiterInterestWrite, &socket_watcher_, this);
}

return rv;
Expand Down Expand Up @@ -586,7 +586,7 @@ void TCPSocketStarboard::DidCompleteWrite() {
CompletionOnceCallback callback = std::move(write_callback_);
write_callback_.Reset();

ClearWatcherIfOperationsNotPending();
ClearWriteWatcherIfOperationsNotPending();
std::move(callback).Run(rv);
}
}
Expand Down Expand Up @@ -667,10 +667,18 @@ void TCPSocketStarboard::ApplySocketTag(const SocketTag& tag) {
tag_ = tag;
}

void TCPSocketStarboard::ClearWatcherIfOperationsNotPending() {
if (!read_pending() && !write_pending() && !accept_pending() &&
!connect_pending()) {
bool ok = socket_watcher_.StopWatchingSocket();
void TCPSocketStarboard::ClearReadWatcherIfOperationsNotPending() {
if (!read_pending() && !accept_pending() && !connect_pending()) {
bool ok = base::CurrentIOThread::Get()->UnregisterInterest(
socket_, kSbSocketWaiterInterestRead, &socket_watcher_);
DCHECK(ok);
}
}

void TCPSocketStarboard::ClearWriteWatcherIfOperationsNotPending() {
if (!write_pending()) {
bool ok = base::CurrentIOThread::Get()->UnregisterInterest(
socket_, kSbSocketWaiterInterestWrite, &socket_watcher_);
DCHECK(ok);
}
}
Expand Down
3 changes: 2 additions & 1 deletion net/socket/tcp_socket_starboard.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class NET_EXPORT TCPSocketStarboard : public base::MessagePumpIOStarboard::Watch
int DoWrite(IOBuffer* buf, int buf_len);

void StopWatchingAndCleanUp();
void ClearWatcherIfOperationsNotPending();
void ClearReadWatcherIfOperationsNotPending();
void ClearWriteWatcherIfOperationsNotPending();

bool read_pending() const { return !read_if_ready_callback_.is_null(); }
bool write_pending() const {
Expand Down
Loading
Loading