Skip to content

Commit

Permalink
Fix unbalanced ref count involving file descriptors passed to Bun.con…
Browse files Browse the repository at this point in the history
…nect (#14107)
  • Loading branch information
Jarred-Sumner authored Sep 23, 2024
1 parent 2f8c20e commit ff9560c
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 130 deletions.
9 changes: 7 additions & 2 deletions packages/bun-usockets/src/eventing/epoll_kqueue.c
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ struct us_poll_t *us_poll_resize(struct us_poll_t *p, struct us_loop_t *loop, un
return new_p;
}

void us_poll_start(struct us_poll_t *p, struct us_loop_t *loop, int events) {
int us_poll_start_rc(struct us_poll_t *p, struct us_loop_t *loop, int events) {
p->state.poll_type = us_internal_poll_type(p) | ((events & LIBUS_SOCKET_READABLE) ? POLL_TYPE_POLLING_IN : 0) | ((events & LIBUS_SOCKET_WRITABLE) ? POLL_TYPE_POLLING_OUT : 0);

#ifdef LIBUS_USE_EPOLL
Expand All @@ -379,11 +379,16 @@ void us_poll_start(struct us_poll_t *p, struct us_loop_t *loop, int events) {
do {
ret = epoll_ctl(loop->fd, EPOLL_CTL_ADD, p->state.fd, &event);
} while (IS_EINTR(ret));
return ret;
#else
kqueue_change(loop->fd, p->state.fd, 0, events, p);
return kqueue_change(loop->fd, p->state.fd, 0, events, p);
#endif
}

void us_poll_start(struct us_poll_t *p, struct us_loop_t *loop, int events) {
us_poll_start_rc(p, loop, events);
}

void us_poll_change(struct us_poll_t *p, struct us_loop_t *loop, int events) {
int old_events = us_poll_events(p);
if (old_events != events) {
Expand Down
2 changes: 2 additions & 0 deletions packages/bun-usockets/src/libusockets.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ void us_poll_init(us_poll_r p, LIBUS_SOCKET_DESCRIPTOR fd, int poll_type);

/* Start, change and stop polling for events */
void us_poll_start(us_poll_r p, us_loop_r loop, int events) nonnull_fn_decl;
/* Returns 0 if successful */
int us_poll_start_rc(us_poll_r p, us_loop_r loop, int events) nonnull_fn_decl;
void us_poll_change(us_poll_r p, us_loop_r loop, int events) nonnull_fn_decl;
void us_poll_stop(us_poll_r p, struct us_loop_t *loop) nonnull_fn_decl;

Expand Down
7 changes: 5 additions & 2 deletions packages/bun-usockets/src/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
#include <stdio.h>
#include <errno.h>

#ifndef WIN32
Expand Down Expand Up @@ -310,7 +309,11 @@ struct us_socket_t *us_socket_from_fd(struct us_socket_context_t *ctx, int socke
#else
struct us_poll_t *p1 = us_create_poll(ctx->loop, 0, sizeof(struct us_socket_t) + socket_ext_size);
us_poll_init(p1, fd, POLL_TYPE_SOCKET);
us_poll_start(p1, ctx->loop, LIBUS_SOCKET_READABLE | LIBUS_SOCKET_WRITABLE);
int rc = us_poll_start_rc(p1, ctx->loop, LIBUS_SOCKET_READABLE | LIBUS_SOCKET_WRITABLE);
if (rc != 0) {
us_poll_free(p1, ctx->loop);
return 0;
}

struct us_socket_t *s = (struct us_socket_t *) p1;
s->context = ctx;
Expand Down
110 changes: 68 additions & 42 deletions src/bun.js/api/bun/socket.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1236,46 +1236,30 @@ pub const Listener = struct {
const promise_value = promise.asValue(globalObject);
handlers_ptr.promise.set(globalObject, promise_value);

if (ssl_enabled) {
var tls = TLSSocket.new(.{
.handlers = handlers_ptr,
.this_value = .zero,
.socket = TLSSocket.Socket.detached,
.connection = connection,
.protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch bun.outOfMemory()) else null,
.server_name = server_name,
.socket_context = socket_context, // owns the socket context
});
switch (ssl_enabled) {
inline else => |is_ssl_enabled| {
const SocketType = NewSocket(is_ssl_enabled);
var socket = SocketType.new(.{
.handlers = handlers_ptr,
.this_value = .zero,
.socket = SocketType.Socket.detached,
.connection = connection,
.protos = if (protos) |p| (bun.default_allocator.dupe(u8, p) catch bun.outOfMemory()) else null,
.server_name = server_name,
.socket_context = socket_context, // owns the socket context
});

SocketType.dataSetCached(socket.getThisValue(globalObject), globalObject, default_data);

socket.doConnect(connection) catch {
socket.handleConnectError(@intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED));
return promise_value;
};

TLSSocket.dataSetCached(tls.getThisValue(globalObject), globalObject, default_data);
socket.poll_ref.ref(handlers.vm);

tls.doConnect(connection) catch {
tls.handleConnectError(@intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED));
return promise_value;
};

tls.poll_ref.ref(handlers.vm);

return promise_value;
} else {
var tcp = TCPSocket.new(.{
.handlers = handlers_ptr,
.this_value = .zero,
.socket = TCPSocket.Socket.detached,
.connection = null,
.protos = null,
.server_name = null,
.socket_context = socket_context, // owns the socket context
});

TCPSocket.dataSetCached(tcp.getThisValue(globalObject), globalObject, default_data);
tcp.doConnect(connection) catch {
tcp.handleConnectError(@intFromEnum(if (port == null) bun.C.SystemErrno.ENOENT else bun.C.SystemErrno.ECONNREFUSED));
return promise_value;
};
tcp.poll_ref.ref(handlers.vm);

return promise_value;
},
}
}
};
Expand Down Expand Up @@ -1371,9 +1355,13 @@ fn NewSocket(comptime ssl: bool) type {

pub fn doConnect(this: *This, connection: Listener.UnixOrHost) !void {
bun.assert(this.socket_context != null);
this.ref();
errdefer {
this.deref();
}

switch (connection) {
.host => |c| {
this.ref();
this.socket = try This.Socket.connectAnon(
normalizeHost(c.host),
c.port,
Expand All @@ -1382,8 +1370,6 @@ fn NewSocket(comptime ssl: bool) type {
);
},
.unix => |u| {
this.ref();

this.socket = try This.Socket.connectUnixAnon(
u,
this.socket_context.?,
Expand Down Expand Up @@ -1469,10 +1455,14 @@ fn NewSocket(comptime ssl: bool) type {

fn handleConnectError(this: *This, errno: c_int) void {
log("onConnectError({d}, {})", .{ errno, this.ref_count });
// Ensure the socket is still alive for any defer's we have
this.ref();
defer this.deref();

const needs_deref = !this.socket.isDetached();
this.socket = Socket.detached;
defer if (needs_deref) this.deref();
defer this.markInactive();
defer if (needs_deref) this.deref();

const handlers = this.handlers;
const vm = handlers.vm;
Expand Down Expand Up @@ -1500,6 +1490,11 @@ fn NewSocket(comptime ssl: bool) type {

if (callback == .zero) {
if (handlers.promise.trySwap()) |promise| {
if (this.this_value != .zero) {
this.this_value = .zero;
}
this.has_pending_activity.store(false, .release);

// reject the promise on connect() error
const err_value = err.toErrorInstance(globalObject);
promise.asPromise().?.rejectOnNextTick(globalObject, err_value);
Expand All @@ -1509,6 +1504,9 @@ fn NewSocket(comptime ssl: bool) type {
}

const this_value = this.getThisValue(globalObject);
this.this_value = .zero;
this.has_pending_activity.store(false, .release);

const err_value = err.toErrorInstance(globalObject);
const result = callback.call(globalObject, this_value, &[_]JSValue{
this_value,
Expand All @@ -1524,7 +1522,6 @@ fn NewSocket(comptime ssl: bool) type {
var promise = val.asPromise().?;
const err_ = err.toErrorInstance(globalObject);
promise.rejectOnNextTickAsHandled(globalObject, err_);
this.has_pending_activity.store(false, .release);
}
}
pub fn onConnectError(this: *This, _: Socket, errno: c_int) void {
Expand Down Expand Up @@ -1566,6 +1563,10 @@ fn NewSocket(comptime ssl: bool) type {
}

pub fn onOpen(this: *This, socket: Socket) void {
// Ensure the socket remains alive until this is finished
this.ref();
defer this.deref();

log("onOpen {} {}", .{ this.socket.isDetached(), this.ref_count });
// update the internal socket instance to the one that was just connected
// This socket must be replaced because the previous one is a connecting socket not a uSockets socket
Expand Down Expand Up @@ -1664,6 +1665,9 @@ fn NewSocket(comptime ssl: bool) type {
JSC.markBinding(@src());
log("onEnd", .{});
if (this.socket.isDetached()) return;
// Ensure the socket remains alive until this is finished
this.ref();
defer this.deref();

const handlers = this.handlers;

Expand Down Expand Up @@ -4121,3 +4125,25 @@ pub fn createNodeTLSBinding(global: *JSC.JSGlobalObject) JSC.JSValue {
JSC.JSFunction.create(global, "isNamedPipeSocket", JSC.toJSHostFunction(jsIsNamedPipeSocket), 1, .{}),
});
}

pub fn jsCreateSocketPair(global: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(JSC.conv) JSValue {
JSC.markBinding(@src());

if (Environment.isWindows) {
global.throw("Not implemented on Windows", .{});
return .zero;
}

var fds_: [2]std.c.fd_t = .{ 0, 0 };
const rc = std.c.socketpair(std.posix.AF.UNIX, std.posix.SOCK.STREAM, 0, &fds_);
if (rc != 0) {
const err = bun.sys.Error.fromCode(bun.C.getErrno(rc), .socketpair);
global.throwValue(err.toJSC(global));
return .zero;
}

const array = JSC.JSValue.createEmptyArray(global, 2);
array.putIndex(global, 0, JSC.jsNumber(fds_[0]));
array.putIndex(global, 1, JSC.jsNumber(fds_[1]));
return array;
}
Loading

0 comments on commit ff9560c

Please sign in to comment.