Skip to content

Commit

Permalink
fix(sockets) always uncork when closing (#13358)
Browse files Browse the repository at this point in the history
Co-authored-by: Jarred Sumner <[email protected]>
  • Loading branch information
cirospaciari and Jarred-Sumner authored Aug 17, 2024
1 parent 996847b commit 63596c3
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 54 deletions.
71 changes: 36 additions & 35 deletions packages/bun-uws/src/AsyncSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,47 +117,47 @@ struct AsyncSocket {

/* Immediately close socket */
us_socket_t *close() {
this->uncork();
return us_socket_close(SSL, (us_socket_t *) this, 0, nullptr);
}

void corkUnchecked() {
/* What if another socket is corked? */
getLoopData()->corkedSocket = this;
getLoopData()->corkedSocketIsSSL = SSL;
getLoopData()->setCorkedSocket(this, SSL);
}

void uncorkWithoutSending() {
if (isCorked()) {
getLoopData()->corkedSocket = nullptr;
getLoopData()->cleanCorkedSocket();
}
}

/* Cork this socket. Only one socket may ever be corked per-loop at any given time */
void cork() {
auto* corked = getLoopData()->getCorkedSocket();
/* Extra check for invalid corking of others */
if (getLoopData()->corkOffset && getLoopData()->corkedSocket != this) {
if (getLoopData()->isCorked() && corked != this) {
// We uncork the other socket early instead of terminating the program
// is unlikely to be cause any issues and is better than crashing
if(getLoopData()->corkedSocketIsSSL) {
((AsyncSocket<true> *) getLoopData()->corkedSocket)->uncork();
if(getLoopData()->isCorkedSSL()) {
((AsyncSocket<true> *) corked)->uncork();
} else {
((AsyncSocket<false> *) getLoopData()->corkedSocket)->uncork();
((AsyncSocket<false> *) corked)->uncork();
}
}

/* What if another socket is corked? */
getLoopData()->corkedSocket = this;
getLoopData()->corkedSocketIsSSL = SSL;
getLoopData()->setCorkedSocket(this, SSL);
}

/* Returns wheter we are corked or not */
bool isCorked() {
return getLoopData()->corkedSocket == this;
return getLoopData()->isCorkedWith(this);
}

/* Returns whether we could cork (it is free) */
bool canCork() {
return getLoopData()->corkedSocket == nullptr;
return getLoopData()->canCork();
}

/* Returns a suitable buffer for temporary assemblation of send data */
Expand All @@ -166,34 +166,36 @@ struct AsyncSocket {
LoopData *loopData = getLoopData();
BackPressure &backPressure = getAsyncSocketData()->buffer;
size_t existingBackpressure = backPressure.length();
if ((!existingBackpressure) && (isCorked() || canCork()) && (loopData->corkOffset + size < LoopData::CORK_BUFFER_SIZE)) {
if ((!existingBackpressure) && (isCorked() || canCork()) && (loopData->getCorkOffset() + size < LoopData::CORK_BUFFER_SIZE)) {
/* Cork automatically if we can */
if (isCorked()) {
char *sendBuffer = loopData->corkBuffer + loopData->corkOffset;
loopData->corkOffset += (unsigned int) size;
char *sendBuffer = loopData->getCorkSendBuffer();
loopData->incrementCorkedOffset((unsigned int) size);
return {sendBuffer, SendBufferAttribute::NEEDS_NOTHING};
} else {
cork();
char *sendBuffer = loopData->corkBuffer + loopData->corkOffset;
loopData->corkOffset += (unsigned int) size;
char *sendBuffer = loopData->getCorkSendBuffer();
loopData->incrementCorkedOffset((unsigned int) size);
return {sendBuffer, SendBufferAttribute::NEEDS_UNCORK};
}
} else {

/* If we are corked and there is already data in the cork buffer,
mark how much is ours and reset it */
unsigned int ourCorkOffset = 0;
if (isCorked() && loopData->corkOffset) {
ourCorkOffset = loopData->corkOffset;
loopData->corkOffset = 0;

if (isCorked()) {
ourCorkOffset = loopData->getCorkOffset();
loopData->setCorkOffset(0);
}

/* Fallback is to use the backpressure as buffer */
backPressure.resize(ourCorkOffset + existingBackpressure + size);

/* And copy corkbuffer in front */
memcpy((char *) backPressure.data() + existingBackpressure, loopData->corkBuffer, ourCorkOffset);

if(ourCorkOffset > 0) {
/* And copy corkbuffer in front */
memcpy((char *) backPressure.data() + existingBackpressure, loopData->getCorkBuffer(), ourCorkOffset);
}
return {(char *) backPressure.data() + ourCorkOffset + existingBackpressure, SendBufferAttribute::NEEDS_DRAIN};
}
}
Expand Down Expand Up @@ -279,20 +281,20 @@ struct AsyncSocket {
}

if (length) {
if (loopData->corkedSocket == this) {
if (loopData->isCorkedWith(this)) {
/* We are corked */
if (LoopData::CORK_BUFFER_SIZE - loopData->corkOffset >= (unsigned int) length) {
if (LoopData::CORK_BUFFER_SIZE - loopData->getCorkOffset() >= (unsigned int) length) {
/* If the entire chunk fits in cork buffer */
memcpy(loopData->corkBuffer + loopData->corkOffset, src, (unsigned int) length);
loopData->corkOffset += (unsigned int) length;
memcpy(loopData->getCorkSendBuffer(), src, (unsigned int) length);
loopData->incrementCorkedOffset((unsigned int) length);
/* Fall through to default return */
} else {
/* Strategy differences between SSL and non-SSL regarding syscall minimizing */
if constexpr (false) {
/* Cork up as much as we can */
unsigned int stripped = LoopData::CORK_BUFFER_SIZE - loopData->corkOffset;
memcpy(loopData->corkBuffer + loopData->corkOffset, src, stripped);
loopData->corkOffset = LoopData::CORK_BUFFER_SIZE;
unsigned int stripped = LoopData::CORK_BUFFER_SIZE - loopData->getCorkOffset();
memcpy(loopData->getCorkSendBuffer(), src, stripped);
loopData->setCorkOffset(LoopData::CORK_BUFFER_SIZE);

auto [written, failed] = uncork(src + stripped, length - (int) stripped, optionally);
return {written + (int) stripped, failed};
Expand Down Expand Up @@ -335,14 +337,13 @@ struct AsyncSocket {
/* It does NOT count bytes written from cork buffer (they are already accounted for in the write call responsible for its corking)! */
std::pair<int, bool> uncork(const char *src = nullptr, int length = 0, bool optionally = false) {
LoopData *loopData = getLoopData();
if (loopData->isCorkedWith(this)) {
auto offset = loopData->getCorkOffset();
loopData->cleanCorkedSocket();

if (loopData->corkedSocket == this) {
loopData->corkedSocket = nullptr;

if (loopData->corkOffset) {
if (offset) {
/* Corked data is already accounted for via its write call */
auto [written, failed] = write(loopData->corkBuffer, (int) loopData->corkOffset, false, length);
loopData->corkOffset = 0;
auto [written, failed] = write(loopData->getCorkBuffer(), (int) offset, false, length);

if (failed && optionally) {
/* We do not need to care for buffering here, write does that */
Expand Down
5 changes: 5 additions & 0 deletions packages/bun-uws/src/HttpContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ struct HttpContext {

/* Handle socket disconnections */
us_socket_context_on_close(SSL, getSocketContext(), [](us_socket_t *s, int /*code*/, void */*reason*/) {
((AsyncSocket<SSL> *)s)->uncorkWithoutSending();


/* Get socket ext */
HttpResponseData<SSL> *httpResponseData = (HttpResponseData<SSL> *) us_socket_ext(SSL, s);

Expand All @@ -126,6 +129,7 @@ struct HttpContext {
if (httpResponseData->onAborted) {
httpResponseData->onAborted((HttpResponse<SSL> *)s, httpResponseData->userData);
}


/* Destruct socket ext */
httpResponseData->~HttpResponseData<SSL>();
Expand Down Expand Up @@ -400,6 +404,7 @@ struct HttpContext {

/* Handle FIN, HTTP does not support half-closed sockets, so simply close */
us_socket_context_on_end(SSL, getSocketContext(), [](us_socket_t *s) {
((AsyncSocket<SSL> *)s)->uncorkWithoutSending();

/* We do not care for half closed sockets */
AsyncSocket<SSL> *asyncSocket = (AsyncSocket<SSL> *) s;
Expand Down
2 changes: 1 addition & 1 deletion packages/bun-uws/src/HttpResponse.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ struct HttpResponse : public AsyncSocket<SSL> {
/* The only way we could possibly have changed the corked socket during handler call, would be if
* the HTTP socket was upgraded to WebSocket and caused a realloc. Because of this we cannot use "this"
* from here downwards. The corking is done with corkUnchecked() in upgrade. It steals cork. */
auto *newCorkedSocket = loopData->corkedSocket;
auto *newCorkedSocket = loopData->getCorkedSocket();

/* If nobody is corked, it means most probably that large amounts of data has
* been written and the cork buffer has already been sent off and uncorked.
Expand Down
75 changes: 61 additions & 14 deletions packages/bun-uws/src/LoopData.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
#ifndef UWS_LOOPDATA_H
#define UWS_LOOPDATA_H

#include <thread>
#include <cstdint>
#include <ctime>
#include <functional>
#include <vector>
#include <mutex>
#include <map>
#include <ctime>
#include <cstdint>
#include <mutex>
#include <thread>
#include <vector>

#include "PerMessageDeflate.h"
#include "MoveOnlyFunction.h"

#include "PerMessageDeflate.h"
// clang-format off
struct us_timer_t;

namespace uWS {
Expand All @@ -44,7 +44,11 @@ struct alignas(16) LoopData {

/* Map from void ptr to handler */
std::map<void *, MoveOnlyFunction<void(Loop *)>> postHandlers, preHandlers;

/* Cork data */
char *corkBuffer = new char[CORK_BUFFER_SIZE];
unsigned int corkOffset = 0;
void *corkedSocket = nullptr;
bool corkedSocketIsSSL = false;
public:
LoopData() {
updateDate();
Expand All @@ -59,6 +63,55 @@ struct alignas(16) LoopData {
}
delete [] corkBuffer;
}
void* getCorkedSocket() {
return this->corkedSocket;
}

void setCorkedSocket(void *corkedSocket, bool ssl) {
this->corkedSocket = corkedSocket;
this->corkedSocketIsSSL = ssl;
}

bool isCorkedSSL() {
return this->corkedSocketIsSSL;
}

bool isCorked() {
return this->corkOffset && this->corkedSocket;
}

bool canCork() {
return this->corkedSocket == nullptr;
}

bool isCorkedWith(void* socket) {
return this->corkedSocket == socket;
}

char* getCorkSendBuffer() {
return this->corkBuffer + this->corkOffset;
}

void cleanCorkedSocket() {
this->corkedSocket = nullptr;
this->corkOffset = 0;
}

unsigned int getCorkOffset() {
return this->corkOffset;
}

void setCorkOffset(unsigned int offset) {
this->corkOffset = offset;
}

void incrementCorkedOffset(unsigned int offset) {
this->corkOffset += offset;
}

char* getCorkBuffer() {
return this->corkBuffer;
}

void updateDate() {
time_t now = time(0);
Expand Down Expand Up @@ -94,12 +147,6 @@ struct alignas(16) LoopData {
/* Good 16k for SSL perf. */
static const unsigned int CORK_BUFFER_SIZE = 16 * 1024;

/* Cork data */
char *corkBuffer = new char[CORK_BUFFER_SIZE];
unsigned int corkOffset = 0;
void *corkedSocket = nullptr;
bool corkedSocketIsSSL = false;

/* Per message deflate data */
ZlibContext *zlibContext = nullptr;
InflationStream *inflationStream = nullptr;
Expand Down
8 changes: 4 additions & 4 deletions packages/bun-uws/src/WebSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
#ifndef UWS_WEBSOCKET_H
#define UWS_WEBSOCKET_H

#include "WebSocketData.h"
#include "WebSocketProtocol.h"
#include "AsyncSocket.h"
#include "WebSocketContextData.h"
#include "WebSocketData.h"
#include "WebSocketProtocol.h"

#include <string_view>

// clang-format off
namespace uWS {

template <bool SSL, bool isServer, typename USERDATA>
Expand Down Expand Up @@ -107,7 +107,7 @@ struct WebSocket : AsyncSocket<SSL> {
WebSocketData *webSocketData = (WebSocketData *) Super::getAsyncSocketData();

/* Special path for long sends of non-compressed, non-SSL messages */
if (message.length() >= 16 * 1024 && !compress && !SSL && !webSocketData->subscriber && getBufferedAmount() == 0 && Super::getLoopData()->corkOffset == 0) {
if (message.length() >= 16 * 1024 && !compress && !SSL && !webSocketData->subscriber && getBufferedAmount() == 0 && Super::getLoopData()->getCorkOffset() == 0) {
char header[10];
int header_length = (int) protocol::formatMessage<isServer>(header, "", 0, opCode, message.length(), compress, fin);
int written = us_socket_write2(0, (struct us_socket_t *)this, header, header_length, message.data(), (int) message.length());
Expand Down
3 changes: 3 additions & 0 deletions packages/bun-uws/src/WebSocketContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ struct WebSocketContext {

/* Handle socket disconnections */
us_socket_context_on_close(SSL, getSocketContext(), [](auto *s, int code, void *reason) {
((AsyncSocket<SSL> *)s)->uncorkWithoutSending();

/* For whatever reason, if we already have emitted close event, do not emit it again */
WebSocketData *webSocketData = (WebSocketData *) (us_socket_ext(SSL, s));
if (!webSocketData->isShuttingDown) {
Expand Down Expand Up @@ -371,6 +373,7 @@ struct WebSocketContext {

/* Handle FIN, HTTP does not support half-closed sockets, so simply close */
us_socket_context_on_end(SSL, getSocketContext(), [](auto *s) {
((AsyncSocket<SSL> *)s)->uncorkWithoutSending();

/* If we get a fin, we just close I guess */
us_socket_close(SSL, (us_socket_t *) s, 0, nullptr);
Expand Down

0 comments on commit 63596c3

Please sign in to comment.