diff --git a/base/message_loop/message_pump_io_starboard.cc b/base/message_loop/message_pump_io_starboard.cc index 714012b8c455..ef33f601344d 100644 --- a/base/message_loop/message_pump_io_starboard.cc +++ b/base/message_loop/message_pump_io_starboard.cc @@ -42,24 +42,17 @@ MessagePumpIOStarboard::SocketWatcher::~SocketWatcher() { bool MessagePumpIOStarboard::SocketWatcher::StopWatchingSocket() { watcher_ = nullptr; interests_ = kSbSocketWaiterInterestNone; - if (!SbSocketIsValid(socket_)) { - pump_ = nullptr; - // If this watcher is not watching anything, no-op and return success. - return true; - } SbSocket socket = Release(); bool result = true; if (SbSocketIsValid(socket)) { DCHECK(pump_); -#if defined(STARBOARD) // This may get called multiple times from TCPSocketStarboard. if (pump_) { - result = pump_->StopWatching(socket); + result = pump_->UnregisterInterest( + socket, kSbSocketWaiterInterestRead || kSbSocketWaiterInterestWrite, + this); } -#else - result = pump_->StopWatching(socket); -#endif } pump_ = nullptr; return result; @@ -109,27 +102,75 @@ MessagePumpIOStarboard::~MessagePumpIOStarboard() { SbSocketWaiterDestroy(waiter_); } +bool MessagePumpIOStarboard::UnregisterInterest(SbSocket socket, + int dropped_interests, + SocketWatcher* controller) { + DCHECK(SbSocketIsValid(socket)); + DCHECK(controller); + DCHECK(dropped_interests == kSbSocketWaiterInterestRead || + dropped_interests == kSbSocketWaiterInterestWrite || + dropped_interests == + (kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite)); + DCHECK_CALLED_ON_VALID_THREAD(watch_socket_caller_checker_); + + // Make sure we don't pick up any funky internal masks. + int old_interest_mask = + controller->interests() & + (kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite); + int interests = old_interest_mask & (~dropped_interests); + if (interests == old_interest_mask) { + // Interests didn't change, return. + return true; + } + + SbSocket old_socket = controller->Release(); + if (SbSocketIsValid(old_socket)) { + // It's illegal to use this function to listen on 2 separate fds with the + // same |controller|. + if (old_socket != socket) { + NOTREACHED() << "Sockets don't match" << old_socket << "!=" << socket; + return false; + } + + // Must disarm the event before we can reuse it. + SbSocketWaiterRemove(waiter_, old_socket); + } else { + interests = kSbSocketWaiterInterestNone; + } + + if (!SbSocketIsValid(socket)) { + NOTREACHED() << "Invalid socket" << socket; + return false; + } + + if (interests) { + // Set current interest mask and waiter for this event. + if (!SbSocketWaiterAdd(waiter_, socket, controller, + OnSocketWaiterNotification, interests, + controller->persistent())) { + return false; + } + controller->Init(socket, controller->persistent()); + } + return true; +} + bool MessagePumpIOStarboard::Watch(SbSocket socket, bool persistent, - int mode, + int interests, SocketWatcher* controller, Watcher* delegate) { DCHECK(SbSocketIsValid(socket)); DCHECK(controller); DCHECK(delegate); - DCHECK(mode == WATCH_READ || mode == WATCH_WRITE || mode == WATCH_READ_WRITE); + DCHECK(interests == kSbSocketWaiterInterestRead || + interests == kSbSocketWaiterInterestWrite || + interests == + (kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite)); // Watch should be called on the pump thread. It is not threadsafe, and your // watcher may never be registered. DCHECK_CALLED_ON_VALID_THREAD(watch_socket_caller_checker_); - int interests = kSbSocketWaiterInterestNone; - if (mode & WATCH_READ) { - interests |= kSbSocketWaiterInterestRead; - } - if (mode & WATCH_WRITE) { - interests |= kSbSocketWaiterInterestWrite; - } - SbSocket old_socket = controller->Release(); if (SbSocketIsValid(old_socket)) { // It's illegal to use this function to listen on 2 separate fds with the @@ -151,6 +192,11 @@ bool MessagePumpIOStarboard::Watch(SbSocket socket, SbSocketWaiterRemove(waiter_, old_socket); } + if (!SbSocketIsValid(socket)) { + NOTREACHED() << "Invalid socket" << socket; + return false; + } + // Set current interest mask and waiter for this event. if (!SbSocketWaiterAdd(waiter_, socket, controller, OnSocketWaiterNotification, interests, persistent)) { diff --git a/base/message_loop/message_pump_io_starboard.h b/base/message_loop/message_pump_io_starboard.h index b7d807dec914..132d53addffc 100644 --- a/base/message_loop/message_pump_io_starboard.h +++ b/base/message_loop/message_pump_io_starboard.h @@ -102,12 +102,6 @@ class BASE_EXPORT MessagePumpIOStarboard : public MessagePump { base::WeakPtrFactory weak_factory_; }; - enum Mode { - WATCH_READ = 1 << 0, - WATCH_WRITE = 1 << 1, - WATCH_READ_WRITE = WATCH_READ | WATCH_WRITE - }; - MessagePumpIOStarboard(); virtual ~MessagePumpIOStarboard(); @@ -125,10 +119,15 @@ class BASE_EXPORT MessagePumpIOStarboard : public MessagePump { // success. Must be called on the same thread the message_pump is running on. bool Watch(SbSocket socket, bool persistent, - int mode, + int interests, SocketWatcher* controller, Watcher* delegate); + // Removes an interest from a socket, and stops watching the socket if needed. + bool UnregisterInterest(SbSocket socket, + int dropped_interests, + SocketWatcher* controller); + // Stops watching the socket. bool StopWatching(SbSocket socket); diff --git a/base/message_loop/message_pump_io_starboard_unittest.cc b/base/message_loop/message_pump_io_starboard_unittest.cc index c6231f70d882..65ba8151043f 100644 --- a/base/message_loop/message_pump_io_starboard_unittest.cc +++ b/base/message_loop/message_pump_io_starboard_unittest.cc @@ -137,10 +137,9 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_DeleteWatcher) { std::make_unique(FROM_HERE)); std::unique_ptr pump = CreateMessagePump(); pump->Watch(socket(), - /*persistent=*/false, - MessagePumpIOStarboard::WATCH_READ_WRITE, - delegate.controller(), - &delegate); + /*persistent=*/false, + (kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite), + delegate.controller(), &delegate); SimulateIOEvent(delegate.controller()); } @@ -165,10 +164,9 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_StopWatcher) { MessagePumpIOStarboard::SocketWatcher controller(FROM_HERE); StopWatcher delegate(&controller); pump->Watch(socket(), - /*persistent=*/false, - MessagePumpIOStarboard::WATCH_READ_WRITE, - &controller, - &delegate); + /*persistent=*/false, + (kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite), + &controller, &delegate); SimulateIOEvent(&controller); } @@ -202,10 +200,8 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_NestedPumpWatcher) { std::unique_ptr pump = CreateMessagePump(); MessagePumpIOStarboard::SocketWatcher controller(FROM_HERE); pump->Watch(socket(), - /*persistent=*/false, - MessagePumpIOStarboard::WATCH_READ, - &controller, - &delegate); + /*persistent=*/false, kSbSocketWaiterInterestRead, &controller, + &delegate); SimulateIOEvent(&controller); } @@ -253,10 +249,8 @@ TEST_F(MessagePumpIOStarboardTest, DISABLED_QuitWatcher) { // Tell the pump to watch the pipe. pump->Watch(socket(), - /*persistent=*/false, - MessagePumpIOStarboard::WATCH_READ, - &controller, - &delegate); + /*persistent=*/false, kSbSocketWaiterInterestRead, &controller, + &delegate); // Make the IO thread wait for |event| before writing to pipefds[1]. const char buf = 0; diff --git a/base/task/current_thread.cc b/base/task/current_thread.cc index 673e2028d4a2..d7e572fec80f 100644 --- a/base/task/current_thread.cc +++ b/base/task/current_thread.cc @@ -214,11 +214,17 @@ MessagePumpForIO* CurrentIOThread::GetMessagePumpForIO() const { #if defined(STARBOARD) bool CurrentIOThread::Watch(SbSocket socket, bool persistent, - int mode, + SbSocketWaiterInterest interests, SocketWatcher* controller, Watcher* delegate) { return static_cast(GetMessagePumpForIO()) - ->Watch(socket, persistent, mode, controller, delegate); + ->Watch(socket, persistent, interests, controller, delegate); +} +bool CurrentIOThread::UnregisterInterest(SbSocket socket, + int dropped_interests, + SocketWatcher* controller) { + return static_cast(GetMessagePumpForIO()) + ->UnregisterInterest(socket, dropped_interests, controller); } #elif BUILDFLAG(IS_WIN) HRESULT CurrentIOThread::RegisterIOHandler( diff --git a/base/task/current_thread.h b/base/task/current_thread.h index 5106c5e9fd71..c39eb38669bd 100644 --- a/base/task/current_thread.h +++ b/base/task/current_thread.h @@ -275,15 +275,14 @@ class BASE_EXPORT CurrentIOThread : public CurrentThread { typedef base::MessagePumpIOStarboard::SocketWatcher SocketWatcher; typedef base::MessagePumpIOStarboard::IOObserver IOObserver; - enum Mode{WATCH_READ = base::MessagePumpIOStarboard::WATCH_READ, - WATCH_WRITE = base::MessagePumpIOStarboard::WATCH_WRITE, - WATCH_READ_WRITE = base::MessagePumpIOStarboard::WATCH_READ_WRITE}; - bool Watch(SbSocket socket, bool persistent, - int mode, + SbSocketWaiterInterest interests, SocketWatcher* controller, Watcher* delegate); + bool UnregisterInterest(SbSocket socket, + int dropped_interests, + SocketWatcher* controller); #elif BUILDFLAG(IS_WIN) // Please see MessagePumpWin for definitions of these methods. HRESULT RegisterIOHandler(HANDLE file, MessagePumpForIO::IOHandler* handler); diff --git a/net/socket/tcp_socket_starboard.cc b/net/socket/tcp_socket_starboard.cc index 0fdd752ae130..920bc13f69e9 100644 --- a/net/socket/tcp_socket_starboard.cc +++ b/net/socket/tcp_socket_starboard.cc @@ -139,9 +139,9 @@ int TCPSocketStarboard::Accept(std::unique_ptr* socket, int result = AcceptInternal(socket, address); if (result == ERR_IO_PENDING) { - if (!base::CurrentIOThread::Get()->Watch( - socket_, true, base::MessagePumpIOStarboard::WATCH_READ, - &socket_watcher_, this)) { + if (!base::CurrentIOThread::Get()->Watch(socket_, true, + kSbSocketWaiterInterestRead, + &socket_watcher_, this)) { DLOG(ERROR) << "WatchSocket failed on read"; return MapLastSocketError(socket_); } @@ -252,7 +252,9 @@ void TCPSocketStarboard::Close() { } void TCPSocketStarboard::StopWatchingAndCleanUp() { - bool ok = socket_watcher_.StopWatchingSocket(); + bool ok = base::CurrentIOThread::Get()->UnregisterInterest( + socket_, kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite, + &socket_watcher_); DCHECK(ok); if (!accept_callback_.is_null()) { @@ -293,7 +295,7 @@ void TCPSocketStarboard::OnSocketReadyToRead(SbSocket socket) { } else if (read_pending()) { DidCompleteRead(); } else { - ClearWatcherIfOperationsNotPending(); + ClearReadWatcherIfOperationsNotPending(); } } @@ -345,8 +347,7 @@ int TCPSocketStarboard::Connect(const IPEndPoint& address, // When it is ready to write, it will have connected. base::CurrentIOThread::Get()->Watch( - socket_, true, base::MessagePumpIOStarboard::WATCH_WRITE, - &socket_watcher_, this); + socket_, true, kSbSocketWaiterInterestWrite, &socket_watcher_, this); return ERR_IO_PENDING; } @@ -374,7 +375,7 @@ void TCPSocketStarboard::DidCompleteConnect() { waiting_connect_ = false; CompletionOnceCallback callback = std::move(write_callback_); write_callback_.Reset(); - ClearWatcherIfOperationsNotPending(); + ClearReadWatcherIfOperationsNotPending(); std::move(callback).Run(HandleConnectCompleted(rv)); } @@ -460,9 +461,8 @@ int TCPSocketStarboard::ReadIfReady(IOBuffer* buf, } read_if_ready_callback_ = std::move(callback); - base::CurrentIOThread::Get()->Watch( - socket_, true, base::MessagePumpIOStarboard::WATCH_READ, - &socket_watcher_, this); + base::CurrentIOThread::Get()->Watch( + socket_, true, kSbSocketWaiterInterestRead, &socket_watcher_, this); return rv; } @@ -470,7 +470,8 @@ int TCPSocketStarboard::ReadIfReady(IOBuffer* buf, int TCPSocketStarboard::CancelReadIfReady() { DCHECK(read_if_ready_callback_); - bool ok = socket_watcher_.StopWatchingSocket(); + bool ok = base::CurrentIOThread::Get()->UnregisterInterest( + socket_, kSbSocketWaiterInterestRead, &socket_watcher_); DCHECK(ok); read_if_ready_callback_.Reset(); @@ -522,7 +523,7 @@ void TCPSocketStarboard::DidCompleteRead() { CompletionOnceCallback callback = std::move(read_if_ready_callback_); read_if_ready_callback_.Reset(); - ClearWatcherIfOperationsNotPending(); + ClearReadWatcherIfOperationsNotPending(); std::move(callback).Run(OK); } @@ -545,8 +546,7 @@ int TCPSocketStarboard::Write( write_buf_len_ = buf_len; write_callback_ = std::move(callback); base::CurrentIOThread::Get()->Watch( - socket_, true, base::MessagePumpIOStarboard::WATCH_WRITE, - &socket_watcher_, this); + socket_, true, kSbSocketWaiterInterestWrite, &socket_watcher_, this); } return rv; @@ -586,7 +586,7 @@ void TCPSocketStarboard::DidCompleteWrite() { CompletionOnceCallback callback = std::move(write_callback_); write_callback_.Reset(); - ClearWatcherIfOperationsNotPending(); + ClearWriteWatcherIfOperationsNotPending(); std::move(callback).Run(rv); } } @@ -667,10 +667,18 @@ void TCPSocketStarboard::ApplySocketTag(const SocketTag& tag) { tag_ = tag; } -void TCPSocketStarboard::ClearWatcherIfOperationsNotPending() { - if (!read_pending() && !write_pending() && !accept_pending() && - !connect_pending()) { - bool ok = socket_watcher_.StopWatchingSocket(); +void TCPSocketStarboard::ClearReadWatcherIfOperationsNotPending() { + if (!read_pending() && !accept_pending() && !connect_pending()) { + bool ok = base::CurrentIOThread::Get()->UnregisterInterest( + socket_, kSbSocketWaiterInterestRead, &socket_watcher_); + DCHECK(ok); + } +} + +void TCPSocketStarboard::ClearWriteWatcherIfOperationsNotPending() { + if (!write_pending()) { + bool ok = base::CurrentIOThread::Get()->UnregisterInterest( + socket_, kSbSocketWaiterInterestWrite, &socket_watcher_); DCHECK(ok); } } diff --git a/net/socket/tcp_socket_starboard.h b/net/socket/tcp_socket_starboard.h index 876dac02d8a5..f23cc4ad2417 100644 --- a/net/socket/tcp_socket_starboard.h +++ b/net/socket/tcp_socket_starboard.h @@ -158,7 +158,8 @@ class NET_EXPORT TCPSocketStarboard : public base::MessagePumpIOStarboard::Watch int DoWrite(IOBuffer* buf, int buf_len); void StopWatchingAndCleanUp(); - void ClearWatcherIfOperationsNotPending(); + void ClearReadWatcherIfOperationsNotPending(); + void ClearWriteWatcherIfOperationsNotPending(); bool read_pending() const { return !read_if_ready_callback_.is_null(); } bool write_pending() const { diff --git a/net/socket/udp_socket_starboard.cc b/net/socket/udp_socket_starboard.cc index 6cf1f737f691..e3bd7a4c1b41 100644 --- a/net/socket/udp_socket_starboard.cc +++ b/net/socket/udp_socket_starboard.cc @@ -132,7 +132,9 @@ void UDPSocketStarboard::Close() { write_callback_.Reset(); send_to_address_.reset(); - bool ok = socket_watcher_.StopWatchingSocket(); + bool ok = base::CurrentIOThread::Get()->UnregisterInterest( + socket_, kSbSocketWaiterInterestRead | kSbSocketWaiterInterestWrite, + &socket_watcher_); DCHECK(ok); is_connected_ = false; @@ -224,8 +226,7 @@ int UDPSocketStarboard::ReadMultiplePackets(Socket::ReadPacketResults* results, } if (!base::CurrentIOThread::Get()->Watch( - socket_, true, base::MessagePumpIOStarboard::WATCH_READ, - &socket_watcher_, this)) { + socket_, true, kSbSocketWaiterInterestRead, &socket_watcher_, this)) { PLOG(ERROR) << "WatchSocket failed on read"; Error result = MapLastSocketError(socket_); if (result == ERR_IO_PENDING) { @@ -266,8 +267,7 @@ int UDPSocketStarboard::RecvFrom(IOBuffer* buf, return nread; if (!base::CurrentIOThread::Get()->Watch( - socket_, true, base::MessagePumpIOStarboard::WATCH_READ, - &socket_watcher_, this)) { + socket_, true, kSbSocketWaiterInterestRead, &socket_watcher_, this)) { PLOG(ERROR) << "WatchSocket failed on read"; Error result = MapLastSocketError(socket_); if (result == ERR_IO_PENDING) { @@ -315,9 +315,9 @@ int UDPSocketStarboard::SendToOrWrite(IOBuffer* buf, if (result != ERR_IO_PENDING) return result; - if (!base::CurrentIOThread::Get()->Watch( - socket_, true, base::MessagePumpIOStarboard::WATCH_WRITE, - &socket_watcher_, this)) { + if (!base::CurrentIOThread::Get()->Watch(socket_, true, + kSbSocketWaiterInterestWrite, + &socket_watcher_, this)) { DVLOG(1) << "Watch failed on write, error " << SbSocketGetLastError(socket_); Error result = MapLastSocketError(socket_); @@ -476,7 +476,7 @@ void UDPSocketStarboard::WriteAsyncWatcher::OnSocketReadyToWrite( SbSocket /*socket*/) { DVLOG(1) << __func__ << " queue " << socket_->pending_writes_.size() << " out of " << socket_->write_async_outstanding_ << " total"; - socket_->StopWatchingSocket(); + socket_->StopWatchingSocketForWriting(); socket_->FlushPending(); } @@ -504,7 +504,7 @@ void UDPSocketStarboard::DidCompleteRead() { read_buf_ = NULL; read_buf_len_ = 0; recv_from_address_ = NULL; - InternalStopWatchingSocket(); + StopWatchingSocketForReading(); DoReadCallback(result); } } @@ -513,7 +513,7 @@ void UDPSocketStarboard::DidCompleteMultiplePacketRead() { int result = InternalReadMultiplePackets(results_); if (result != ERR_IO_PENDING) { results_ = nullptr; - InternalStopWatchingSocket(); + StopWatchingSocketForReading(); DoReadCallback(result); } } @@ -543,7 +543,7 @@ void UDPSocketStarboard::DidCompleteWrite() { write_buf_ = NULL; write_buf_len_ = 0; send_to_address_.reset(); - InternalStopWatchingSocket(); + StopWatchingSocketForWriting(); DoWriteCallback(result); } } @@ -959,7 +959,7 @@ void UDPSocketStarboard::DidSendBuffers(SendResult send_result) { last_async_result_ = send_result.rv; if (last_async_result_ == ERR_IO_PENDING) { DVLOG(2) << __func__ << " WatchSocket start"; - if (!WatchSocket()) { + if (!WatchSocketForWriting()) { last_async_result_ = MapLastSocketError(socket_); DVLOG(1) << "WatchSocket failed on write, error: " << last_async_result_; LogWrite(last_async_result_, NULL, NULL); @@ -970,7 +970,7 @@ void UDPSocketStarboard::DidSendBuffers(SendResult send_result) { DVLOG(2) << __func__ << " WatchSocket stop: result " << ErrorToShortString(last_async_result_) << " pending_writes " << pending_writes_.size(); - StopWatchingSocket(); + StopWatchingSocketForWriting(); } DCHECK(last_async_result_ != ERR_IO_PENDING); @@ -1002,32 +1002,32 @@ void UDPSocketStarboard::SetMsgConfirm(bool confirm) { NOTIMPLEMENTED(); } -bool UDPSocketStarboard::WatchSocket() { +bool UDPSocketStarboard::WatchSocketForWriting() { if (write_async_watcher_->watching()) return true; - bool result = InternalWatchSocket(); + bool result = base::CurrentIOThread::Get()->Watch( + socket_, true, kSbSocketWaiterInterestWrite, &socket_watcher_, this); if (result) { write_async_watcher_->set_watching(true); } return result; } -void UDPSocketStarboard::StopWatchingSocket() { +void UDPSocketStarboard::StopWatchingSocketForWriting() { if (!write_async_watcher_->watching()) return; write_async_watcher_->set_watching(false); - InternalStopWatchingSocket(); -} - -bool UDPSocketStarboard::InternalWatchSocket() { - return base::CurrentIOThread::Get()->Watch( - socket_, true, base::MessagePumpIOStarboard::WATCH_WRITE, - &socket_watcher_, this); + if (!write_buf_) { + bool ok = base::CurrentIOThread::Get()->UnregisterInterest( + socket_, kSbSocketWaiterInterestWrite, &socket_watcher_); + DCHECK(ok); + } } -void UDPSocketStarboard::InternalStopWatchingSocket() { - if (!read_buf_ && !write_buf_ && !write_async_watcher_->watching()) { - bool ok = socket_watcher_.StopWatchingSocket(); +void UDPSocketStarboard::StopWatchingSocketForReading() { + if (!read_buf_) { + bool ok = base::CurrentIOThread::Get()->UnregisterInterest( + socket_, kSbSocketWaiterInterestRead, &socket_watcher_); DCHECK(ok); } } diff --git a/net/socket/udp_socket_starboard.h b/net/socket/udp_socket_starboard.h index 080d778ba9c5..ae4430c5cf95 100644 --- a/net/socket/udp_socket_starboard.h +++ b/net/socket/udp_socket_starboard.h @@ -361,9 +361,6 @@ class NET_EXPORT UDPSocketStarboard write_async_outstanding_ += increment; } - virtual bool InternalWatchSocket(); - virtual void InternalStopWatchingSocket(); - void SetWriteCallback(CompletionOnceCallback callback) { write_callback_ = std::move(callback); } @@ -385,8 +382,9 @@ class NET_EXPORT UDPSocketStarboard int InternalWriteAsync(CompletionOnceCallback callback, const NetworkTrafficAnnotationTag& traffic_annotation); - bool WatchSocket(); - void StopWatchingSocket(); + bool WatchSocketForWriting(); + void StopWatchingSocketForReading(); + void StopWatchingSocketForWriting(); void DoReadCallback(int rv); void DoWriteCallback(int rv);