diff --git a/src/bun.js/api/bun/ssl_wrapper.zig b/src/bun.js/api/bun/ssl_wrapper.zig index 9d0fc0d1d3542..48c95c346dffc 100644 --- a/src/bun.js/api/bun/ssl_wrapper.zig +++ b/src/bun.js/api/bun/ssl_wrapper.zig @@ -154,15 +154,16 @@ pub fn SSLWrapper(comptime T: type) type { // we sent the shutdown this.flags.sent_ssl_shutdown = ret >= 0; - defer if (ret < 0) { + if (ret < 0) { const err = BoringSSL.SSL_get_error(ssl, ret); BoringSSL.ERR_clear_error(); if (err == BoringSSL.SSL_ERROR_SSL or err == BoringSSL.SSL_ERROR_SYSCALL) { this.flags.fatal_error = true; this.triggerCloseCallback(); + return false; } - }; + } return ret == 1; // truly closed } @@ -424,7 +425,6 @@ pub fn SSLWrapper(comptime T: type) type { if (read > 0) { this.triggerDataCallback(buffer[0..read]); } - this.triggerCloseCallback(); return false; } else { diff --git a/src/http.zig b/src/http.zig index baa5b5199bbe8..1cfb72daee52f 100644 --- a/src/http.zig +++ b/src/http.zig @@ -210,13 +210,18 @@ const ProxyTunnel = struct { none: void, } = .{ .none = {} }, write_buffer: bun.io.StreamBuffer = .{}, + ref_count: u32 = 1, const ProxyTunnelWrapper = SSLWrapper(*HTTPClient); + usingnamespace bun.NewRefCounted(ProxyTunnel, ProxyTunnel.deinit); + fn onOpen(this: *HTTPClient) void { this.state.response_stage = .proxy_handshake; this.state.request_stage = .proxy_handshake; - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { + proxy.ref(); + defer proxy.deref(); if (proxy.wrapper) |*wrapper| { var ssl_ptr = wrapper.ssl orelse return; const _hostname = this.hostname orelse this.url.hostname; @@ -244,7 +249,9 @@ const ProxyTunnel = struct { if (decoded_data.len == 0) return; log("onData decoded {}", .{decoded_data.len}); - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { + proxy.ref(); + defer proxy.deref(); switch (this.state.response_stage) { .body => { if (decoded_data.len == 0) return; @@ -286,16 +293,13 @@ const ProxyTunnel = struct { return; } }, - .proxy_decoded_headers, .proxy_headers => { - this.flags.proxy_tunneling = false; - this.state.response_stage = .proxy_decoded_headers; - + .proxy_headers => { switch (proxy.socket) { .ssl => |socket| { - this.onData(true, decoded_data, &http_thread.https_context, socket); + this.handleOnDataHeaders(true, decoded_data, &http_thread.https_context, socket); }, .tcp => |socket| { - this.onData(false, decoded_data, &http_thread.http_context, socket); + this.handleOnDataHeaders(false, decoded_data, &http_thread.http_context, socket); }, .none => {}, } @@ -309,7 +313,9 @@ const ProxyTunnel = struct { } fn onHandshake(this: *HTTPClient, handshake_success: bool, ssl_error: uws.us_bun_verify_error_t) void { - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { + proxy.ref(); + defer proxy.deref(); this.state.response_stage = .proxy_headers; this.state.request_stage = .proxy_headers; this.state.request_sent_len = 0; @@ -375,7 +381,7 @@ const ProxyTunnel = struct { } pub fn write(this: *HTTPClient, encoded_data: []const u8) void { - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { const written = switch (proxy.socket) { .ssl => |socket| socket.write(encoded_data, true), .tcp => |socket| socket.write(encoded_data, true), @@ -390,7 +396,10 @@ const ProxyTunnel = struct { } fn onClose(this: *HTTPClient) void { - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { + proxy.ref(); + // defer the proxy deref the proxy tunnel may still be in use after triggering the close callback + defer http_thread.scheduleProxyDeref(proxy); const err = proxy.shutdown_err; switch (proxy.socket) { .ssl => |socket| { @@ -401,41 +410,41 @@ const ProxyTunnel = struct { }, .none => {}, } + proxy.detachSocket(); } } - fn start(ctx: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket, ssl_options: JSC.API.ServerConfig.SSLConfig) void { - ctx.proxy_tunnel = .{}; - if (ctx.proxy_tunnel) |*this| { - if (is_ssl) { - this.socket = .{ .ssl = socket }; - } else { - this.socket = .{ .tcp = socket }; + fn start(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket, ssl_options: JSC.API.ServerConfig.SSLConfig) void { + const proxy_tunnel = ProxyTunnel.new(.{}); + + var custom_options = ssl_options; + // we always request the cert so we can verify it and also we manually abort the connection if the hostname doesn't match + custom_options.reject_unauthorized = 0; + custom_options.request_cert = 1; + proxy_tunnel.wrapper = SSLWrapper(*HTTPClient).init(custom_options, true, .{ + .onOpen = ProxyTunnel.onOpen, + .onData = ProxyTunnel.onData, + .onHandshake = ProxyTunnel.onHandshake, + .onClose = ProxyTunnel.onClose, + .write = ProxyTunnel.write, + .ctx = this, + }) catch |err| { + if (err == error.OutOfMemory) { + bun.outOfMemory(); } - var custom_options = ssl_options; - // we always request the cert so we can verify it and also we manually abort the connection if the hostname doesn't match - custom_options.reject_unauthorized = 0; - custom_options.request_cert = 1; - this.wrapper = SSLWrapper(*HTTPClient).init(custom_options, true, .{ - .onOpen = ProxyTunnel.onOpen, - .onData = ProxyTunnel.onData, - .onHandshake = ProxyTunnel.onHandshake, - .onClose = ProxyTunnel.onClose, - .write = ProxyTunnel.write, - .ctx = ctx, - }) catch |err| { - if (err == error.OutOfMemory) { - bun.outOfMemory(); - } - // invalid TLS Options - this.socket = .{ .none = {} }; - this.wrapper = null; - ctx.closeAndFail(error.ConnectionRefused, is_ssl, socket); - return; - }; - this.wrapper.?.start(); + // invalid TLS Options + proxy_tunnel.detachAndDeref(); + this.closeAndFail(error.ConnectionRefused, is_ssl, socket); + return; + }; + this.proxy_tunnel = proxy_tunnel; + if (is_ssl) { + proxy_tunnel.socket = .{ .ssl = socket }; + } else { + proxy_tunnel.socket = .{ .tcp = socket }; } + proxy_tunnel.wrapper.?.start(); } pub fn close(this: *ProxyTunnel, err: anyerror) void { @@ -447,6 +456,8 @@ const ProxyTunnel = struct { } pub fn onWritable(this: *ProxyTunnel, comptime is_ssl: bool, socket: NewHTTPContext(is_ssl).HTTPSocket) void { + this.ref(); + defer this.deref(); const encoded_data = this.write_buffer.slice(); if (encoded_data.len == 0) { return; @@ -465,6 +476,8 @@ const ProxyTunnel = struct { } pub fn receiveData(this: *ProxyTunnel, buf: []const u8) void { + this.ref(); + defer this.deref(); if (this.wrapper) |*wrapper| { wrapper.receiveData(buf); } @@ -477,6 +490,15 @@ const ProxyTunnel = struct { return 0; } + pub fn detachSocket(this: *ProxyTunnel) void { + this.socket = .{ .none = {} }; + } + + pub fn detachAndDeref(this: *ProxyTunnel) void { + this.detachSocket(); + this.deref(); + } + pub fn deinit(this: *ProxyTunnel) void { this.socket = .{ .none = {} }; if (this.wrapper) |*wrapper| { @@ -484,6 +506,7 @@ const ProxyTunnel = struct { this.wrapper = null; } this.write_buffer.deinit(); + this.destroy(); } }; @@ -951,6 +974,8 @@ pub const HTTPThread = struct { queued_shutdowns: std.ArrayListUnmanaged(ShutdownMessage) = std.ArrayListUnmanaged(ShutdownMessage){}, queued_shutdowns_lock: bun.Lock = .{}, + queued_proxy_deref: std.ArrayListUnmanaged(*ProxyTunnel) = std.ArrayListUnmanaged(*ProxyTunnel){}, + has_awoken: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), timer: std.time.Timer, @@ -1039,8 +1064,10 @@ pub const HTTPThread = struct { for (custom_ssl_context_map.keys()) |other_config| { if (requested_config.isSame(other_config)) { // we free the callers config since we have a existing one - requested_config.deinit(); - bun.default_allocator.destroy(requested_config); + if (requested_config != client.tls_props) { + requested_config.deinit(); + bun.default_allocator.destroy(requested_config); + } client.tls_props = other_config; if (client.http_proxy) |url| { return try custom_ssl_context_map.get(other_config).?.connect(client, url.hostname, url.getPortAuto()); @@ -1052,8 +1079,10 @@ pub const HTTPThread = struct { // we need the config so dont free it var custom_context = try bun.default_allocator.create(NewHTTPContext(is_ssl)); custom_context.initWithClientConfig(client) catch |err| { - requested_config.deinit(); client.tls_props = null; + + requested_config.deinit(); + bun.default_allocator.destroy(requested_config); bun.default_allocator.destroy(custom_context); return err; }; @@ -1104,6 +1133,10 @@ pub const HTTPThread = struct { this.queued_shutdowns.clearRetainingCapacity(); } + while (this.queued_proxy_deref.popOrNull()) |http| { + http.deref(); + } + var count: usize = 0; var active = AsyncHTTP.active_requests_count.load(.monotonic); const max = AsyncHTTP.max_simultaneous_requests.load(.monotonic); @@ -1174,6 +1207,15 @@ pub const HTTPThread = struct { this.loop.loop.wakeup(); } + pub fn scheduleProxyDeref(this: *@This(), proxy: *ProxyTunnel) void { + // this is always called on the http thread + { + this.queued_proxy_deref.append(bun.default_allocator, proxy) catch bun.outOfMemory(); + } + if (this.has_awoken.load(.monotonic)) + this.loop.loop.wakeup(); + } + pub fn wakeup(this: *@This()) void { if (this.has_awoken.load(.monotonic)) this.loop.loop.wakeup(); @@ -1360,11 +1402,16 @@ pub fn onClose( log("Closed {s}\n", .{client.url.href}); // the socket is closed, we need to unregister the abort tracker client.unregisterAbortTracker(); + if (client.signals.get(.aborted)) { client.fail(error.Aborted); return; } - + if (client.proxy_tunnel) |tunnel| { + client.proxy_tunnel = null; + // always detach the socket from the tunnel onClose (timeout, connectError will call fail that will do the same) + tunnel.detachAndDeref(); + } const in_progress = client.state.stage != .done and client.state.stage != .fail and client.state.flags.is_redirect_pending == false; if (in_progress) { @@ -1407,6 +1454,7 @@ pub fn onTimeout( ) void { if (client.flags.disable_timeout) return; log("Timeout {s}\n", .{client.url.href}); + defer NewHTTPContext(is_ssl).terminateSocket(socket); client.fail(error.Timeout); } @@ -1535,7 +1583,6 @@ pub const HTTPStage = enum { done, proxy_handshake, proxy_headers, - proxy_decoded_headers, proxy_body, }; @@ -1908,7 +1955,7 @@ request_content_len_buf: ["-4294967295".len]u8 = undefined, http_proxy: ?URL = null, proxy_authorization: ?[]u8 = null, -proxy_tunnel: ?ProxyTunnel = null, +proxy_tunnel: ?*ProxyTunnel = null, signals: Signals = .{}, async_http_id: u32 = 0, hostname: ?[]u8 = null, @@ -1923,10 +1970,9 @@ pub fn deinit(this: *HTTPClient) void { this.allocator.free(auth); this.proxy_authorization = null; } - if (this.proxy_tunnel != null) { - var tunnel = this.proxy_tunnel.?; + if (this.proxy_tunnel) |tunnel| { this.proxy_tunnel = null; - tunnel.deinit(); + tunnel.detachAndDeref(); } this.unix_socket_path.deinit(); this.unix_socket_path = JSC.ZigString.Slice.empty; @@ -2078,6 +2124,8 @@ pub const AsyncHTTP = struct { verbose: HTTPVerboseLevel = .none, client: HTTPClient = undefined, + waitingDeffered: bool = false, + finalized: bool = false, err: ?anyerror = null, async_http_id: u32 = 0, @@ -2654,10 +2702,9 @@ pub fn doRedirect( this.state.reset(this.allocator); // also reset proxy to redirect this.flags.proxy_tunneling = false; - if (this.proxy_tunnel != null) { - var tunnel = this.proxy_tunnel.?; + if (this.proxy_tunnel) |tunnel| { this.proxy_tunnel = null; - tunnel.deinit(); + tunnel.detachAndDeref(); } return this.start(.{ .bytes = request_body }, body_out_str); @@ -2782,7 +2829,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s } } - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { proxy.onWritable(is_ssl, socket); } @@ -2952,7 +2999,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s if (this.state.original_request_body != .bytes) { @panic("sendfile is only supported without SSL. This code should never have been reached!"); } - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { this.setTimeout(socket, 5); const to_send = this.state.request_body; @@ -2968,7 +3015,7 @@ pub fn onWritable(this: *HTTPClient, comptime is_first_call: bool, comptime is_s } }, .proxy_headers => { - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { this.setTimeout(socket, 5); var stack_fallback = std.heap.stackFallback(16384, default_allocator); const allocator = stack_fallback.get(); @@ -3073,157 +3120,166 @@ inline fn handleShortRead( this.setTimeout(socket, 5); } -pub fn onData( + +pub fn handleOnDataHeaders( this: *HTTPClient, comptime is_ssl: bool, incoming_data: []const u8, ctx: *NewHTTPContext(is_ssl), socket: NewHTTPContext(is_ssl).HTTPSocket, ) void { - log("onData {}", .{incoming_data.len}); - if (this.signals.get(.aborted)) { - this.closeAndAbort(is_ssl, socket); + var to_read = incoming_data; + var amount_read: usize = 0; + var needs_move = true; + if (this.state.response_message_buffer.list.items.len > 0) { + // this one probably won't be another chunk, so we use appendSliceExact() to avoid over-allocating + this.state.response_message_buffer.appendSliceExact(incoming_data) catch bun.outOfMemory(); + to_read = this.state.response_message_buffer.list.items; + needs_move = false; + } + + // we reset the pending_response each time wich means that on parse error this will be always be empty + this.state.pending_response = picohttp.Response{}; + + // minimal http/1.1 request size is 16 bytes without headers and 26 with Host header + // if is less than 16 will always be a ShortRead + if (to_read.len < 16) { + this.handleShortRead(is_ssl, incoming_data, socket, needs_move); return; } - switch (this.state.response_stage) { - .pending, .headers, .proxy_decoded_headers => { - var to_read = incoming_data; - var amount_read: usize = 0; - var needs_move = true; - if (this.state.response_message_buffer.list.items.len > 0) { - // this one probably won't be another chunk, so we use appendSliceExact() to avoid over-allocating - this.state.response_message_buffer.appendSliceExact(incoming_data) catch bun.outOfMemory(); - to_read = this.state.response_message_buffer.list.items; - needs_move = false; - } + var response = picohttp.Response.parseParts( + to_read, + &shared_response_headers_buf, + &amount_read, + ) catch |err| { + switch (err) { + error.ShortRead => { + this.handleShortRead(is_ssl, incoming_data, socket, needs_move); + }, + else => { + this.closeAndFail(err, is_ssl, socket); + }, + } + return; + }; - // we reset the pending_response each time wich means that on parse error this will be always be empty - this.state.pending_response = picohttp.Response{}; + // we save the successful parsed response + this.state.pending_response = response; - // minimal http/1.1 request size is 16 bytes without headers and 26 with Host header - // if is less than 16 will always be a ShortRead - if (to_read.len < 16) { - this.handleShortRead(is_ssl, incoming_data, socket, needs_move); - return; - } + const body_buf = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..]; + // handle the case where we have a 100 Continue + if (response.status_code == 100) { + // we still can have the 200 OK in the same buffer sometimes + if (body_buf.len > 0) { + this.onData(is_ssl, body_buf, ctx, socket); + } + return; + } + const should_continue = this.handleResponseMetadata( + &response, + ) catch |err| { + this.closeAndFail(err, is_ssl, socket); + return; + }; - var response = picohttp.Response.parseParts( - to_read, - &shared_response_headers_buf, - &amount_read, - ) catch |err| { - switch (err) { - error.ShortRead => { - this.handleShortRead(is_ssl, incoming_data, socket, needs_move); - }, - else => { - this.closeAndFail(err, is_ssl, socket); - }, - } - return; - }; + if (this.state.content_encoding_i < response.headers.len and !this.state.flags.did_set_content_encoding) { + // if it compressed with this header, it is no longer because we will decompress it + const mutable_headers = std.ArrayListUnmanaged(picohttp.Header){ .items = response.headers, .capacity = response.headers.len }; + this.state.flags.did_set_content_encoding = true; + response.headers = mutable_headers.items; + this.state.content_encoding_i = std.math.maxInt(@TypeOf(this.state.content_encoding_i)); + // we need to reset the pending response because we removed a header + this.state.pending_response = response; + } - // we save the successful parsed response - this.state.pending_response = response; + if (should_continue == .finished) { + if (this.state.flags.is_redirect_pending) { + this.doRedirect(is_ssl, ctx, socket); + return; + } + // this means that the request ended + // clone metadata and return the progress at this point + this.cloneMetadata(); + // if is chuncked but no body is expected we mark the last chunk + this.state.flags.received_last_chunk = true; + // if is not we ignore the content_length + this.state.content_length = 0; + this.progressUpdate(is_ssl, ctx, socket); + return; + } - const body_buf = to_read[@min(@as(usize, @intCast(response.bytes_read)), to_read.len)..]; - // handle the case where we have a 100 Continue - if (response.status_code == 100) { - // we still can have the 200 OK in the same buffer sometimes - if (body_buf.len > 0) { - this.onData(is_ssl, body_buf, ctx, socket); - } - return; - } - const should_continue = this.handleResponseMetadata( - &response, - ) catch |err| { + if (this.flags.proxy_tunneling and this.proxy_tunnel == null) { + // we are proxing we dont need to cloneMetadata yet + this.startProxyHandshake(is_ssl, socket); + return; + } + + // we have body data incoming so we clone metadata and keep going + this.cloneMetadata(); + + if (body_buf.len == 0) { + // no body data yet, but we can report the headers + if (this.signals.get(.header_progress)) { + this.progressUpdate(is_ssl, ctx, socket); + } + return; + } + + if (this.state.response_stage == .body) { + { + const report_progress = this.handleResponseBody(body_buf, true) catch |err| { this.closeAndFail(err, is_ssl, socket); return; }; - if (this.state.content_encoding_i < response.headers.len and !this.state.flags.did_set_content_encoding) { - // if it compressed with this header, it is no longer because we will decompress it - const mutable_headers = std.ArrayListUnmanaged(picohttp.Header){ .items = response.headers, .capacity = response.headers.len }; - this.state.flags.did_set_content_encoding = true; - response.headers = mutable_headers.items; - this.state.content_encoding_i = std.math.maxInt(@TypeOf(this.state.content_encoding_i)); - // we need to reset the pending response because we removed a header - this.state.pending_response = response; - } - - if (should_continue == .finished) { - if (this.state.flags.is_redirect_pending) { - this.doRedirect(is_ssl, ctx, socket); - return; - } - // this means that the request ended - // clone metadata and return the progress at this point - this.cloneMetadata(); - // if is chuncked but no body is expected we mark the last chunk - this.state.flags.received_last_chunk = true; - // if is not we ignore the content_length - this.state.content_length = 0; + if (report_progress) { this.progressUpdate(is_ssl, ctx, socket); return; } - - if (this.flags.proxy_tunneling and this.proxy_tunnel == null) { - // we are proxing we dont need to cloneMetadata yet - this.startProxyHandshake(is_ssl, socket); + } + } else if (this.state.response_stage == .body_chunk) { + this.setTimeout(socket, 5); + { + const report_progress = this.handleResponseBodyChunkedEncoding(body_buf) catch |err| { + this.closeAndFail(err, is_ssl, socket); return; - } - - // we have body data incoming so we clone metadata and keep going - this.cloneMetadata(); + }; - if (body_buf.len == 0) { - // no body data yet, but we can report the headers - if (this.signals.get(.header_progress)) { - this.progressUpdate(is_ssl, ctx, socket); - } + if (report_progress) { + this.progressUpdate(is_ssl, ctx, socket); return; } + } + } - if (this.state.response_stage == .body) { - { - const report_progress = this.handleResponseBody(body_buf, true) catch |err| { - this.closeAndFail(err, is_ssl, socket); - return; - }; - - if (report_progress) { - this.progressUpdate(is_ssl, ctx, socket); - return; - } - } - } else if (this.state.response_stage == .body_chunk) { - this.setTimeout(socket, 5); - { - const report_progress = this.handleResponseBodyChunkedEncoding(body_buf) catch |err| { - this.closeAndFail(err, is_ssl, socket); - return; - }; - - if (report_progress) { - this.progressUpdate(is_ssl, ctx, socket); - return; - } - } - } + // if not reported we report partially now + if (this.signals.get(.header_progress)) { + this.progressUpdate(is_ssl, ctx, socket); + return; + } +} +pub fn onData( + this: *HTTPClient, + comptime is_ssl: bool, + incoming_data: []const u8, + ctx: *NewHTTPContext(is_ssl), + socket: NewHTTPContext(is_ssl).HTTPSocket, +) void { + log("onData {}", .{incoming_data.len}); + if (this.signals.get(.aborted)) { + this.closeAndAbort(is_ssl, socket); + return; + } - // if not reported we report partially now - if (this.signals.get(.header_progress)) { - this.progressUpdate(is_ssl, ctx, socket); - return; - } + switch (this.state.response_stage) { + .pending, .headers => { + this.handleOnDataHeaders(is_ssl, incoming_data, ctx, socket); }, - .body => { this.setTimeout(socket, 5); - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { proxy.receiveData(incoming_data); } else { const report_progress = this.handleResponseBody(incoming_data, false) catch |err| { @@ -3241,7 +3297,7 @@ pub fn onData( .body_chunk => { this.setTimeout(socket, 5); - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { proxy.receiveData(incoming_data); } else { const report_progress = this.handleResponseBodyChunkedEncoding(incoming_data) catch |err| { @@ -3259,7 +3315,7 @@ pub fn onData( .fail => {}, .proxy_headers, .proxy_handshake => { this.setTimeout(socket, 5); - if (this.proxy_tunnel) |*proxy| { + if (this.proxy_tunnel) |proxy| { proxy.receiveData(incoming_data); } return; @@ -3278,10 +3334,11 @@ pub fn closeAndAbort(this: *HTTPClient, comptime is_ssl: bool, socket: NewHTTPCo fn fail(this: *HTTPClient, err: anyerror) void { this.unregisterAbortTracker(); - if (this.proxy_tunnel != null) { - var tunnel = this.proxy_tunnel.?; + + if (this.proxy_tunnel) |tunnel| { this.proxy_tunnel = null; - tunnel.deinit(); + // always detach the socket from the tunnel in case of fail + tunnel.detachAndDeref(); } if (this.state.stage != .done and this.state.stage != .fail) { this.state.request_stage = .fail; diff --git a/test/js/bun/http/proxy.test.ts b/test/js/bun/http/proxy.test.ts index 4099f2d2881ea..3d297df5cf491 100644 --- a/test/js/bun/http/proxy.test.ts +++ b/test/js/bun/http/proxy.test.ts @@ -1,5 +1,5 @@ import type { Server } from "bun"; -import { afterAll, beforeAll, expect, test } from "bun:test"; +import { afterAll, beforeAll, expect, test, describe } from "bun:test"; import { tls as tlsCert } from "harness"; import { once } from "node:events"; import net from "node:net"; @@ -131,6 +131,104 @@ for (const proxy_tls of [false, true]) { } } +for (const server_tls of [false, true]) { + describe(`proxy can handle redirects with ${server_tls ? "TLS" : "non-TLS"} server`, () => { + test("with empty body #12007", async () => { + using server = Bun.serve({ + tls: server_tls ? tlsCert : undefined, + port: 0, + async fetch(req) { + if (req.url.endsWith("/bunbun")) { + return Response.redirect("/bun", 302); + } + if (req.url.endsWith("/bun")) { + return Response.redirect("/", 302); + } + return new Response("", { status: 403 }); + }, + }); + const response = await fetch(`${server.url.origin}/bunbun`, { + proxy: httpsProxyServer.url, + tls: { + cert: tlsCert.cert, + rejectUnauthorized: false, + }, + }); + expect(response.ok).toBe(false); + expect(response.status).toBe(403); + expect(response.statusText).toBe("Forbidden"); + }); + + test("with body #12007", async () => { + using server = Bun.serve({ + tls: server_tls ? tlsCert : undefined, + port: 0, + async fetch(req) { + if (req.url.endsWith("/bunbun")) { + return new Response("Hello, bunbun", { status: 302, headers: { Location: "/bun" } }); + } + if (req.url.endsWith("/bun")) { + return new Response("Hello, bun", { status: 302, headers: { Location: "/" } }); + } + return new Response("BUN!", { status: 200 }); + }, + }); + const response = await fetch(`${server.url.origin}/bunbun`, { + proxy: httpsProxyServer.url, + tls: { + cert: tlsCert.cert, + rejectUnauthorized: false, + }, + }); + expect(response.ok).toBe(true); + expect(response.status).toBe(200); + expect(response.statusText).toBe("OK"); + + const result = await response.text(); + expect(result).toBe("BUN!"); + }); + + test("with chunked body #12007", async () => { + using server = Bun.serve({ + tls: server_tls ? tlsCert : undefined, + port: 0, + async fetch(req) { + async function* body() { + await Bun.sleep(100); + yield "bun"; + await Bun.sleep(100); + yield "bun"; + await Bun.sleep(100); + yield "bun"; + await Bun.sleep(100); + yield "bun"; + } + if (req.url.endsWith("/bunbun")) { + return new Response(body, { status: 302, headers: { Location: "/bun" } }); + } + if (req.url.endsWith("/bun")) { + return new Response(body, { status: 302, headers: { Location: "/" } }); + } + return new Response(body, { status: 200 }); + }, + }); + const response = await fetch(`${server.url.origin}/bunbun`, { + proxy: httpsProxyServer.url, + tls: { + cert: tlsCert.cert, + rejectUnauthorized: false, + }, + }); + expect(response.ok).toBe(true); + expect(response.status).toBe(200); + expect(response.statusText).toBe("OK"); + + const result = await response.text(); + expect(result).toBe("bunbunbunbun"); + }); + }); +} + test("unsupported protocol", async () => { expect( fetch("https://httpbin.org/get", {