From ba930ad54a95bb35e44edb69420c8aa355c362ec Mon Sep 17 00:00:00 2001 From: Ciro Spaciari Date: Sat, 18 Jan 2025 16:03:42 -0800 Subject: [PATCH] feat(sql) transactions, savepoints, connection pooling and reserve (#16381) --- packages/bun-types/bun.d.ts | 292 ++++++ src/bun.js/bindings/BunObject.cpp | 15 +- src/bun.js/bindings/ErrorCode.ts | 3 + src/bun.js/event_loop.zig | 1 - src/bun.js/module_loader.zig | 10 - src/js/builtins/BunBuiltinNames.h | 1 + src/js/bun/sql.ts | 1430 ++++++++++++++++++++++++++--- src/sql/postgres.zig | 32 +- test/js/sql/sql.test.ts | 355 ++++--- 9 files changed, 1856 insertions(+), 283 deletions(-) diff --git a/packages/bun-types/bun.d.ts b/packages/bun-types/bun.d.ts index eded771c7b8a90..ec3529efc1886e 100644 --- a/packages/bun-types/bun.d.ts +++ b/packages/bun-types/bun.d.ts @@ -1995,6 +1995,298 @@ declare module "bun" { */ stat(path: string, options?: S3Options): Promise; }; + /** + * Configuration options for SQL client connection and behavior + * @example + * const config: SQLOptions = { + * host: 'localhost', + * port: 5432, + * user: 'dbuser', + * password: 'secretpass', + * database: 'myapp', + * idleTimeout: 30000, + * max: 20, + * onconnect: (client) => { + * console.log('Connected to database'); + * } + * }; + */ + type SQLOptions = { + /** Connection URL (can be string or URL object) */ + url: URL | string; + /** Database server hostname */ + host: string; + /** Database server port number */ + port: number | string; + /** Database user for authentication */ + user: string; + /** Database password for authentication */ + password: string; + /** Name of the database to connect to */ + database: string; + /** Database adapter/driver to use */ + adapter: string; + /** Maximum time in milliseconds to wait for connection to become available */ + idleTimeout: number; + /** Maximum time in milliseconds to wait when establishing a connection */ + connectionTimeout: number; + /** Maximum lifetime in milliseconds of a connection */ + maxLifetime: number; + /** Whether to use TLS/SSL for the connection */ + tls: boolean; + /** Callback function executed when a connection is established */ + onconnect: (client: SQL) => void; + /** Callback function executed when a connection is closed */ + onclose: (client: SQL) => void; + /** Maximum number of connections in the pool */ + max: number; + }; + + /** + * Represents a SQL query that can be executed, with additional control methods + * Extends Promise to allow for async/await usage + */ + interface SQLQuery extends Promise { + /** Indicates if the query is currently executing */ + active: boolean; + /** Indicates if the query has been cancelled */ + cancelled: boolean; + /** Cancels the executing query */ + cancel(): SQLQuery; + /** Executes the query */ + execute(): SQLQuery; + /** Returns the raw query result */ + raw(): SQLQuery; + /** Returns only the values from the query result */ + values(): SQLQuery; + } + + /** + * Callback function type for transaction contexts + * @param sql Function to execute SQL queries within the transaction + */ + type SQLContextCallback = (sql: (strings: string, ...values: any[]) => SQLQuery | Array) => Promise; + + /** + * Main SQL client interface providing connection and transaction management + */ + interface SQL { + /** Creates a new SQL client instance + * @example + * const sql = new SQL("postgres://localhost:5432/mydb"); + * const sql = new SQL(new URL("postgres://localhost:5432/mydb")); + */ + new (connectionString: string | URL): SQL; + /** Creates a new SQL client instance with options + * @example + * const sql = new SQL("postgres://localhost:5432/mydb", { idleTimeout: 1000 }); + */ + new (connectionString: string | URL, options: SQLOptions): SQL; + /** Creates a new SQL client instance with options + * @example + * const sql = new SQL({ url: "postgres://localhost:5432/mydb", idleTimeout: 1000 }); + */ + new (options?: SQLOptions): SQL; + /** Executes a SQL query using template literals + * @example + * const [user] = await sql`select * from users where id = ${1}`; + */ + (strings: string, ...values: any[]): SQLQuery; + /** Commits a distributed transaction also know as prepared transaction in postgres or XA transaction in MySQL + * @example + * await sql.commitDistributed("my_distributed_transaction"); + */ + commitDistributed(name: string): Promise; + /** Rolls back a distributed transaction also know as prepared transaction in postgres or XA transaction in MySQL + * @example + * await sql.rollbackDistributed("my_distributed_transaction"); + */ + rollbackDistributed(name: string): Promise; + /** Waits for the database connection to be established + * @example + * await sql.connect(); + */ + connect(): Promise; + /** Closes the database connection with optional timeout in seconds + * @example + * await sql.close({ timeout: 1 }); + */ + close(options?: { timeout?: number }): Promise; + /** Closes the database connection with optional timeout in seconds + * @alias close + * @example + * await sql.end({ timeout: 1 }); + */ + end(options?: { timeout?: number }): Promise; + /** Flushes any pending operations */ + flush(): void; + /** The reserve method pulls out a connection from the pool, and returns a client that wraps the single connection. + * This can be used for running queries on an isolated connection. + * Calling reserve in a reserved Sql will return a new reserved connection, not the same connection (behavior matches postgres package). + * @example + * const reserved = await sql.reserve(); + * await reserved`select * from users`; + * await reserved.release(); + * // with in a production scenario would be something more like + * const reserved = await sql.reserve(); + * try { + * // ... queries + * } finally { + * await reserved.release(); + * } + * //To make it simpler bun supportsSymbol.dispose and Symbol.asyncDispose + * { + * // always release after context (safer) + * using reserved = await sql.reserve() + * await reserved`select * from users` + * } + */ + reserve(): Promise; + /** Begins a new transaction + * Will reserve a connection for the transaction and supply a scoped sql instance for all transaction uses in the callback function. sql.begin will resolve with the returned value from the callback function. + * BEGIN is automatically sent with the optional options, and if anything fails ROLLBACK will be called so the connection can be released and execution can continue. + * @example + * const [user, account] = await sql.begin(async sql => { + * const [user] = await sql` + * insert into users ( + * name + * ) values ( + * 'Murray' + * ) + * returning * + * ` + * const [account] = await sql` + * insert into accounts ( + * user_id + * ) values ( + * ${ user.user_id } + * ) + * returning * + * ` + * return [user, account] + * }) + */ + begin(fn: SQLContextCallback): Promise; + /** Begins a new transaction with options + * Will reserve a connection for the transaction and supply a scoped sql instance for all transaction uses in the callback function. sql.begin will resolve with the returned value from the callback function. + * BEGIN is automatically sent with the optional options, and if anything fails ROLLBACK will be called so the connection can be released and execution can continue. + * @example + * const [user, account] = await sql.begin("read write", async sql => { + * const [user] = await sql` + * insert into users ( + * name + * ) values ( + * 'Murray' + * ) + * returning * + * ` + * const [account] = await sql` + * insert into accounts ( + * user_id + * ) values ( + * ${ user.user_id } + * ) + * returning * + * ` + * return [user, account] + * }) + */ + begin(options: string, fn: SQLContextCallback): Promise; + /** Alternative method to begin a transaction + * Will reserve a connection for the transaction and supply a scoped sql instance for all transaction uses in the callback function. sql.transaction will resolve with the returned value from the callback function. + * BEGIN is automatically sent with the optional options, and if anything fails ROLLBACK will be called so the connection can be released and execution can continue. + * @alias begin + * @example + * const [user, account] = await sql.transaction(async sql => { + * const [user] = await sql` + * insert into users ( + * name + * ) values ( + * 'Murray' + * ) + * returning * + * ` + * const [account] = await sql` + * insert into accounts ( + * user_id + * ) values ( + * ${ user.user_id } + * ) + * returning * + * ` + * return [user, account] + * }) + */ + transaction(fn: SQLContextCallback): Promise; + /** Alternative method to begin a transaction with options + * Will reserve a connection for the transaction and supply a scoped sql instance for all transaction uses in the callback function. sql.transaction will resolve with the returned value from the callback function. + * BEGIN is automatically sent with the optional options, and if anything fails ROLLBACK will be called so the connection can be released and execution can continue. + * @alias begin + * @example + * const [user, account] = await sql.transaction("read write", async sql => { + * const [user] = await sql` + * insert into users ( + * name + * ) values ( + * 'Murray' + * ) + * returning * + * ` + * const [account] = await sql` + * insert into accounts ( + * user_id + * ) values ( + * ${ user.user_id } + * ) + * returning * + * ` + * return [user, account] + * }) + */ + transaction(options: string, fn: SQLContextCallback): Promise; + /** Begins a distributed transaction + * Also know as Two-Phase Commit, in a distributed transaction, Phase 1 involves the coordinator preparing nodes by ensuring data is written and ready to commit, while Phase 2 finalizes with nodes committing or rolling back based on the coordinator's decision, ensuring durability and releasing locks. + * In PostgreSQL and MySQL distributed transactions persist beyond the original session, allowing privileged users or coordinators to commit/rollback them, ensuring support for distributed transactions, recovery, and administrative tasks. + * beginDistributed will automatic rollback if any exception are not caught, and you can commit and rollback later if everything goes well. + * PostgreSQL natively supports distributed transactions using PREPARE TRANSACTION, while MySQL uses XA Transactions, and MSSQL also supports distributed/XA transactions. However, in MSSQL, distributed transactions are tied to the original session, the DTC coordinator, and the specific connection. + * These transactions are automatically committed or rolled back following the same rules as regular transactions, with no option for manual intervention from other sessions, in MSSQL distributed transactions are used to coordinate transactions using Linked Servers. + * @example + * await sql.beginDistributed("numbers", async sql => { + * await sql`create table if not exists numbers (a int)`; + * await sql`insert into numbers values(1)`; + * }); + * // later you can call + * await sql.commitDistributed("numbers"); + * // or await sql.rollbackDistributed("numbers"); + */ + beginDistributed(name: string, fn: SQLContextCallback): Promise; + /** Alternative method to begin a distributed transaction + * @alias beginDistributed + */ + distributed(name: string, fn: SQLContextCallback): Promise; + /** Current client options */ + options: SQLOptions; + } + + /** + * Represents a reserved connection from the connection pool + * Extends SQL with additional release functionality + */ + interface ReservedSQL extends SQL { + /** Releases the client back to the connection pool */ + release(): void; + } + + /** + * Represents a client within a transaction context + * Extends SQL with savepoint functionality + */ + interface TransactionSQL extends SQL { + /** Creates a savepoint within the current transaction */ + savepoint(name: string, fn: SQLContextCallback): Promise; + } + + var sql: SQL; /** * This lets you use macros as regular imports diff --git a/src/bun.js/bindings/BunObject.cpp b/src/bun.js/bindings/BunObject.cpp index 5aa13b44d08808..cfe3b9fe33b709 100644 --- a/src/bun.js/bindings/BunObject.cpp +++ b/src/bun.js/bindings/BunObject.cpp @@ -292,7 +292,7 @@ static JSValue constructPluginObject(VM& vm, JSObject* bunObject) return pluginFunction; } -static JSValue constructBunSQLObject(VM& vm, JSObject* bunObject) +static JSValue defaultBunSQLObject(VM& vm, JSObject* bunObject) { auto scope = DECLARE_THROW_SCOPE(vm); auto* globalObject = defaultGlobalObject(bunObject->globalObject()); @@ -301,6 +301,16 @@ static JSValue constructBunSQLObject(VM& vm, JSObject* bunObject) return sqlValue.getObject()->get(globalObject, vm.propertyNames->defaultKeyword); } +static JSValue constructBunSQLObject(VM& vm, JSObject* bunObject) +{ + auto scope = DECLARE_THROW_SCOPE(vm); + auto* globalObject = defaultGlobalObject(bunObject->globalObject()); + JSValue sqlValue = globalObject->internalModuleRegistry()->requireId(globalObject, vm, InternalModuleRegistry::BunSql); + RETURN_IF_EXCEPTION(scope, {}); + auto clientData = WebCore::clientData(vm); + return sqlValue.getObject()->get(globalObject, clientData->builtinNames().SQLPublicName()); +} + extern "C" JSC::EncodedJSValue JSPasswordObject__create(JSGlobalObject*); static JSValue constructPasswordObject(VM& vm, JSObject* bunObject) @@ -745,7 +755,8 @@ JSC_DEFINE_HOST_FUNCTION(functionFileURLToPath, (JSC::JSGlobalObject * globalObj revision constructBunRevision ReadOnly|DontDelete|PropertyCallback semver BunObject_getter_wrap_semver ReadOnly|DontDelete|PropertyCallback s3 BunObject_callback_s3 DontDelete|Function 1 - sql constructBunSQLObject DontDelete|PropertyCallback + sql defaultBunSQLObject DontDelete|PropertyCallback + SQL constructBunSQLObject DontDelete|PropertyCallback serve BunObject_callback_serve DontDelete|Function 1 sha BunObject_callback_sha DontDelete|Function 1 shrink BunObject_callback_shrink DontDelete|Function 1 diff --git a/src/bun.js/bindings/ErrorCode.ts b/src/bun.js/bindings/ErrorCode.ts index f9d3d062829177..40982caa531d54 100644 --- a/src/bun.js/bindings/ErrorCode.ts +++ b/src/bun.js/bindings/ErrorCode.ts @@ -167,6 +167,9 @@ const errors: ErrorCodeMapping = [ ["ERR_POSTGRES_IDLE_TIMEOUT", Error, "PostgresError"], ["ERR_POSTGRES_CONNECTION_TIMEOUT", Error, "PostgresError"], ["ERR_POSTGRES_LIFETIME_TIMEOUT", Error, "PostgresError"], + ["ERR_POSTGRES_INVALID_TRANSACTION_STATE", Error, "PostgresError"], + ["ERR_POSTGRES_QUERY_CANCELLED", Error, "PostgresError"], + ["ERR_POSTGRES_UNSAFE_TRANSACTION", Error, "PostgresError"], // S3 ["ERR_S3_MISSING_CREDENTIALS", Error], diff --git a/src/bun.js/event_loop.zig b/src/bun.js/event_loop.zig index 48e23a010bc5e0..9d5263a3b836cd 100644 --- a/src/bun.js/event_loop.zig +++ b/src/bun.js/event_loop.zig @@ -898,7 +898,6 @@ pub const EventLoop = struct { pub fn runCallback(this: *EventLoop, callback: JSC.JSValue, globalObject: *JSC.JSGlobalObject, thisValue: JSC.JSValue, arguments: []const JSC.JSValue) void { this.enter(); defer this.exit(); - _ = callback.call(globalObject, thisValue, arguments) catch |err| globalObject.reportActiveExceptionAsUnhandled(err); } diff --git a/src/bun.js/module_loader.zig b/src/bun.js/module_loader.zig index 06c02b01467e7a..837b3653bd480c 100644 --- a/src/bun.js/module_loader.zig +++ b/src/bun.js/module_loader.zig @@ -2517,14 +2517,7 @@ pub const ModuleLoader = struct { // These are defined in src/js/* .@"bun:ffi" => return jsSyntheticModule(.@"bun:ffi", specifier), - .@"bun:sql" => { - if (!Environment.isDebug) { - if (!is_allowed_to_use_internal_testing_apis and !bun.FeatureFlags.postgresql) - return null; - } - return jsSyntheticModule(.@"bun:sql", specifier); - }, .@"bun:sqlite" => return jsSyntheticModule(.@"bun:sqlite", specifier), .@"detect-libc" => return jsSyntheticModule(if (!Environment.isLinux) .@"detect-libc" else if (!Environment.isMusl) .@"detect-libc/linux" else .@"detect-libc/musl", specifier), .@"node:assert" => return jsSyntheticModule(.@"node:assert", specifier), @@ -2732,7 +2725,6 @@ pub const HardcodedModule = enum { @"bun:jsc", @"bun:main", @"bun:test", // usually replaced by the transpiler but `await import("bun:" + "test")` has to work - @"bun:sql", @"bun:sqlite", @"detect-libc", @"node:assert", @@ -2819,7 +2811,6 @@ pub const HardcodedModule = enum { .{ "bun:test", HardcodedModule.@"bun:test" }, .{ "bun:sqlite", HardcodedModule.@"bun:sqlite" }, .{ "bun:internal-for-testing", HardcodedModule.@"bun:internal-for-testing" }, - .{ "bun:sql", HardcodedModule.@"bun:sql" }, .{ "detect-libc", HardcodedModule.@"detect-libc" }, .{ "node-fetch", HardcodedModule.@"node-fetch" }, .{ "isomorphic-fetch", HardcodedModule.@"isomorphic-fetch" }, @@ -3059,7 +3050,6 @@ pub const HardcodedModule = enum { .{ "bun:ffi", .{ .path = "bun:ffi" } }, .{ "bun:jsc", .{ .path = "bun:jsc" } }, .{ "bun:sqlite", .{ .path = "bun:sqlite" } }, - .{ "bun:sql", .{ .path = "bun:sql" } }, .{ "bun:wrap", .{ .path = "bun:wrap" } }, .{ "bun:internal-for-testing", .{ .path = "bun:internal-for-testing" } }, .{ "ffi", .{ .path = "bun:ffi" } }, diff --git a/src/js/builtins/BunBuiltinNames.h b/src/js/builtins/BunBuiltinNames.h index 7fd2016a106ffc..7d083d7b5f0930 100644 --- a/src/js/builtins/BunBuiltinNames.h +++ b/src/js/builtins/BunBuiltinNames.h @@ -259,6 +259,7 @@ using namespace JSC; macro(written) \ macro(napiDlopenHandle) \ macro(napiWrappedContents) \ + macro(SQL) \ BUN_ADDITIONAL_BUILTIN_NAMES(macro) // --- END of BUN_COMMON_PRIVATE_IDENTIFIERS_EACH_PROPERTY_NAME --- diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index abe2a973cc6793..1a7a8c7f651cd5 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -1,3 +1,5 @@ +const { hideFromStack } = require("internal/shared"); + const enum QueryStatus { active = 1 << 1, cancelled = 1 << 2, @@ -15,6 +17,11 @@ const enum SSLMode { verify_full = 4, } +function connectionClosedError() { + return $ERR_POSTGRES_CONNECTION_CLOSED("Connection closed"); +} +hideFromStack(connectionClosedError); + class SQLResultArray extends PublicArray { static [Symbol.toStringTag] = "SQLResults"; @@ -33,6 +40,7 @@ const _run = Symbol("run"); const _queryStatus = Symbol("status"); const _handler = Symbol("handler"); const PublicPromise = Promise; +type TransactionCallback = (sql: (strings: string, ...values: any[]) => Query) => Promise; const { createConnection: _createConnection, @@ -105,6 +113,7 @@ class Query extends PublicPromise { } this[_queryStatus] |= QueryStatus.executed; + // this avoids a infinite loop await 1; return handler(this, handle); } @@ -175,12 +184,16 @@ class Query extends PublicPromise { then() { this[_run](); - return super.$then.$apply(this, arguments); + const result = super.$then.$apply(this, arguments); + $markPromiseAsHandled(result); + return result; } catch() { this[_run](); - return super.catch.$apply(this, arguments); + const result = super.catch.$apply(this, arguments); + $markPromiseAsHandled(result); + return result; } finally() { @@ -228,6 +241,526 @@ init( }, ); +function onQueryFinish(onClose) { + this.queries.delete(onClose); + this.pool.release(this); +} + +enum PooledConnectionState { + pending = 0, + connected = 1, + closed = 2, +} +enum PooledConnectionFlags { + /// canBeConnected is used to indicate that at least one time we were able to connect to the database + canBeConnected = 1 << 0, + /// reserved is used to indicate that the connection is currently reserved + reserved = 1 << 1, + /// preReserved is used to indicate that the connection will be reserved in the future when queryCount drops to 0 + preReserved = 1 << 2, +} +class PooledConnection { + pool: ConnectionPool; + connection: ReturnType; + state: PooledConnectionState = PooledConnectionState.pending; + storedError: Error | null = null; + queries: Set<(err: Error) => void> = new Set(); + onFinish: ((err: Error | null) => void) | null = null; + connectionInfo: any; + + flags: number = 0; + /// queryCount is used to indicate the number of queries using the connection, if a connection is reserved or if its a transaction queryCount will be 1 independently of the number of queries + queryCount: number = 0; + #onConnected(err, _) { + const connectionInfo = this.connectionInfo; + if (connectionInfo?.onconnect) { + connectionInfo.onconnect(err); + } + this.storedError = err; + if (!err) { + this.flags |= PooledConnectionFlags.canBeConnected; + } + this.state = err ? PooledConnectionState.closed : PooledConnectionState.connected; + const onFinish = this.onFinish; + if (onFinish) { + this.queryCount = 0; + this.flags &= ~PooledConnectionFlags.reserved; + this.flags &= ~PooledConnectionFlags.preReserved; + + // pool is closed, lets finish the connection + // pool is closed, lets finish the connection + if (err) { + onFinish(err); + } else { + this.connection.close(); + } + return; + } + this.pool.release(this, true); + } + #onClose(err) { + const connectionInfo = this.connectionInfo; + if (connectionInfo?.onclose) { + connectionInfo.onclose(err); + } + this.state = PooledConnectionState.closed; + this.connection = null; + this.storedError = err; + + // remove from ready connections if its there + this.pool.readyConnections.delete(this); + const queries = new Set(this.queries); + this.queries.clear(); + this.queryCount = 0; + this.flags &= ~PooledConnectionFlags.reserved; + + // notify all queries that the connection is closed + for (const onClose of queries) { + onClose(err); + } + const onFinish = this.onFinish; + if (onFinish) { + onFinish(err); + } + + this.pool.release(this, true); + } + constructor(connectionInfo, pool: ConnectionPool) { + this.connection = createConnection(connectionInfo, this.#onConnected.bind(this), this.#onClose.bind(this)); + this.state = PooledConnectionState.pending; + this.pool = pool; + this.connectionInfo = connectionInfo; + } + onClose(onClose: (err: Error) => void) { + this.queries.add(onClose); + } + bindQuery(query: Query, onClose: (err: Error) => void) { + this.queries.add(onClose); + // @ts-ignore + query.finally(onQueryFinish.bind(this, onClose)); + } + #doRetry() { + if (this.pool.closed) { + return; + } + // reset error and state + this.storedError = null; + this.state = PooledConnectionState.pending; + // retry connection + this.connection = createConnection( + this.connectionInfo, + this.#onConnected.bind(this, this.connectionInfo), + this.#onClose.bind(this, this.connectionInfo), + ); + } + close() { + try { + if (this.state === PooledConnectionState.connected) { + this.connection?.close(); + } + } catch {} + } + flush() { + this.connection?.flush(); + } + retry() { + // if pool is closed, we can't retry + if (this.pool.closed) { + return false; + } + // we need to reconnect + // lets use a retry strategy + + // we can only retry if one day we are able to connect + if (this.flags & PooledConnectionFlags.canBeConnected) { + this.#doRetry(); + } else { + // analyse type of error to see if we can retry + switch (this.storedError?.code) { + case "ERR_POSTGRES_UNSUPPORTED_AUTHENTICATION_METHOD": + case "ERR_POSTGRES_UNKNOWN_AUTHENTICATION_METHOD": + case "ERR_POSTGRES_TLS_NOT_AVAILABLE": + case "ERR_POSTGRES_TLS_UPGRADE_FAILED": + case "ERR_POSTGRES_INVALID_SERVER_SIGNATURE": + case "ERR_POSTGRES_INVALID_SERVER_KEY": + case "ERR_POSTGRES_AUTHENTICATION_FAILED_PBKDF2": + // we can't retry these are authentication errors + return false; + default: + // we can retry + this.#doRetry(); + return true; + } + } + } +} +class ConnectionPool { + connectionInfo: any; + + connections: PooledConnection[]; + readyConnections: Set; + waitingQueue: Array<(err: Error | null, result: any) => void> = []; + reservedQueue: Array<(err: Error | null, result: any) => void> = []; + + poolStarted: boolean = false; + closed: boolean = false; + onAllQueriesFinished: (() => void) | null = null; + constructor(connectionInfo) { + this.connectionInfo = connectionInfo; + this.connections = new Array(connectionInfo.max); + this.readyConnections = new Set(); + } + + flushConcurrentQueries() { + if (this.waitingQueue.length === 0) { + return; + } + while (this.waitingQueue.length > 0) { + let endReached = true; + // no need to filter for reserved connections because there are not in the readyConnections + // preReserved only shows that we wanna avoiding adding more queries to it + const nonReservedConnections = Array.from(this.readyConnections).filter( + c => !(c.flags & PooledConnectionFlags.preReserved), + ); + if (nonReservedConnections.length === 0) { + return; + } + // kinda balance the load between connections + const orderedConnections = nonReservedConnections.sort((a, b) => a.queryCount - b.queryCount); + const leastQueries = orderedConnections[0].queryCount; + + for (const connection of orderedConnections) { + if (connection.queryCount > leastQueries) { + endReached = false; + break; + } + + const pending = this.waitingQueue.shift(); + if (pending) { + connection.queryCount++; + pending(null, connection); + } + } + const halfPoolSize = Math.ceil(this.connections.length / 2); + if (endReached || orderedConnections.length < halfPoolSize) { + // we are able to distribute the load between connections but the connection pool is less than half of the pool size + // so we can stop here and wait for the next tick to flush the waiting queue + break; + } + } + if (this.waitingQueue.length > 0) { + // we still wanna to flush the waiting queue but lets wait for the next tick because some connections might be released + // this is better for query performance + process.nextTick(this.flushConcurrentQueries.bind(this)); + } + } + + release(connection: PooledConnection, connectingEvent: boolean = false) { + if (!connectingEvent) { + connection.queryCount--; + } + const was_reserved = connection.flags & PooledConnectionFlags.reserved; + connection.flags &= ~PooledConnectionFlags.reserved; + connection.flags &= ~PooledConnectionFlags.preReserved; + if (this.onAllQueriesFinished) { + // we are waiting for all queries to finish, lets check if we can call it + if (!this.hasPendingQueries()) { + this.onAllQueriesFinished(); + } + } + if (connection.state !== PooledConnectionState.connected) { + // connection is not ready + return; + } + if (was_reserved) { + if (this.waitingQueue.length > 0) { + if (connection.storedError) { + // this connection got a error but maybe we can wait for another + + if (this.hasConnectionsAvailable()) { + return; + } + + // we have no connections available so lets fails + let pending; + while ((pending = this.waitingQueue.shift())) { + pending.onConnected(connection.storedError, connection); + } + return; + } + const pendingReserved = this.reservedQueue.shift(); + if (pendingReserved) { + connection.flags |= PooledConnectionFlags.reserved; + connection.queryCount++; + // we have a connection waiting for a reserved connection lets prioritize it + pendingReserved(connection.storedError, connection); + return; + } + this.flushConcurrentQueries(); + } else { + // connection is ready, lets add it back to the ready connections + this.readyConnections.add(connection); + } + } else { + if (connection.queryCount == 0) { + // ok we can actually bind reserved queries to it + const pendingReserved = this.reservedQueue.shift(); + if (pendingReserved) { + connection.flags |= PooledConnectionFlags.reserved; + connection.queryCount++; + // we have a connection waiting for a reserved connection lets prioritize it + pendingReserved(connection.storedError, connection); + return; + } + } + + this.readyConnections.add(connection); + + this.flushConcurrentQueries(); + } + } + + hasConnectionsAvailable() { + if (this.readyConnections.size > 0) return true; + if (this.poolStarted) { + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + if (connection.state !== PooledConnectionState.closed) { + // some connection is connecting or connected + return true; + } + } + } + return false; + } + hasPendingQueries() { + if (this.waitingQueue.length > 0 || this.reservedQueue.length > 0) return true; + if (this.poolStarted) { + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + if (connection.queryCount > 0) { + return true; + } + } + } + return false; + } + isConnected() { + if (this.readyConnections.size > 0) { + return true; + } + if (this.poolStarted) { + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + if (connection.state === PooledConnectionState.connected) { + return true; + } + } + } + return false; + } + flush() { + if (this.closed) { + return; + } + if (this.poolStarted) { + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + if (connection.state === PooledConnectionState.connected) { + connection.connection.flush(); + } + } + } + } + + async #close() { + let pending; + while ((pending = this.waitingQueue.shift())) { + pending(connectionClosedError(), null); + } + while (this.reservedQueue.length > 0) { + const pendingReserved = this.reservedQueue.shift(); + if (pendingReserved) { + pendingReserved(connectionClosedError(), null); + } + } + const promises: Array> = []; + if (this.poolStarted) { + this.poolStarted = false; + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + switch (connection.state) { + case PooledConnectionState.pending: + { + const { promise, resolve } = Promise.withResolvers(); + connection.onFinish = resolve; + promises.push(promise); + } + break; + case PooledConnectionState.connected: + { + const { promise, resolve } = Promise.withResolvers(); + connection.onFinish = resolve; + promises.push(promise); + connection.connection.close(); + } + break; + } + // clean connection reference + // @ts-ignore + this.connections[i] = null; + } + } + this.readyConnections.clear(); + this.waitingQueue.length = 0; + return Promise.all(promises); + } + async close(options?: { timeout?: number }) { + if (this.closed) { + return Promise.reject(connectionClosedError()); + } + let timeout = options?.timeout; + if (timeout) { + timeout = Number(timeout); + if (timeout > 2 ** 31 || timeout < 0 || timeout !== timeout) { + throw $ERR_INVALID_ARG_VALUE("options.timeout", timeout, "must be a non-negative integer less than 2^31"); + } + this.closed = true; + if (timeout > 0 && this.hasPendingQueries()) { + const { promise, resolve } = Promise.withResolvers(); + const timer = setTimeout(() => { + // timeout is reached, lets close and probably fail some queries + this.#close().finally(resolve); + }, timeout * 1000); + timer.unref(); // dont block the event loop + this.onAllQueriesFinished = () => { + clearTimeout(timer); + // everything is closed, lets close the pool + this.#close().finally(resolve); + }; + + return promise; + } + } else { + this.closed = true; + } + + await this.#close(); + } + + /** + * @param {function} onConnected - The callback function to be called when the connection is established. + * @param {boolean} reserved - Whether the connection is reserved, if is reserved the connection will not be released until release is called, if not release will only decrement the queryCount counter + */ + connect(onConnected: (err: Error | null, result: any) => void, reserved: boolean = false) { + if (this.closed) { + return onConnected(connectionClosedError(), null); + } + + if (this.readyConnections.size === 0) { + // no connection ready lets make some + let retry_in_progress = false; + let all_closed = true; + let storedError: Error | null = null; + + if (this.poolStarted) { + // we already started the pool + // lets check if some connection is available to retry + const pollSize = this.connections.length; + for (let i = 0; i < pollSize; i++) { + const connection = this.connections[i]; + // we need a new connection and we have some connections that can retry + if (connection.state === PooledConnectionState.closed) { + if (connection.retry()) { + // lets wait for connection to be released + if (!retry_in_progress) { + // avoid adding to the queue twice, we wanna to retry every available pool connection + retry_in_progress = true; + if (reserved) { + // we are not sure what connection will be available so we dont pre reserve + this.reservedQueue.push(onConnected); + } else { + this.waitingQueue.push(onConnected); + } + } + } else { + // we have some error, lets grab it and fail if unable to start a connection + storedError = connection.storedError; + } + } else { + // we have some pending or open connections + all_closed = false; + } + } + if (!all_closed && !retry_in_progress) { + // is possible to connect because we have some working connections, or we are just without network for some reason + // wait for connection to be released or fail + if (reserved) { + // we are not sure what connection will be available so we dont pre reserve + this.reservedQueue.push(onConnected); + } else { + this.waitingQueue.push(onConnected); + } + } else { + // impossible to connect or retry + onConnected(storedError, null); + } + return; + } + // we never started the pool, lets start it + if (reserved) { + this.reservedQueue.push(onConnected); + } else { + this.waitingQueue.push(onConnected); + } + this.poolStarted = true; + const pollSize = this.connections.length; + // pool is always at least 1 connection + this.connections[0] = new PooledConnection(this.connectionInfo, this); + this.connections[0].flags |= PooledConnectionFlags.preReserved; // lets pre reserve the first connection + for (let i = 1; i < pollSize; i++) { + this.connections[i] = new PooledConnection(this.connectionInfo, this); + } + return; + } + if (reserved) { + let connectionWithLeastQueries: PooledConnection | null = null; + let leastQueries = Infinity; + for (const connection of this.readyConnections) { + if (connection.flags & PooledConnectionFlags.reserved || connection.flags & PooledConnectionFlags.preReserved) + continue; + const queryCount = connection.queryCount; + if (queryCount > 0) { + if (queryCount < leastQueries) { + leastQueries = queryCount; + connectionWithLeastQueries = connection; + continue; + } + } + connection.flags |= PooledConnectionFlags.reserved; + connection.queryCount++; + this.readyConnections.delete(connection); + onConnected(null, connection); + return; + } + if (connectionWithLeastQueries) { + // lets mark the connection with the least queries as preReserved if any + connectionWithLeastQueries.flags |= PooledConnectionFlags.preReserved; + } + // no connection available to be reserved lets wait for a connection to be released + this.reservedQueue.push(onConnected); + } else { + this.waitingQueue.push(onConnected); + this.flushConcurrentQueries(); + } + } +} + function createConnection( { hostname, @@ -367,7 +900,8 @@ function loadOptions(o) { connectionTimeout, maxLifetime, onconnect, - onclose; + onclose, + max; const env = Bun.env; var sslMode: SSLMode = SSLMode.disable; @@ -405,7 +939,14 @@ function loadOptions(o) { } if (url) { - ({ hostname, port, username, password, protocol: adapter } = o = url); + ({ hostname, port, username, password, adapter } = o); + // object overrides url + hostname ||= url.hostname; + port ||= url.port; + username ||= url.username; + password ||= url.password; + adapter ||= url.protocol; + if (adapter[adapter.length - 1] === ":") { adapter = adapter.slice(0, -1); } @@ -429,6 +970,7 @@ function loadOptions(o) { password ||= o.password || o.pass || env.PGPASSWORD || ""; tls ||= o.tls || o.ssl; adapter ||= o.adapter || "postgres"; + max = o.max; idleTimeout ??= o.idleTimeout; idleTimeout ??= o.idle_timeout; @@ -484,6 +1026,13 @@ function loadOptions(o) { } } + if (max != null) { + max = Number(max); + if (max > 2 ** 31 || max < 1 || max !== max) { + throw $ERR_INVALID_ARG_VALUE("options.max", max, "must be a non-negative integer between 1 and 2^31"); + } + } + if (sslMode !== SSLMode.disable && !tls?.serverName) { if (hostname) { tls = { @@ -500,14 +1049,18 @@ function loadOptions(o) { port = Number(port); if (!Number.isSafeInteger(port) || port < 1 || port > 65535) { - throw new Error(`Invalid port: ${port}`); + throw $ERR_INVALID_ARG_VALUE("port", port, "must be a non-negative integer between 1 and 65535"); } - if (adapter && !(adapter === "postgres" || adapter === "postgresql")) { - throw new Error(`Unsupported adapter: ${adapter}. Only \"postgres\" is supported for now`); + switch (adapter) { + case "postgres": + case "postgresql": + adapter = "postgres"; + break; + default: + throw new Error(`Unsupported adapter: ${adapter}. Only \"postgres\" is supported for now`); } - - const ret: any = { hostname, port, username, password, database, tls, query, sslMode }; + const ret: any = { hostname, port, username, password, database, tls, query, sslMode, adapter }; if (idleTimeout != null) { ret.idleTimeout = idleTimeout; } @@ -523,108 +1076,690 @@ function loadOptions(o) { if (onclose !== undefined) { ret.onclose = onclose; } + ret.max = max || 10; + return ret; } -function SQL(o) { - var connection, - connected = false, - connecting = false, - closed = false, - onConnect: any[] = [], - storedErrorForClosedConnection, - connectionInfo = loadOptions(o); +enum ReservedConnectionState { + acceptQueries = 1 << 0, + closed = 1 << 1, +} + +function assertValidTransactionName(name: string) { + if (name.indexOf("'") !== -1) { + throw Error(`Distributed transaction name cannot contain single quotes.`); + } +} +function SQL(o, e = {}) { + if (typeof o === "string" || o instanceof URL) { + o = { ...e, url: o }; + } + var connectionInfo = loadOptions(o); + var pool = new ConnectionPool(connectionInfo); - function connectedHandler(query, handle, err) { + function doCreateQuery(strings, values, allowUnsafeTransaction) { + const sqlString = normalizeStrings(strings, values); + let columns; + if (hasSQLArrayParameter) { + hasSQLArrayParameter = false; + const v = values[0]; + columns = v.columns; + values = v.value; + } + if (!allowUnsafeTransaction) { + if (connectionInfo.max !== 1) { + const upperCaseSqlString = sqlString.toUpperCase().trim(); + if (upperCaseSqlString.startsWith("BEGIN") || upperCaseSqlString.startsWith("START TRANSACTION")) { + throw $ERR_POSTGRES_UNSAFE_TRANSACTION("Only use sql.begin, sql.reserved or max: 1"); + } + } + } + return createQuery(sqlString, values, new SQLResultArray(), columns); + } + + function onQueryDisconnected(err) { + // connection closed mid query this will not be called if the query finishes first + const query = this; if (err) { return query.reject(err); } - - if (!connected) { - return query.reject(storedErrorForClosedConnection || new Error("Not connected")); + // query is cancelled when waiting for a connection from the pool + if (query.cancelled) { + return query.reject($ERR_POSTGRES_QUERY_CANCELLED("Query cancelled")); } + } + function onQueryConnected(handle, err, pooledConnection) { + const query = this; + if (err) { + // fail to aquire a connection from the pool + return query.reject(err); + } + // query is cancelled when waiting for a connection from the pool if (query.cancelled) { - return query.reject(new Error("Query cancelled")); + pool.release(pooledConnection); // release the connection back to the pool + return query.reject($ERR_POSTGRES_QUERY_CANCELLED("Query cancelled")); } - - handle.run(connection, query); - - // if the above throws, we don't want it to be in the array. - // This array exists mostly to keep the in-flight queries alive. - connection.queries.push(query); + // bind close event to the query (will unbind and auto release the connection when the query is finished) + pooledConnection.bindQuery(query, onQueryDisconnected.bind(query)); + handle.run(pooledConnection.connection, query); } + function queryFromPoolHandler(query, handle, err) { + if (err) { + // fail to create query + return query.reject(err); + } + // query is cancelled + if (!handle || query.cancelled) { + return query.reject($ERR_POSTGRES_QUERY_CANCELLED("Query cancelled")); + } - function pendingConnectionHandler(query, handle) { - onConnect.push(err => connectedHandler(query, handle, err)); - if (!connecting) { - connecting = true; - connection = createConnection(connectionInfo, onConnected, onClose); + pool.connect(onQueryConnected.bind(query, handle)); + } + function queryFromPool(strings, values) { + try { + return new Query(doCreateQuery(strings, values, false), queryFromPoolHandler); + } catch (err) { + return Promise.reject(err); } } - function closedConnectionHandler(query, handle) { - query.reject(storedErrorForClosedConnection || new Error("Connection closed")); + function onTransactionQueryDisconnected(query) { + const transactionQueries = this; + transactionQueries.delete(query); } + function queryFromTransactionHandler(transactionQueries, query, handle, err) { + const pooledConnection = this; + if (err) { + transactionQueries.delete(query); + return query.reject(err); + } + // query is cancelled + if (query.cancelled) { + transactionQueries.delete(query); + return query.reject($ERR_POSTGRES_QUERY_CANCELLED("Query cancelled")); + } - function onConnected(err, result) { - connected = !err; - for (const handler of onConnect) { - handler(err); + query.finally(onTransactionQueryDisconnected.bind(transactionQueries, query)); + handle.run(pooledConnection.connection, query); + } + function queryFromTransaction(strings, values, pooledConnection, transactionQueries) { + try { + const query = new Query( + doCreateQuery(strings, values, true), + queryFromTransactionHandler.bind(pooledConnection, transactionQueries), + ); + transactionQueries.add(query); + return query; + } catch (err) { + return Promise.reject(err); } - onConnect = []; + } + function onTransactionDisconnected(err) { + const reject = this.reject; + this.connectionState |= ReservedConnectionState.closed; - if (connected && connectionInfo?.onconnect) { - connectionInfo.onconnect(err); + for (const query of this.queries) { + (query as Query).reject(err); + } + if (err) { + return reject(err); } } - function onClose(err, queries) { - closed = true; - storedErrorForClosedConnection = err; - if (sql === lazyDefaultSQL) { - resetDefaultSQL(initialDefaultSQL); + function onReserveConnected(err, pooledConnection) { + const { promise, resolve, reject } = this; + if (err) { + return reject(err); } - onConnected(err, undefined); - if (queries) { - const queriesCopy = queries.slice(); - queries.length = 0; - for (const handler of queriesCopy) { - handler.reject(err); + let reservedTransaction = new Set(); + + const state = { + connectionState: ReservedConnectionState.acceptQueries, + reject, + storedError: null, + queries: new Set(), + }; + const onClose = onTransactionDisconnected.bind(state); + pooledConnection.onClose(onClose); + + function reserved_sql(strings, ...values) { + if ( + state.connectionState & ReservedConnectionState.closed || + !(state.connectionState & ReservedConnectionState.acceptQueries) + ) { + return Promise.reject(connectionClosedError()); } + if ($isJSArray(strings) && strings[0] && typeof strings[0] === "object") { + return new SQLArrayParameter(strings, values); + } + // we use the same code path as the transaction sql + return queryFromTransaction(strings, values, pooledConnection, state.queries); } + reserved_sql.connect = () => { + if (state.connectionState & ReservedConnectionState.closed) { + return Promise.reject(connectionClosedError()); + } + return Promise.resolve(reserved_sql); + }; + + reserved_sql.commitDistributed = async function (name: string) { + const adapter = connectionInfo.adapter; + assertValidTransactionName(name); + switch (adapter) { + case "postgres": + return await reserved_sql(`COMMIT PREPARED '${name}'`); + case "mysql": + return await reserved_sql(`XA COMMIT '${name}'`); + case "mssql": + throw Error(`MSSQL distributed transaction is automatically committed.`); + case "sqlite": + throw Error(`SQLite dont support distributed transactions.`); + default: + throw Error(`Unsupported adapter: ${adapter}.`); + } + }; + reserved_sql.rollbackDistributed = async function (name: string) { + assertValidTransactionName(name); + const adapter = connectionInfo.adapter; + switch (adapter) { + case "postgres": + return await reserved_sql(`ROLLBACK PREPARED '${name}'`); + case "mysql": + return await reserved_sql(`XA ROLLBACK '${name}'`); + case "mssql": + throw Error(`MSSQL distributed transaction is automatically rolled back.`); + case "sqlite": + throw Error(`SQLite dont support distributed transactions.`); + default: + throw Error(`Unsupported adapter: ${adapter}.`); + } + }; - if (connectionInfo?.onclose) { - connectionInfo.onclose(err); + // reserve is allowed to be called inside reserved connection but will return a new reserved connection from the pool + // this matchs the behavior of the postgres package + reserved_sql.reserve = () => sql.reserve(); + function onTransactionFinished(transaction_promise: Promise) { + reservedTransaction.delete(transaction_promise); } + reserved_sql.beginDistributed = (name: string, fn: TransactionCallback) => { + // begin is allowed the difference is that we need to make sure to use the same connection and never release it + if (state.connectionState & ReservedConnectionState.closed) { + return Promise.reject(connectionClosedError()); + } + let callback = fn; + + if (typeof name !== "string") { + return Promise.reject($ERR_INVALID_ARG_VALUE("name", name, "must be a string")); + } + + if (!$isCallable(callback)) { + return Promise.reject($ERR_INVALID_ARG_VALUE("fn", callback, "must be a function")); + } + const { promise, resolve, reject } = Promise.withResolvers(); + // lets just reuse the same code path as the transaction begin + onTransactionConnected(callback, name, resolve, reject, true, true, null, pooledConnection); + reservedTransaction.add(promise); + promise.finally(onTransactionFinished.bind(null, promise)); + return promise; + }; + reserved_sql.begin = (options_or_fn: string | TransactionCallback, fn?: TransactionCallback) => { + // begin is allowed the difference is that we need to make sure to use the same connection and never release it + if ( + state.connectionState & ReservedConnectionState.closed || + !(state.connectionState & ReservedConnectionState.acceptQueries) + ) { + return Promise.reject(connectionClosedError()); + } + let callback = fn; + let options: string | undefined = options_or_fn as unknown as string; + if ($isCallable(options_or_fn)) { + callback = options_or_fn as unknown as TransactionCallback; + options = undefined; + } else if (typeof options_or_fn !== "string") { + return Promise.reject($ERR_INVALID_ARG_VALUE("options", options_or_fn, "must be a string")); + } + if (!$isCallable(callback)) { + return Promise.reject($ERR_INVALID_ARG_VALUE("fn", callback, "must be a function")); + } + const { promise, resolve, reject } = Promise.withResolvers(); + // lets just reuse the same code path as the transaction begin + onTransactionConnected(callback, options, resolve, reject, true, false, null, pooledConnection); + reservedTransaction.add(promise); + promise.finally(onTransactionFinished.bind(null, promise)); + return promise; + }; + + reserved_sql.flush = () => { + if (state.connectionState & ReservedConnectionState.closed) { + throw connectionClosedError(); + } + return pooledConnection.flush(); + }; + reserved_sql.close = async (options?: { timeout?: number }) => { + const reserveQueries = state.queries; + if ( + state.connectionState & ReservedConnectionState.closed || + !(state.connectionState & ReservedConnectionState.acceptQueries) + ) { + return Promise.reject(connectionClosedError()); + } + state.connectionState &= ~ReservedConnectionState.acceptQueries; + let timeout = options?.timeout; + if (timeout) { + timeout = Number(timeout); + if (timeout > 2 ** 31 || timeout < 0 || timeout !== timeout) { + throw $ERR_INVALID_ARG_VALUE("options.timeout", timeout, "must be a non-negative integer less than 2^31"); + } + if (timeout > 0 && (reserveQueries.size > 0 || reservedTransaction.size > 0)) { + const { promise, resolve } = Promise.withResolvers(); + // race all queries vs timeout + const pending_queries = Array.from(reserveQueries); + const pending_transactions = Array.from(reservedTransaction); + const timer = setTimeout(() => { + state.connectionState |= ReservedConnectionState.closed; + for (const query of reserveQueries) { + (query as Query).cancel(); + } + state.connectionState |= ReservedConnectionState.closed; + pooledConnection.close(); + + resolve(); + }, timeout * 1000); + timer.unref(); // dont block the event loop + Promise.all([Promise.all(pending_queries), Promise.all(pending_transactions)]).finally(() => { + clearTimeout(timer); + resolve(); + }); + return promise; + } + } + state.connectionState |= ReservedConnectionState.closed; + for (const query of reserveQueries) { + (query as Query).cancel(); + } + + pooledConnection.close(); + + return Promise.resolve(undefined); + }; + reserved_sql.release = () => { + if ( + state.connectionState & ReservedConnectionState.closed || + !(state.connectionState & ReservedConnectionState.acceptQueries) + ) { + return Promise.reject(connectionClosedError()); + } + // just release the connection back to the pool + state.connectionState |= ReservedConnectionState.closed; + state.connectionState &= ~ReservedConnectionState.acceptQueries; + pooledConnection.queries.delete(onClose); + pool.release(pooledConnection); + return Promise.resolve(undefined); + }; + // this dont need to be async dispose only disposable but we keep compatibility with other types of sql functions + reserved_sql[Symbol.asyncDispose] = () => reserved_sql.release(); + reserved_sql[Symbol.dispose] = () => reserved_sql.release(); + + reserved_sql.options = sql.options; + reserved_sql.transaction = reserved_sql.begin; + reserved_sql.distributed = reserved_sql.beginDistributed; + reserved_sql.end = reserved_sql.close; + resolve(reserved_sql); } + async function onTransactionConnected( + callback, + options, + resolve, + reject, + dontRelease, + distributed, + err, + pooledConnection, + ) { + /* + BEGIN; -- works on POSTGRES, MySQL (autocommit is true, no options accepted), and SQLite (no options accepted) (need to change to BEGIN TRANSACTION on MSSQL) + START TRANSACTION; -- works on POSTGRES, MySQL (autocommit is false, options accepted), (need to change to BEGIN TRANSACTION on MSSQL and BEGIN on SQLite) + + -- Create a SAVEPOINT + SAVEPOINT my_savepoint; -- works on POSTGRES, MySQL, and SQLite (need to change to SAVE TRANSACTION on MSSQL) + + -- QUERY + + -- Roll back to SAVEPOINT if needed + ROLLBACK TO SAVEPOINT my_savepoint; -- works on POSTGRES, MySQL, and SQLite (need to change to ROLLBACK TRANSACTION on MSSQL) + + -- Release the SAVEPOINT + RELEASE SAVEPOINT my_savepoint; -- works on POSTGRES, MySQL, and SQLite (MSSQL dont have RELEASE SAVEPOINT you just need to transaction again) + + -- Commit the transaction + COMMIT; -- works on POSTGRES, MySQL, and SQLite (need to change to COMMIT TRANSACTION on MSSQL) + -- or rollback everything + ROLLBACK; -- works on POSTGRES, MySQL, and SQLite (need to change to ROLLBACK TRANSACTION on MSSQL) + + */ - function doCreateQuery(strings, values) { - const sqlString = normalizeStrings(strings, values); - let columns; - if (hasSQLArrayParameter) { - hasSQLArrayParameter = false; - const v = values[0]; - columns = v.columns; - values = v.value; + if (err) { + return reject(err); } + const state = { + connectionState: ReservedConnectionState.acceptQueries, + reject, + queries: new Set(), + }; + + let savepoints = 0; + let transactionSavepoints = new Set(); + const adapter = connectionInfo.adapter; + let BEGIN_COMMAND: string = "BEGIN"; + let ROLLBACK_COMMAND: string = "COMMIT"; + let COMMIT_COMMAND: string = "ROLLBACK"; + let SAVEPOINT_COMMAND: string = "SAVEPOINT"; + let RELEASE_SAVEPOINT_COMMAND: string | null = "RELEASE SAVEPOINT"; + let ROLLBACK_TO_SAVEPOINT_COMMAND: string = "ROLLBACK TO SAVEPOINT"; + // MySQL and maybe other adapters need to call XA END or some other command before commit or rollback in a distributed transaction + let BEFORE_COMMIT_OR_ROLLBACK_COMMAND: string | null = null; + if (distributed) { + if (options.indexOf("'") !== -1) { + pool.release(pooledConnection); + return reject(new Error(`Distributed transaction name cannot contain single quotes.`)); + } + // distributed transaction + // in distributed transaction options is the name/id of the transaction + switch (adapter) { + case "postgres": + // in postgres we only need to call prepare transaction instead of commit + COMMIT_COMMAND = `PREPARE TRANSACTION '${options}'`; + break; + case "mysql": + // MySQL we use XA transactions + // START TRANSACTION is autocommit false + BEGIN_COMMAND = `XA START '${options}'`; + BEFORE_COMMIT_OR_ROLLBACK_COMMAND = `XA END '${options}'`; + COMMIT_COMMAND = `XA PREPARE '${options}'`; + ROLLBACK_COMMAND = `XA ROLLBACK '${options}'`; + break; + case "sqlite": + pool.release(pooledConnection); + + // do not support options just use defaults + return reject(new Error(`SQLite dont support distributed transactions.`)); + case "mssql": + BEGIN_COMMAND = ` BEGIN DISTRIBUTED TRANSACTION ${options}`; + ROLLBACK_COMMAND = `ROLLBACK TRANSACTION ${options}`; + COMMIT_COMMAND = `COMMIT TRANSACTION ${options}`; + break; + default: + pool.release(pooledConnection); + + // TODO: use ERR_ + return reject(new Error(`Unsupported adapter: ${adapter}.`)); + } + } else { + // normal transaction + switch (adapter) { + case "postgres": + if (options) { + BEGIN_COMMAND = `BEGIN ${options}`; + } + break; + case "mysql": + // START TRANSACTION is autocommit false + BEGIN_COMMAND = options ? `START TRANSACTION ${options}` : "START TRANSACTION"; + break; + + case "sqlite": + if (options) { + // sqlite supports DEFERRED, IMMEDIATE, EXCLUSIVE + BEGIN_COMMAND = `BEGIN ${options}`; + } + break; + case "mssql": + BEGIN_COMMAND = options ? `START TRANSACTION ${options}` : "START TRANSACTION"; + ROLLBACK_COMMAND = "ROLLBACK TRANSACTION"; + COMMIT_COMMAND = "COMMIT TRANSACTION"; + SAVEPOINT_COMMAND = "SAVE"; + RELEASE_SAVEPOINT_COMMAND = null; // mssql dont have release savepoint + ROLLBACK_TO_SAVEPOINT_COMMAND = "ROLLBACK TRANSACTION"; + break; + default: + pool.release(pooledConnection); + // TODO: use ERR_ + return reject(new Error(`Unsupported adapter: ${adapter}.`)); + } + } + const onClose = onTransactionDisconnected.bind(state); + pooledConnection.onClose(onClose); - return createQuery(sqlString, values, new SQLResultArray(), columns); - } + function run_internal_transaction_sql(strings, ...values) { + if (state.connectionState & ReservedConnectionState.closed) { + return Promise.reject(connectionClosedError()); + } + return queryFromTransaction(strings, values, pooledConnection, state.queries); + } + function transaction_sql(strings, ...values) { + if ( + state.connectionState & ReservedConnectionState.closed || + !(state.connectionState & ReservedConnectionState.acceptQueries) + ) { + return Promise.reject(connectionClosedError()); + } + if ($isJSArray(strings) && strings[0] && typeof strings[0] === "object") { + return new SQLArrayParameter(strings, values); + } - function connectedSQL(strings, values) { - return new Query(doCreateQuery(strings, values), connectedHandler); - } + return queryFromTransaction(strings, values, pooledConnection, state.queries); + } + // reserve is allowed to be called inside transaction connection but will return a new reserved connection from the pool and will not be part of the transaction + // this matchs the behavior of the postgres package + transaction_sql.reserve = () => sql.reserve(); - function closedSQL(strings, values) { - return new Query(undefined, closedConnectionHandler); - } + transaction_sql.connect = () => { + if (state.connectionState & ReservedConnectionState.closed) { + return Promise.reject(connectionClosedError()); + } - function pendingSQL(strings, values) { - return new Query(doCreateQuery(strings, values), pendingConnectionHandler); - } + return Promise.resolve(transaction_sql); + }; + transaction_sql.commitDistributed = async function (name: string) { + assertValidTransactionName(name); + switch (adapter) { + case "postgres": + return await transaction_sql(`COMMIT PREPARED '${name}'`); + case "mysql": + return await transaction_sql(`XA COMMIT '${name}'`); + case "mssql": + throw Error(`MSSQL distributed transaction is automatically committed.`); + case "sqlite": + throw Error(`SQLite dont support distributed transactions.`); + default: + throw Error(`Unsupported adapter: ${adapter}.`); + } + }; + transaction_sql.rollbackDistributed = async function (name: string) { + assertValidTransactionName(name); + switch (adapter) { + case "postgres": + return await transaction_sql(`ROLLBACK PREPARED '${name}'`); + case "mysql": + return await transaction_sql(`XA ROLLBACK '${name}'`); + case "mssql": + throw Error(`MSSQL distributed transaction is automatically rolled back.`); + case "sqlite": + throw Error(`SQLite dont support distributed transactions.`); + default: + throw Error(`Unsupported adapter: ${adapter}.`); + } + }; + // begin is not allowed on a transaction we need to use savepoint() instead + transaction_sql.begin = function () { + if (distributed) { + throw $ERR_POSTGRES_INVALID_TRANSACTION_STATE("cannot call begin inside a distributed transaction"); + } + throw $ERR_POSTGRES_INVALID_TRANSACTION_STATE("cannot call begin inside a transaction use savepoint() instead"); + }; + + transaction_sql.beginDistributed = function () { + if (distributed) { + throw $ERR_POSTGRES_INVALID_TRANSACTION_STATE("cannot call beginDistributed inside a distributed transaction"); + } + throw $ERR_POSTGRES_INVALID_TRANSACTION_STATE( + "cannot call beginDistributed inside a transaction use savepoint() instead", + ); + }; + + transaction_sql.flush = function () { + if (state.connectionState & ReservedConnectionState.closed) { + throw connectionClosedError(); + } + return pooledConnection.flush(); + }; + transaction_sql.close = async function (options?: { timeout?: number }) { + // we dont actually close the connection here, we just set the state to closed and rollback the transaction + if ( + state.connectionState & ReservedConnectionState.closed || + !(state.connectionState & ReservedConnectionState.acceptQueries) + ) { + return Promise.reject(connectionClosedError()); + } + state.connectionState &= ~ReservedConnectionState.acceptQueries; + const transactionQueries = state.queries; + let timeout = options?.timeout; + if (timeout) { + timeout = Number(timeout); + if (timeout > 2 ** 31 || timeout < 0 || timeout !== timeout) { + throw $ERR_INVALID_ARG_VALUE("options.timeout", timeout, "must be a non-negative integer less than 2^31"); + } + + if (timeout > 0 && (transactionQueries.size > 0 || transactionSavepoints.size > 0)) { + const { promise, resolve } = Promise.withResolvers(); + // race all queries vs timeout + const pending_queries = Array.from(transactionQueries); + const pending_savepoints = Array.from(transactionSavepoints); + const timer = setTimeout(async () => { + for (const query of transactionQueries) { + (query as Query).cancel(); + } + if (BEFORE_COMMIT_OR_ROLLBACK_COMMAND) { + await run_internal_transaction_sql(BEFORE_COMMIT_OR_ROLLBACK_COMMAND); + } + await run_internal_transaction_sql(ROLLBACK_COMMAND); + state.connectionState |= ReservedConnectionState.closed; + resolve(); + }, timeout * 1000); + timer.unref(); // dont block the event loop + Promise.all([Promise.all(pending_queries), Promise.all(pending_savepoints)]).finally(() => { + clearTimeout(timer); + resolve(); + }); + return promise; + } + } + for (const query of transactionQueries) { + (query as Query).cancel(); + } + if (BEFORE_COMMIT_OR_ROLLBACK_COMMAND) { + await run_internal_transaction_sql(BEFORE_COMMIT_OR_ROLLBACK_COMMAND); + } + await run_internal_transaction_sql(ROLLBACK_COMMAND); + state.connectionState |= ReservedConnectionState.closed; + }; + transaction_sql[Symbol.asyncDispose] = () => transaction_sql.close(); + transaction_sql.options = sql.options; + + transaction_sql.transaction = transaction_sql.begin; + transaction_sql.distributed = transaction_sql.beginDistributed; + transaction_sql.end = transaction_sql.close; + function onSavepointFinished(savepoint_promise: Promise) { + transactionSavepoints.delete(savepoint_promise); + } + async function run_internal_savepoint(save_point_name: string, savepoint_callback: TransactionCallback) { + await run_internal_transaction_sql(`${SAVEPOINT_COMMAND} ${save_point_name}`); + + try { + let result = await savepoint_callback(transaction_sql); + if (RELEASE_SAVEPOINT_COMMAND) { + // mssql dont have release savepoint + await run_internal_transaction_sql(`${RELEASE_SAVEPOINT_COMMAND} ${save_point_name}`); + } + if (Array.isArray(result)) { + result = await Promise.all(result); + } + return result; + } catch (err) { + if (!(state.connectionState & ReservedConnectionState.closed)) { + await run_internal_transaction_sql(`${ROLLBACK_TO_SAVEPOINT_COMMAND} ${save_point_name}`); + } + throw err; + } + } + if (distributed) { + transaction_sql.savepoint = async (fn: TransactionCallback, name?: string): Promise => { + throw $ERR_POSTGRES_INVALID_TRANSACTION_STATE("cannot call savepoint inside a distributed transaction"); + }; + } else { + transaction_sql.savepoint = async (fn: TransactionCallback, name?: string): Promise => { + let savepoint_callback = fn; + + if ( + state.connectionState & ReservedConnectionState.closed || + !(state.connectionState & ReservedConnectionState.acceptQueries) + ) { + throw connectionClosedError(); + } + if ($isCallable(name)) { + savepoint_callback = name as unknown as TransactionCallback; + name = ""; + } + if (!$isCallable(savepoint_callback)) { + throw $ERR_INVALID_ARG_VALUE("fn", callback, "must be a function"); + } + // matchs the format of the savepoint name in postgres package + const save_point_name = `s${savepoints++}${name ? `_${name}` : ""}`; + const promise = run_internal_savepoint(save_point_name, savepoint_callback); + transactionSavepoints.add(promise); + promise.finally(onSavepointFinished.bind(null, promise)); + return await promise; + }; + } + let needs_rollback = false; + try { + await run_internal_transaction_sql(BEGIN_COMMAND); + needs_rollback = true; + let transaction_result = await callback(transaction_sql); + if (Array.isArray(transaction_result)) { + transaction_result = await Promise.all(transaction_result); + } + // at this point we dont need to rollback anymore + needs_rollback = false; + if (BEFORE_COMMIT_OR_ROLLBACK_COMMAND) { + await run_internal_transaction_sql(BEFORE_COMMIT_OR_ROLLBACK_COMMAND); + } + await run_internal_transaction_sql(COMMIT_COMMAND); + return resolve(transaction_result); + } catch (err) { + try { + if (!(state.connectionState & ReservedConnectionState.closed) && needs_rollback) { + if (BEFORE_COMMIT_OR_ROLLBACK_COMMAND) { + await run_internal_transaction_sql(BEFORE_COMMIT_OR_ROLLBACK_COMMAND); + } + await run_internal_transaction_sql(ROLLBACK_COMMAND); + } + } catch (err) { + return reject(err); + } + return reject(err); + } finally { + state.connectionState |= ReservedConnectionState.closed; + pooledConnection.queries.delete(onClose); + if (!dontRelease) { + pool.release(pooledConnection); + } + } + } function sql(strings, ...values) { /** * const users = [ @@ -643,77 +1778,132 @@ function SQL(o) { return new SQLArrayParameter(strings, values); } - if (closed) { - return closedSQL(strings, values); - } - - if (connected) { - return connectedSQL(strings, values); - } - - return pendingSQL(strings, values); + return queryFromPool(strings, values); } - sql.connect = () => { - if (closed) { - return Promise.reject(new Error("Connection closed")); + sql.reserve = () => { + if (pool.closed) { + return Promise.reject(connectionClosedError()); } - if (connected) { - return Promise.resolve(sql); + const promiseWithResolvers = Promise.withResolvers(); + pool.connect(onReserveConnected.bind(promiseWithResolvers), true); + return promiseWithResolvers.promise; + }; + sql.rollbackDistributed = async function (name: string) { + if (pool.closed) { + throw connectionClosedError(); } - - var { resolve, reject, promise } = Promise.withResolvers(); - onConnect.push(err => (err ? reject(err) : resolve(sql))); - if (!connecting) { - connecting = true; - connection = createConnection(connectionInfo, onConnected, onClose); + assertValidTransactionName(name); + const adapter = connectionInfo.adapter; + switch (adapter) { + case "postgres": + return await sql(`ROLLBACK PREPARED '${name}'`); + case "mysql": + return await sql(`XA ROLLBACK '${name}'`); + case "mssql": + throw Error(`MSSQL distributed transaction is automatically rolled back.`); + case "sqlite": + throw Error(`SQLite dont support distributed transactions.`); + default: + throw Error(`Unsupported adapter: ${adapter}.`); } - - return promise; }; - sql.close = () => { - if (closed) { - return Promise.resolve(); + sql.commitDistributed = async function (name: string) { + if (pool.closed) { + throw connectionClosedError(); + } + assertValidTransactionName(name); + const adapter = connectionInfo.adapter; + switch (adapter) { + case "postgres": + return await sql(`COMMIT PREPARED '${name}'`); + case "mysql": + return await sql(`XA COMMIT '${name}'`); + case "mssql": + throw Error(`MSSQL distributed transaction is automatically committed.`); + case "sqlite": + throw Error(`SQLite dont support distributed transactions.`); + default: + throw Error(`Unsupported adapter: ${adapter}.`); } - - var { resolve, promise } = Promise.withResolvers(); - onConnect.push(resolve); - connection.close(); - return promise; }; - sql[Symbol.asyncDispose] = () => sql.close(); + sql.beginDistributed = (name: string, fn: TransactionCallback) => { + if (pool.closed) { + return Promise.reject(connectionClosedError()); + } + let callback = fn; - sql.flush = () => { - if (closed || !connected) { - return; + if (typeof name !== "string") { + return Promise.reject($ERR_INVALID_ARG_VALUE("name", name, "must be a string")); } - connection.flush(); + if (!$isCallable(callback)) { + return Promise.reject($ERR_INVALID_ARG_VALUE("fn", callback, "must be a function")); + } + const { promise, resolve, reject } = Promise.withResolvers(); + // lets just reuse the same code path as the transaction begin + pool.connect(onTransactionConnected.bind(null, callback, name, resolve, reject, false, true), true); + return promise; }; - sql.options = connectionInfo; - sql.then = () => { - if (closed) { - return Promise.reject(new Error("Connection closed")); + sql.begin = (options_or_fn: string | TransactionCallback, fn?: TransactionCallback) => { + if (pool.closed) { + return Promise.reject(connectionClosedError()); + } + let callback = fn; + let options: string | undefined = options_or_fn as unknown as string; + if ($isCallable(options_or_fn)) { + callback = options_or_fn as unknown as TransactionCallback; + options = undefined; + } else if (typeof options_or_fn !== "string") { + return Promise.reject($ERR_INVALID_ARG_VALUE("options", options_or_fn, "must be a string")); + } + if (!$isCallable(callback)) { + return Promise.reject($ERR_INVALID_ARG_VALUE("fn", callback, "must be a function")); + } + const { promise, resolve, reject } = Promise.withResolvers(); + pool.connect(onTransactionConnected.bind(null, callback, options, resolve, reject, false, false), true); + return promise; + }; + sql.connect = () => { + if (pool.closed) { + return Promise.reject(connectionClosedError()); } - if (connected) { + if (pool.isConnected()) { return Promise.resolve(sql); } - const { resolve, reject, promise } = Promise.withResolvers(); - onConnect.push(err => (err ? reject(err) : resolve(sql))); - if (!connecting) { - connecting = true; - connection = createConnection(connectionInfo, onConnected, onClose); - } + let { resolve, reject, promise } = Promise.withResolvers(); + const onConnected = (err, connection) => { + if (err) { + return reject(err); + } + // we are just measuring the connection here lets release it + pool.release(connection); + resolve(sql); + }; + + pool.connect(onConnected); return promise; }; + sql.close = async (options?: { timeout?: number }) => { + await pool.close(options); + }; + + sql[Symbol.asyncDispose] = () => sql.close(); + + sql.flush = () => pool.flush(); + sql.options = connectionInfo; + + sql.transaction = sql.begin; + sql.distributed = sql.beginDistributed; + sql.end = sql.close; return sql; } diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index c0f2bbef847e23..445ad4e654f188 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -217,8 +217,11 @@ pub const PostgresSQLQuery = struct { binary: bool = false, pub usingnamespace JSC.Codegen.JSPostgresSQLQuery; - + const log = bun.Output.scoped(.PostgresSQLQuery, false); pub fn getTarget(this: *PostgresSQLQuery, globalObject: *JSC.JSGlobalObject) JSC.JSValue { + if (this.thisValue == .zero) { + return .zero; + } const target = PostgresSQLQuery.targetGetCached(this.thisValue) orelse return .zero; PostgresSQLQuery.targetSetCached(this.thisValue, globalObject, .zero); return target; @@ -325,10 +328,13 @@ pub const PostgresSQLQuery = struct { return; } - // TODO: error handling var vm = JSC.VirtualMachine.get(); const function = vm.rareData().postgresql_context.onQueryRejectFn.get().?; - globalObject.queueMicrotask(function, &[_]JSValue{ targetValue, err.toJS(globalObject) }); + const event_loop = vm.eventLoop(); + event_loop.runCallback(function, globalObject, thisValue, &.{ + targetValue, + err.toJS(globalObject), + }); } const CommandTag = union(enum) { @@ -484,9 +490,14 @@ pub const PostgresSQLQuery = struct { pub fn call(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { const arguments = callframe.arguments_old(4).slice(); - const query = arguments[0]; - const values = arguments[1]; - const columns = arguments[3]; + var args = JSC.Node.ArgumentsSlice.init(globalThis.bunVM(), arguments); + defer args.deinit(); + const query = args.nextEat() orelse { + return globalThis.throw("query must be a string", .{}); + }; + const values = args.nextEat() orelse { + return globalThis.throw("values must be an array", .{}); + }; if (!query.isString()) { return globalThis.throw("query must be a string", .{}); @@ -496,7 +507,9 @@ pub const PostgresSQLQuery = struct { return globalThis.throw("values must be an array", .{}); } - const pending_value = arguments[2]; + const pending_value = args.nextEat() orelse .undefined; + const columns = args.nextEat() orelse .undefined; + if (!pending_value.jsType().isArrayLike()) { return globalThis.throwInvalidArgumentType("query", "pendingValue", "Array"); } @@ -573,10 +586,11 @@ pub const PostgresSQLQuery = struct { signature.deinit(); if (has_params and this.statement.?.status == .parsing) { + // if it has params, we need to wait for ParamDescription to be received before we can write the data } else { this.binary = this.statement.?.fields.len > 0; - + log("bindAndExecute", .{}); PostgresRequest.bindAndExecute(globalObject, this.statement.?, binding_value, columns_value, PostgresSQLConnection.Writer, writer) catch |err| { if (!globalObject.hasException()) return globalObject.throwError(err, "failed to bind and execute query"); @@ -2447,7 +2461,7 @@ pub const PostgresSQLConnection = struct { DataCell.Putter.put, ); - const pending_value = PostgresSQLQuery.pendingValueGetCached(request.thisValue) orelse .zero; + const pending_value = if (request.thisValue == .zero) .zero else PostgresSQLQuery.pendingValueGetCached(request.thisValue) orelse .zero; pending_value.ensureStillAlive(); const result = putter.toJS(this.globalObject, pending_value, statement.structure(this.js_value, this.globalObject), statement.fields_flags); diff --git a/test/js/sql/sql.test.ts b/test/js/sql/sql.test.ts index ce681484fe6004..4d4fd38c613264 100644 --- a/test/js/sql/sql.test.ts +++ b/test/js/sql/sql.test.ts @@ -1,4 +1,5 @@ -import { postgres, sql } from "bun:sql"; +import { sql } from "bun"; +const postgres = (...args) => new sql(...args); import { expect, test, mock } from "bun:test"; import { $ } from "bun"; import { bunExe, isCI, withoutAggressiveGC } from "harness"; @@ -382,168 +383,240 @@ if (!isCI && hasPsql) { } }); - // t('Throws on illegal transactions', async() => { - // const sql = postgres({ ...options, max: 2, fetch_types: false }) - // const error = await sql`begin`.catch(e => e) - // return [ - // error.code, - // 'UNSAFE_TRANSACTION' - // ] - // }) - - // t('Transaction throws', async() => { - // await sql`create table test (a int)` - // return ['22P02', await sql.begin(async sql => { - // await sql`insert into test values(1)` - // await sql`insert into test values('hej')` - // }).catch(x => x.code), await sql`drop table test`] - // }) - - // t('Transaction rolls back', async() => { - // await sql`create table test (a int)` - // await sql.begin(async sql => { - // await sql`insert into test values(1)` - // await sql`insert into test values('hej')` - // }).catch(() => { /* ignore */ }) - // return [0, (await sql`select a from test`).count, await sql`drop table test`] - // }) - - // t('Transaction throws on uncaught savepoint', async() => { - // await sql`create table test (a int)` + test("Throws on illegal transactions", async () => { + const sql = postgres({ ...options, max: 2, fetch_types: false }); + const error = await sql`begin`.catch(e => e); + return expect(error.code).toBe("ERR_POSTGRES_UNSAFE_TRANSACTION"); + }); - // return ['fail', (await sql.begin(async sql => { - // await sql`insert into test values(1)` - // await sql.savepoint(async sql => { - // await sql`insert into test values(2)` - // throw new Error('fail') - // }) - // }).catch((err) => err.message)), await sql`drop table test`] - // }) + test("Transaction throws", async () => { + await sql`create table if not exists test (a int)`; + try { + expect( + await sql + .begin(async sql => { + await sql`insert into test values(1)`; + await sql`insert into test values('hej')`; + }) + .catch(e => e.errno), + ).toBe(22); + } finally { + await sql`drop table test`; + } + }); - // t('Transaction throws on uncaught named savepoint', async() => { - // await sql`create table test (a int)` + test("Transaction rolls back", async () => { + await sql`create table if not exists test (a int)`; - // return ['fail', (await sql.begin(async sql => { - // await sql`insert into test values(1)` - // await sql.savepoit('watpoint', async sql => { - // await sql`insert into test values(2)` - // throw new Error('fail') - // }) - // }).catch(() => 'fail')), await sql`drop table test`] - // }) + try { + await sql + .begin(async sql => { + await sql`insert into test values(1)`; + await sql`insert into test values('hej')`; + }) + .catch(() => { + /* ignore */ + }); + + expect((await sql`select a from test`).count).toBe(0); + } finally { + await sql`drop table test`; + } + }); - // t('Transaction succeeds on caught savepoint', async() => { - // await sql`create table test (a int)` - // await sql.begin(async sql => { - // await sql`insert into test values(1)` - // await sql.savepoint(async sql => { - // await sql`insert into test values(2)` - // throw new Error('please rollback') - // }).catch(() => { /* ignore */ }) - // await sql`insert into test values(3)` - // }) + test("Transaction throws on uncaught savepoint", async () => { + await sql`create table test (a int)`; + try { + expect( + await sql + .begin(async sql => { + await sql`insert into test values(1)`; + await sql.savepoint(async sql => { + await sql`insert into test values(2)`; + throw new Error("fail"); + }); + }) + .catch(err => err.message), + ).toBe("fail"); + } finally { + await sql`drop table test`; + } + }); - // return ['2', (await sql`select count(1) from test`)[0].count, await sql`drop table test`] - // }) + test("Transaction throws on uncaught named savepoint", async () => { + await sql`create table test (a int)`; + try { + expect( + await sql + .begin(async sql => { + await sql`insert into test values(1)`; + await sql.savepoit("watpoint", async sql => { + await sql`insert into test values(2)`; + throw new Error("fail"); + }); + }) + .catch(() => "fail"), + ).toBe("fail"); + } finally { + await sql`drop table test`; + } + }); - // t('Savepoint returns Result', async() => { - // let result - // await sql.begin(async sql => { - // result = await sql.savepoint(sql => - // sql`select 1 as x` - // ) - // }) + test("Transaction succeeds on caught savepoint", async () => { + try { + await sql`create table test (a int)`; + await sql.begin(async sql => { + await sql`insert into test values(1)`; + await sql + .savepoint(async sql => { + await sql`insert into test values(2)`; + throw new Error("please rollback"); + }) + .catch(() => { + /* ignore */ + }); + await sql`insert into test values(3)`; + }); + expect((await sql`select count(1) from test`)[0].count).toBe("2"); + } finally { + await sql`drop table test`; + } + }); - // return [1, result[0].x] - // }) + test("Savepoint returns Result", async () => { + let result; + await sql.begin(async t => { + result = await t.savepoint(s => s`select 1 as x`); + }); + expect(result[0]?.x).toBe(1); + }); - // t('Prepared transaction', async() => { - // await sql`create table test (a int)` + // test("Prepared transaction", async () => { + // await sql`create table test (a int)`; // await sql.begin(async sql => { - // await sql`insert into test values(1)` - // await sql.prepare('tx1') - // }) + // await sql`insert into test values(1)`; + // await sql.prepare("tx1"); + // }); - // await sql`commit prepared 'tx1'` - - // return ['1', (await sql`select count(1) from test`)[0].count, await sql`drop table test`] - // }) + // await sql`commit prepared 'tx1'`; + // try { + // expect((await sql`select count(1) from test`)[0].count).toBe("1"); + // } finally { + // await sql`drop table test`; + // } + // }); - // t('Transaction requests are executed implicitly', async() => { - // const sql = postgres({ debug: true, idle_timeout: 1, fetch_types: false }) - // return [ - // 'testing', - // (await sql.begin(sql => [ - // sql`select set_config('bun_sql.test', 'testing', true)`, - // sql`select current_setting('bun_sql.test') as x` - // ]))[1][0].x - // ] - // }) + test("Prepared transaction", async () => { + await sql`create table test (a int)`; - // t('Uncaught transaction request errors bubbles to transaction', async() => [ - // '42703', - // (await sql.begin(sql => [ - // sql`select wat`, - // sql`select current_setting('bun_sql.test') as x, ${ 1 } as a` - // ]).catch(e => e.code)) - // ]) + try { + await sql.beginDistributed("tx1", async sql => { + await sql`insert into test values(1)`; + }); - // t('Fragments in transactions', async() => [ - // true, - // (await sql.begin(sql => sql`select true as x where ${ sql`1=1` }`))[0].x - // ]) + await sql.commitDistributed("tx1"); + expect((await sql`select count(1) from test`)[0].count).toBe("1"); + } finally { + await sql`drop table test`; + } + }); - // t('Transaction rejects with rethrown error', async() => [ - // 'WAT', - // await sql.begin(async sql => { - // try { - // await sql`select exception` - // } catch (ex) { - // throw new Error('WAT') - // } - // }).catch(e => e.message) - // ]) + test("Transaction requests are executed implicitly", async () => { + const sql = postgres({ ...options, debug: true, idle_timeout: 1, fetch_types: false }); + expect( + ( + await sql.begin(sql => [ + sql`select set_config('bun_sql.test', 'testing', true)`, + sql`select current_setting('bun_sql.test') as x`, + ]) + )[1][0].x, + ).toBe("testing"); + }); - // t('Parallel transactions', async() => { - // await sql`create table test (a int)` - // return ['11', (await Promise.all([ - // sql.begin(sql => sql`select 1`), - // sql.begin(sql => sql`select 1`) - // ])).map(x => x.count).join(''), await sql`drop table test`] - // }) + test("Uncaught transaction request errosó rs bubbles to transaction", async () => { + const sql = postgres({ ...options, debug: true, idle_timeout: 1, fetch_types: false }); + expect( + await sql + .begin(sql => [sql`select wat`, sql`select current_setting('bun_sql.test') as x, ${1} as a`]) + .catch(e => e.errno), + ).toBe(42703); + }); - // t("Many transactions at beginning of connection", async () => { - // const sql = postgres(options); - // const xs = await Promise.all(Array.from({ length: 100 }, () => sql.begin(sql => sql`select 1`))); - // return [100, xs.length]; + // test.only("Fragments in transactions", async () => { + // const sql = postgres({ ...options, debug: true, idle_timeout: 1, fetch_types: false }); + // expect((await sql.begin(sql => sql`select true as x where ${sql`1=1`}`))[0].x).toBe(true); // }); - // t('Transactions array', async() => { - // await sql`create table test (a int)` + test("Transaction rejects with rethrown error", async () => { + await using sql = postgres({ ...options }); + expect( + await sql + .begin(async sql => { + try { + await sql`select exception`; + } catch (ex) { + throw new Error("WAT"); + } + }) + .catch(e => e.message), + ).toBe("WAT"); + }); - // return ['11', (await sql.begin(sql => [ - // sql`select 1`.then(x => x), - // sql`select 1` - // ])).map(x => x.count).join(''), await sql`drop table test`] - // }) + test("Parallel transactions", async () => { + await sql`create table test (a int)`; + expect( + (await Promise.all([sql.begin(sql => sql`select 1 as count`), sql.begin(sql => sql`select 1 as count`)])) + .map(x => x[0].count) + .join(""), + ).toBe("11"); + await sql`drop table test`; + }); - // t('Transaction waits', async() => { - // await sql`create table test (a int)` - // await sql.begin(async sql => { - // await sql`insert into test values(1)` - // await sql.savepoint(async sql => { - // await sql`insert into test values(2)` - // throw new Error('please rollback') - // }).catch(() => { /* ignore */ }) - // await sql`insert into test values(3)` - // }) + test("Many transactions at beginning of connection", async () => { + await using sql = postgres(options); + const xs = await Promise.all(Array.from({ length: 100 }, () => sql.begin(sql => sql`select 1`))); + return expect(xs.length).toBe(100); + }); - // return ['11', (await Promise.all([ - // sql.begin(sql => sql`select 1`), - // sql.begin(sql => sql`select 1`) - // ])).map(x => x.count).join(''), await sql`drop table test`] - // }) + test("Transactions array", async () => { + await using sql = postgres(options); + await sql`create table test (a int)`; + try { + expect( + (await sql.begin(sql => [sql`select 1 as count`, sql`select 1 as count`])).map(x => x[0].count).join(""), + ).toBe("11"); + } finally { + await sql`drop table test`; + } + }); + + test("Transaction waits", async () => { + await using sql = postgres({ ...options }); + await sql`create table test (a int)`; + try { + await sql.begin(async sql => { + await sql`insert into test values(1)`; + await sql + .savepoint(async sql => { + await sql`insert into test values(2)`; + throw new Error("please rollback"); + }) + .catch(() => { + /* ignore */ + }); + await sql`insert into test values(3)`; + }); + + expect( + (await Promise.all([sql.begin(sql => sql`select 1 as count`), sql.begin(sql => sql`select 1 as count`)])) + .map(x => x[0].count) + .join(""), + ).toBe("11"); + } finally { + await sql`drop table test`; + } + }); // t('Helpers in Transaction', async() => { // return ['1', (await sql.begin(async sql =>