From 63596c3f8c9c66a2aa48b24ab9865927ce9e958b Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Sat, 17 Aug 2024 09:53:10 +0000 Subject: [PATCH] fix(sockets) always uncork when closing (#13358) Co-authored-by: Jarred Sumner --- packages/bun-uws/src/AsyncSocket.h | 71 +++++++++++------------ packages/bun-uws/src/HttpContext.h | 5 ++ packages/bun-uws/src/HttpResponse.h | 2 +- packages/bun-uws/src/LoopData.h | 75 ++++++++++++++++++++----- packages/bun-uws/src/WebSocket.h | 8 +-- packages/bun-uws/src/WebSocketContext.h | 3 + 6 files changed, 110 insertions(+), 54 deletions(-) diff --git a/packages/bun-uws/src/AsyncSocket.h b/packages/bun-uws/src/AsyncSocket.h index 753212cc12728..4a0d82968cb03 100644 --- a/packages/bun-uws/src/AsyncSocket.h +++ b/packages/bun-uws/src/AsyncSocket.h @@ -117,47 +117,47 @@ struct AsyncSocket { /* Immediately close socket */ us_socket_t *close() { + this->uncork(); return us_socket_close(SSL, (us_socket_t *) this, 0, nullptr); } void corkUnchecked() { /* What if another socket is corked? */ - getLoopData()->corkedSocket = this; - getLoopData()->corkedSocketIsSSL = SSL; + getLoopData()->setCorkedSocket(this, SSL); } void uncorkWithoutSending() { if (isCorked()) { - getLoopData()->corkedSocket = nullptr; + getLoopData()->cleanCorkedSocket(); } } /* Cork this socket. Only one socket may ever be corked per-loop at any given time */ void cork() { + auto* corked = getLoopData()->getCorkedSocket(); /* Extra check for invalid corking of others */ - if (getLoopData()->corkOffset && getLoopData()->corkedSocket != this) { + if (getLoopData()->isCorked() && corked != this) { // We uncork the other socket early instead of terminating the program // is unlikely to be cause any issues and is better than crashing - if(getLoopData()->corkedSocketIsSSL) { - ((AsyncSocket *) getLoopData()->corkedSocket)->uncork(); + if(getLoopData()->isCorkedSSL()) { + ((AsyncSocket *) corked)->uncork(); } else { - ((AsyncSocket *) getLoopData()->corkedSocket)->uncork(); + ((AsyncSocket *) corked)->uncork(); } } /* What if another socket is corked? */ - getLoopData()->corkedSocket = this; - getLoopData()->corkedSocketIsSSL = SSL; + getLoopData()->setCorkedSocket(this, SSL); } /* Returns wheter we are corked or not */ bool isCorked() { - return getLoopData()->corkedSocket == this; + return getLoopData()->isCorkedWith(this); } /* Returns whether we could cork (it is free) */ bool canCork() { - return getLoopData()->corkedSocket == nullptr; + return getLoopData()->canCork(); } /* Returns a suitable buffer for temporary assemblation of send data */ @@ -166,16 +166,16 @@ struct AsyncSocket { LoopData *loopData = getLoopData(); BackPressure &backPressure = getAsyncSocketData()->buffer; size_t existingBackpressure = backPressure.length(); - if ((!existingBackpressure) && (isCorked() || canCork()) && (loopData->corkOffset + size < LoopData::CORK_BUFFER_SIZE)) { + if ((!existingBackpressure) && (isCorked() || canCork()) && (loopData->getCorkOffset() + size < LoopData::CORK_BUFFER_SIZE)) { /* Cork automatically if we can */ if (isCorked()) { - char *sendBuffer = loopData->corkBuffer + loopData->corkOffset; - loopData->corkOffset += (unsigned int) size; + char *sendBuffer = loopData->getCorkSendBuffer(); + loopData->incrementCorkedOffset((unsigned int) size); return {sendBuffer, SendBufferAttribute::NEEDS_NOTHING}; } else { cork(); - char *sendBuffer = loopData->corkBuffer + loopData->corkOffset; - loopData->corkOffset += (unsigned int) size; + char *sendBuffer = loopData->getCorkSendBuffer(); + loopData->incrementCorkedOffset((unsigned int) size); return {sendBuffer, SendBufferAttribute::NEEDS_UNCORK}; } } else { @@ -183,17 +183,19 @@ struct AsyncSocket { /* If we are corked and there is already data in the cork buffer, mark how much is ours and reset it */ unsigned int ourCorkOffset = 0; - if (isCorked() && loopData->corkOffset) { - ourCorkOffset = loopData->corkOffset; - loopData->corkOffset = 0; + + if (isCorked()) { + ourCorkOffset = loopData->getCorkOffset(); + loopData->setCorkOffset(0); } /* Fallback is to use the backpressure as buffer */ backPressure.resize(ourCorkOffset + existingBackpressure + size); - /* And copy corkbuffer in front */ - memcpy((char *) backPressure.data() + existingBackpressure, loopData->corkBuffer, ourCorkOffset); - + if(ourCorkOffset > 0) { + /* And copy corkbuffer in front */ + memcpy((char *) backPressure.data() + existingBackpressure, loopData->getCorkBuffer(), ourCorkOffset); + } return {(char *) backPressure.data() + ourCorkOffset + existingBackpressure, SendBufferAttribute::NEEDS_DRAIN}; } } @@ -279,20 +281,20 @@ struct AsyncSocket { } if (length) { - if (loopData->corkedSocket == this) { + if (loopData->isCorkedWith(this)) { /* We are corked */ - if (LoopData::CORK_BUFFER_SIZE - loopData->corkOffset >= (unsigned int) length) { + if (LoopData::CORK_BUFFER_SIZE - loopData->getCorkOffset() >= (unsigned int) length) { /* If the entire chunk fits in cork buffer */ - memcpy(loopData->corkBuffer + loopData->corkOffset, src, (unsigned int) length); - loopData->corkOffset += (unsigned int) length; + memcpy(loopData->getCorkSendBuffer(), src, (unsigned int) length); + loopData->incrementCorkedOffset((unsigned int) length); /* Fall through to default return */ } else { /* Strategy differences between SSL and non-SSL regarding syscall minimizing */ if constexpr (false) { /* Cork up as much as we can */ - unsigned int stripped = LoopData::CORK_BUFFER_SIZE - loopData->corkOffset; - memcpy(loopData->corkBuffer + loopData->corkOffset, src, stripped); - loopData->corkOffset = LoopData::CORK_BUFFER_SIZE; + unsigned int stripped = LoopData::CORK_BUFFER_SIZE - loopData->getCorkOffset(); + memcpy(loopData->getCorkSendBuffer(), src, stripped); + loopData->setCorkOffset(LoopData::CORK_BUFFER_SIZE); auto [written, failed] = uncork(src + stripped, length - (int) stripped, optionally); return {written + (int) stripped, failed}; @@ -335,14 +337,13 @@ struct AsyncSocket { /* It does NOT count bytes written from cork buffer (they are already accounted for in the write call responsible for its corking)! */ std::pair uncork(const char *src = nullptr, int length = 0, bool optionally = false) { LoopData *loopData = getLoopData(); + if (loopData->isCorkedWith(this)) { + auto offset = loopData->getCorkOffset(); + loopData->cleanCorkedSocket(); - if (loopData->corkedSocket == this) { - loopData->corkedSocket = nullptr; - - if (loopData->corkOffset) { + if (offset) { /* Corked data is already accounted for via its write call */ - auto [written, failed] = write(loopData->corkBuffer, (int) loopData->corkOffset, false, length); - loopData->corkOffset = 0; + auto [written, failed] = write(loopData->getCorkBuffer(), (int) offset, false, length); if (failed && optionally) { /* We do not need to care for buffering here, write does that */ diff --git a/packages/bun-uws/src/HttpContext.h b/packages/bun-uws/src/HttpContext.h index 43c68fd90f9e4..68f900b70d3e5 100644 --- a/packages/bun-uws/src/HttpContext.h +++ b/packages/bun-uws/src/HttpContext.h @@ -113,6 +113,9 @@ struct HttpContext { /* Handle socket disconnections */ us_socket_context_on_close(SSL, getSocketContext(), [](us_socket_t *s, int /*code*/, void */*reason*/) { + ((AsyncSocket *)s)->uncorkWithoutSending(); + + /* Get socket ext */ HttpResponseData *httpResponseData = (HttpResponseData *) us_socket_ext(SSL, s); @@ -126,6 +129,7 @@ struct HttpContext { if (httpResponseData->onAborted) { httpResponseData->onAborted((HttpResponse *)s, httpResponseData->userData); } + /* Destruct socket ext */ httpResponseData->~HttpResponseData(); @@ -400,6 +404,7 @@ struct HttpContext { /* Handle FIN, HTTP does not support half-closed sockets, so simply close */ us_socket_context_on_end(SSL, getSocketContext(), [](us_socket_t *s) { + ((AsyncSocket *)s)->uncorkWithoutSending(); /* We do not care for half closed sockets */ AsyncSocket *asyncSocket = (AsyncSocket *) s; diff --git a/packages/bun-uws/src/HttpResponse.h b/packages/bun-uws/src/HttpResponse.h index 15e7057a11f66..aee8963c92f3c 100644 --- a/packages/bun-uws/src/HttpResponse.h +++ b/packages/bun-uws/src/HttpResponse.h @@ -513,7 +513,7 @@ struct HttpResponse : public AsyncSocket { /* The only way we could possibly have changed the corked socket during handler call, would be if * the HTTP socket was upgraded to WebSocket and caused a realloc. Because of this we cannot use "this" * from here downwards. The corking is done with corkUnchecked() in upgrade. It steals cork. */ - auto *newCorkedSocket = loopData->corkedSocket; + auto *newCorkedSocket = loopData->getCorkedSocket(); /* If nobody is corked, it means most probably that large amounts of data has * been written and the cork buffer has already been sent off and uncorked. diff --git a/packages/bun-uws/src/LoopData.h b/packages/bun-uws/src/LoopData.h index 92bd9ffff853b..e68ca51b0eb0d 100644 --- a/packages/bun-uws/src/LoopData.h +++ b/packages/bun-uws/src/LoopData.h @@ -18,17 +18,17 @@ #ifndef UWS_LOOPDATA_H #define UWS_LOOPDATA_H -#include +#include +#include #include -#include -#include #include -#include -#include +#include +#include +#include -#include "PerMessageDeflate.h" #include "MoveOnlyFunction.h" - +#include "PerMessageDeflate.h" +// clang-format off struct us_timer_t; namespace uWS { @@ -44,7 +44,11 @@ struct alignas(16) LoopData { /* Map from void ptr to handler */ std::map> postHandlers, preHandlers; - + /* Cork data */ + char *corkBuffer = new char[CORK_BUFFER_SIZE]; + unsigned int corkOffset = 0; + void *corkedSocket = nullptr; + bool corkedSocketIsSSL = false; public: LoopData() { updateDate(); @@ -59,6 +63,55 @@ struct alignas(16) LoopData { } delete [] corkBuffer; } + void* getCorkedSocket() { + return this->corkedSocket; + } + + void setCorkedSocket(void *corkedSocket, bool ssl) { + this->corkedSocket = corkedSocket; + this->corkedSocketIsSSL = ssl; + } + + bool isCorkedSSL() { + return this->corkedSocketIsSSL; + } + + bool isCorked() { + return this->corkOffset && this->corkedSocket; + } + + bool canCork() { + return this->corkedSocket == nullptr; + } + + bool isCorkedWith(void* socket) { + return this->corkedSocket == socket; + } + + char* getCorkSendBuffer() { + return this->corkBuffer + this->corkOffset; + } + + void cleanCorkedSocket() { + this->corkedSocket = nullptr; + this->corkOffset = 0; + } + + unsigned int getCorkOffset() { + return this->corkOffset; + } + + void setCorkOffset(unsigned int offset) { + this->corkOffset = offset; + } + + void incrementCorkedOffset(unsigned int offset) { + this->corkOffset += offset; + } + + char* getCorkBuffer() { + return this->corkBuffer; + } void updateDate() { time_t now = time(0); @@ -94,12 +147,6 @@ struct alignas(16) LoopData { /* Good 16k for SSL perf. */ static const unsigned int CORK_BUFFER_SIZE = 16 * 1024; - /* Cork data */ - char *corkBuffer = new char[CORK_BUFFER_SIZE]; - unsigned int corkOffset = 0; - void *corkedSocket = nullptr; - bool corkedSocketIsSSL = false; - /* Per message deflate data */ ZlibContext *zlibContext = nullptr; InflationStream *inflationStream = nullptr; diff --git a/packages/bun-uws/src/WebSocket.h b/packages/bun-uws/src/WebSocket.h index 94d7976616cf0..ee17fb8225f82 100644 --- a/packages/bun-uws/src/WebSocket.h +++ b/packages/bun-uws/src/WebSocket.h @@ -18,13 +18,13 @@ #ifndef UWS_WEBSOCKET_H #define UWS_WEBSOCKET_H -#include "WebSocketData.h" -#include "WebSocketProtocol.h" #include "AsyncSocket.h" #include "WebSocketContextData.h" +#include "WebSocketData.h" +#include "WebSocketProtocol.h" #include - +// clang-format off namespace uWS { template @@ -107,7 +107,7 @@ struct WebSocket : AsyncSocket { WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData(); /* Special path for long sends of non-compressed, non-SSL messages */ - if (message.length() >= 16 * 1024 && !compress && !SSL && !webSocketData->subscriber && getBufferedAmount() == 0 && Super::getLoopData()->corkOffset == 0) { + if (message.length() >= 16 * 1024 && !compress && !SSL && !webSocketData->subscriber && getBufferedAmount() == 0 && Super::getLoopData()->getCorkOffset() == 0) { char header[10]; int header_length = (int) protocol::formatMessage(header, "", 0, opCode, message.length(), compress, fin); int written = us_socket_write2(0, (struct us_socket_t *)this, header, header_length, message.data(), (int) message.length()); diff --git a/packages/bun-uws/src/WebSocketContext.h b/packages/bun-uws/src/WebSocketContext.h index 950db9b0b99e7..25c6e216acdf6 100644 --- a/packages/bun-uws/src/WebSocketContext.h +++ b/packages/bun-uws/src/WebSocketContext.h @@ -252,6 +252,8 @@ struct WebSocketContext { /* Handle socket disconnections */ us_socket_context_on_close(SSL, getSocketContext(), [](auto *s, int code, void *reason) { + ((AsyncSocket *)s)->uncorkWithoutSending(); + /* For whatever reason, if we already have emitted close event, do not emit it again */ WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s)); if (!webSocketData->isShuttingDown) { @@ -371,6 +373,7 @@ struct WebSocketContext { /* Handle FIN, HTTP does not support half-closed sockets, so simply close */ us_socket_context_on_end(SSL, getSocketContext(), [](auto *s) { + ((AsyncSocket *)s)->uncorkWithoutSending(); /* If we get a fin, we just close I guess */ us_socket_close(SSL, (us_socket_t *) s, 0, nullptr);