diff --git a/package.json b/package.json index 438fc6063..6b169f544 100644 --- a/package.json +++ b/package.json @@ -137,7 +137,7 @@ "@bundled-es-modules/statuses": "^1.0.1", "@bundled-es-modules/tough-cookie": "^0.1.6", "@inquirer/confirm": "^3.0.0", - "@mswjs/interceptors": "^0.36.1", + "@mswjs/interceptors": "^0.36.4", "@open-draft/deferred-promise": "^2.2.0", "@open-draft/until": "^2.1.0", "@types/cookie": "^0.6.0", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3a630b654..22fc02bc1 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -18,8 +18,8 @@ dependencies: specifier: ^3.0.0 version: 3.1.1 '@mswjs/interceptors': - specifier: ^0.36.1 - version: 0.36.1 + specifier: ^0.36.4 + version: 0.36.4 '@open-draft/deferred-promise': specifier: ^2.2.0 version: 2.2.0 @@ -1457,8 +1457,8 @@ packages: - utf-8-validate dev: true - /@mswjs/interceptors@0.36.1: - resolution: {integrity: sha512-wjKbjecynhZIqtvTuq61Q1JOzC1obSlrn1U6/f9JvwIw0zFX7+CZD3C/ncdUPbxpmfgq4GoyY2mFU+JobK+1SQ==} + /@mswjs/interceptors@0.36.4: + resolution: {integrity: sha512-ktzj7bra4HatOGqXw/PXyresXxFtnZa570rm4olAyf9HbvNdEWRkQl81ykmJK0nCHxNndmh2zQ84TBYKFDM+sg==} engines: {node: '>=18'} dependencies: '@open-draft/deferred-promise': 2.2.0 diff --git a/src/core/handlers/WebSocketHandler.ts b/src/core/handlers/WebSocketHandler.ts index 1cd5043d9..55f66def0 100644 --- a/src/core/handlers/WebSocketHandler.ts +++ b/src/core/handlers/WebSocketHandler.ts @@ -1,4 +1,5 @@ import { Emitter } from 'strict-event-emitter' +import { createRequestId } from '@mswjs/interceptors' import type { WebSocketConnectionData } from '@mswjs/interceptors/WebSocket' import { type Match, @@ -23,13 +24,18 @@ interface WebSocketHandlerConnection extends WebSocketConnectionData { export const kEmitter = Symbol('kEmitter') export const kDispatchEvent = Symbol('kDispatchEvent') export const kSender = Symbol('kSender') +const kStopPropagationPatched = Symbol('kStopPropagationPatched') +const KOnStopPropagation = Symbol('KOnStopPropagation') export class WebSocketHandler { + public id: string public callFrame?: string protected [kEmitter]: Emitter constructor(private readonly url: Path) { + this.id = createRequestId() + this[kEmitter] = new Emitter() this.callFrame = getCallFrame(new Error()) } @@ -63,8 +69,74 @@ export class WebSocketHandler { params: parsedResult.match.params || {}, } + // Support `event.stopPropagation()` for various client/server events. + connection.client.addEventListener( + 'message', + createStopPropagationListener(this), + ) + connection.client.addEventListener( + 'close', + createStopPropagationListener(this), + ) + + connection.server.addEventListener( + 'open', + createStopPropagationListener(this), + ) + connection.server.addEventListener( + 'message', + createStopPropagationListener(this), + ) + connection.server.addEventListener( + 'error', + createStopPropagationListener(this), + ) + connection.server.addEventListener( + 'close', + createStopPropagationListener(this), + ) + // Emit the connection event on the handler. // This is what the developer adds listeners for. this[kEmitter].emit('connection', resolvedConnection) } } + +function createStopPropagationListener(handler: WebSocketHandler) { + return function stopPropagationListener(event: Event) { + const propagationStoppedAt = Reflect.get(event, 'kPropagationStoppedAt') as + | string + | undefined + + if (propagationStoppedAt && handler.id !== propagationStoppedAt) { + event.stopImmediatePropagation() + return + } + + Object.defineProperty(event, KOnStopPropagation, { + value(this: WebSocketHandler) { + Object.defineProperty(event, 'kPropagationStoppedAt', { + value: handler.id, + }) + }, + configurable: true, + }) + + // Since the same event instance is shared between all client/server objects, + // make sure to patch its `stopPropagation` method only once. + if (!Reflect.get(event, kStopPropagationPatched)) { + event.stopPropagation = new Proxy(event.stopPropagation, { + apply: (target, thisArg, args) => { + Reflect.get(event, KOnStopPropagation)?.call(handler) + return Reflect.apply(target, thisArg, args) + }, + }) + + Object.defineProperty(event, kStopPropagationPatched, { + value: true, + // If something else attempts to redefine this, throw. + configurable: false, + }) + } + } +} diff --git a/test/node/ws-api/on-unhandled-request/error.test.ts b/test/node/ws-api/on-unhandled-request/error.test.ts index 36873939c..5b34d2b1b 100644 --- a/test/node/ws-api/on-unhandled-request/error.test.ts +++ b/test/node/ws-api/on-unhandled-request/error.test.ts @@ -27,12 +27,14 @@ it( const socket = new WebSocket('wss://localhost:4321') const errorListener = vi.fn() - await vi.waitFor(() => { + await vi.waitUntil(() => { return new Promise((resolve, reject) => { // These are intentionally swapped. The connection MUST error. socket.addEventListener('error', errorListener) socket.addEventListener('error', resolve) - socket.onopen = reject + socket.onopen = () => { + reject(new Error('WebSocket connection opened unexpectedly')) + } }) }) diff --git a/test/node/ws-api/ws.stop-propagation.test.ts b/test/node/ws-api/ws.stop-propagation.test.ts new file mode 100644 index 000000000..b7cbb3cf6 --- /dev/null +++ b/test/node/ws-api/ws.stop-propagation.test.ts @@ -0,0 +1,493 @@ +// @vitest-environment node-websocket +import { ws } from 'msw' +import { setupServer } from 'msw/node' +import { WebSocketServer } from '../../support/WebSocketServer' + +const server = setupServer() +const service = ws.link('ws://*') + +const originalServer = new WebSocketServer() + +beforeAll(async () => { + server.listen({ + // We are intentionally connecting to non-existing WebSocket URLs. + // Skip the unhandled request warnings, they are intentional. + onUnhandledRequest: 'bypass', + }) + await originalServer.listen() +}) + +afterEach(() => { + server.resetHandlers() + originalServer.resetState() +}) + +afterAll(async () => { + server.close() + await originalServer.close() +}) + +it('stops propagation for client "message" event', async () => { + const clientMessageListener = vi.fn<[number]>() + + server.use( + service.addEventListener('connection', ({ client }) => { + client.addEventListener('message', (event) => { + // Calling `stopPropagation` will prevent this event from being + // dispatched on the `client` beloning to a different event handler. + event.stopPropagation() + clientMessageListener(1) + }) + + client.addEventListener('message', () => { + clientMessageListener(2) + }) + }), + + service.addEventListener('connection', ({ client }) => { + client.addEventListener('message', () => { + clientMessageListener(3) + }) + }), + + service.addEventListener('connection', ({ client }) => { + client.addEventListener('message', () => { + clientMessageListener(4) + }) + + process.nextTick(() => { + client.close() + }) + }), + ) + + const ws = new WebSocket('ws://localhost') + ws.onopen = () => ws.send('hello world') + + await vi.waitFor(() => { + expect(ws.readyState).toBe(WebSocket.CLOSED) + }) + + expect(clientMessageListener).toHaveBeenNthCalledWith(1, 1) + expect(clientMessageListener).toHaveBeenNthCalledWith(2, 2) + expect(clientMessageListener).toHaveBeenCalledTimes(2) +}) + +it('stops immediate propagation for client "message" event', async () => { + const clientMessageListener = vi.fn<[number]>() + + server.use( + service.addEventListener('connection', ({ client }) => { + client.addEventListener('message', (event) => { + // Calling `stopPropagation` will prevent this event from being + // dispatched on the `client` beloning to a different event handler. + event.stopImmediatePropagation() + clientMessageListener(1) + }) + + client.addEventListener('message', () => { + clientMessageListener(2) + }) + + client.addEventListener('message', () => { + clientMessageListener(3) + }) + }), + + service.addEventListener('connection', ({ client }) => { + client.addEventListener('message', () => { + clientMessageListener(4) + }) + + process.nextTick(() => { + client.close() + }) + }), + ) + + const ws = new WebSocket('ws://localhost') + ws.onopen = () => ws.send('hello world') + + await vi.waitFor(() => { + expect(ws.readyState).toBe(WebSocket.CLOSED) + }) + + expect(clientMessageListener).toHaveBeenNthCalledWith(1, 1) + expect(clientMessageListener).toHaveBeenCalledOnce() +}) + +it('stops propagation for server "open" event', async () => { + const serverOpenListener = vi.fn<[number]>() + + originalServer.addListener('connection', () => {}) + + server.use( + service.addEventListener('connection', ({ client, server }) => { + server.connect() + + server.addEventListener('open', (event) => { + // Calling `stopPropagation` will prevent this event from being + // dispatched on the `server` beloning to a different event handler. + event.stopPropagation() + serverOpenListener(1) + + process.nextTick(() => client.close()) + }) + + server.addEventListener('open', () => { + serverOpenListener(2) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('open', () => { + serverOpenListener(3) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('open', () => { + serverOpenListener(4) + }) + }), + ) + + const ws = new WebSocket(originalServer.url) + + await vi.waitFor(() => { + expect(ws.readyState).toBe(WebSocket.CLOSED) + }) + + expect(serverOpenListener).toHaveBeenNthCalledWith(1, 1) + expect(serverOpenListener).toHaveBeenNthCalledWith(2, 2) + expect(serverOpenListener).toHaveBeenCalledTimes(2) +}) + +it('stops immediate propagation for server "open" event', async () => { + const serverOpenListener = vi.fn<[number]>() + + originalServer.addListener('connection', () => {}) + + server.use( + service.addEventListener('connection', ({ client, server }) => { + server.connect() + + server.addEventListener('open', (event) => { + event.stopImmediatePropagation() + serverOpenListener(1) + + process.nextTick(() => client.close()) + }) + + server.addEventListener('open', () => { + serverOpenListener(2) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('open', () => { + serverOpenListener(3) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('open', () => { + serverOpenListener(4) + }) + }), + ) + + const ws = new WebSocket(originalServer.url) + + await vi.waitFor(() => { + expect(ws.readyState).toBe(WebSocket.CLOSED) + }) + + expect(serverOpenListener).toHaveBeenNthCalledWith(1, 1) + expect(serverOpenListener).toHaveBeenCalledOnce() +}) + +it('stops propagation for server "message" event', async () => { + const serverMessageListener = vi.fn<[number]>() + + originalServer.addListener('connection', (ws) => { + // Send data from the original server to trigger the "message" event. + ws.send('hello') + }) + + server.use( + service.addEventListener('connection', ({ client, server }) => { + server.connect() + + server.addEventListener('message', (event) => { + // Calling `stopPropagation` will prevent this event from being + // dispatched on the `server` beloning to a different event handler. + event.stopPropagation() + serverMessageListener(1) + + process.nextTick(() => client.close()) + }) + + server.addEventListener('message', () => { + serverMessageListener(2) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('message', () => { + serverMessageListener(3) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('message', () => { + serverMessageListener(4) + }) + }), + ) + + const ws = new WebSocket(originalServer.url) + + await vi.waitFor(() => { + expect(ws.readyState).toBe(WebSocket.CLOSED) + }) + + expect(serverMessageListener).toHaveBeenNthCalledWith(1, 1) + expect(serverMessageListener).toHaveBeenNthCalledWith(2, 2) + expect(serverMessageListener).toHaveBeenCalledTimes(2) +}) + +it('stops immediate propagation for server "message" event', async () => { + const serverMessageListener = vi.fn<[number]>() + + originalServer.addListener('connection', (ws) => { + // Send data from the original server to trigger the "message" event. + ws.send('hello') + }) + + server.use( + service.addEventListener('connection', ({ client, server }) => { + server.connect() + + server.addEventListener('message', (event) => { + event.stopImmediatePropagation() + serverMessageListener(1) + + process.nextTick(() => client.close()) + }) + + server.addEventListener('message', () => { + serverMessageListener(2) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('message', () => { + serverMessageListener(3) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('message', () => { + serverMessageListener(4) + }) + }), + ) + + const ws = new WebSocket(originalServer.url) + + await vi.waitFor(() => { + expect(ws.readyState).toBe(WebSocket.CLOSED) + }) + + expect(serverMessageListener).toHaveBeenNthCalledWith(1, 1) + expect(serverMessageListener).toHaveBeenCalledOnce() +}) + +it('stops propagation for server "error" event', async () => { + const serverErrorListener = vi.fn<[number]>() + + server.use( + service.addEventListener('connection', ({ client, server }) => { + server.connect() + + server.addEventListener('error', (event) => { + event.stopPropagation() + serverErrorListener(1) + }) + + server.addEventListener('error', () => { + serverErrorListener(2) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('error', () => { + serverErrorListener(3) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('error', () => { + serverErrorListener(4) + }) + }), + ) + + const ws = new WebSocket('ws://localhost/non-existing-path') + + await vi.waitFor(() => { + /** + * @note Ideally, await the "CLOSED" ready state, + * but Node.js doesn't dispatch it correctly. + * @see https://github.com/nodejs/undici/issues/3697 + */ + return new Promise((resolve) => { + ws.onerror = () => resolve() + }) + }) + + expect(serverErrorListener).toHaveBeenNthCalledWith(1, 1) + expect(serverErrorListener).toHaveBeenNthCalledWith(2, 2) + expect(serverErrorListener).toHaveBeenCalledTimes(2) +}) + +it('stops immediate propagation for server "error" event', async () => { + const serverErrorListener = vi.fn<[number]>() + + server.use( + service.addEventListener('connection', ({ client, server }) => { + server.connect() + + server.addEventListener('error', (event) => { + event.stopImmediatePropagation() + serverErrorListener(1) + }) + + server.addEventListener('error', () => { + serverErrorListener(2) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('error', () => { + serverErrorListener(3) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('error', () => { + serverErrorListener(4) + }) + }), + ) + + const ws = new WebSocket('ws://localhost/non-existing-path') + + await vi.waitFor(() => { + /** + * @note Ideally, await the "CLOSED" ready state, + * but Node.js doesn't dispatch it correctly. + * @see https://github.com/nodejs/undici/issues/3697 + */ + return new Promise((resolve) => { + ws.onerror = () => resolve() + }) + }) + + expect(serverErrorListener).toHaveBeenNthCalledWith(1, 1) + expect(serverErrorListener).toHaveBeenCalledOnce() +}) + +it('stops propagation for server "close" event', async () => { + const serverCloseListener = vi.fn<[number]>() + + originalServer.addListener('connection', (ws) => { + ws.close() + }) + + server.use( + service.addEventListener('connection', ({ client, server }) => { + server.connect() + + server.addEventListener('close', (event) => { + event.stopPropagation() + serverCloseListener(1) + + process.nextTick(() => client.close()) + }) + + server.addEventListener('close', () => { + serverCloseListener(2) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('close', () => { + serverCloseListener(3) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('close', () => { + serverCloseListener(4) + }) + }), + ) + + const ws = new WebSocket(originalServer.url) + + await vi.waitFor(() => { + expect(ws.readyState).toBe(WebSocket.CLOSED) + }) + + expect(serverCloseListener).toHaveBeenNthCalledWith(1, 1) + expect(serverCloseListener).toHaveBeenNthCalledWith(2, 2) + expect(serverCloseListener).toHaveBeenCalledTimes(2) +}) + +it('stops immediate propagation for server "close" event', async () => { + const serverCloseListener = vi.fn<[number]>() + + originalServer.addListener('connection', (ws) => { + ws.close() + }) + + server.use( + service.addEventListener('connection', ({ client, server }) => { + server.connect() + + server.addEventListener('close', (event) => { + event.stopImmediatePropagation() + serverCloseListener(1) + + process.nextTick(() => client.close()) + }) + + server.addEventListener('close', () => { + serverCloseListener(2) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('close', () => { + serverCloseListener(3) + }) + }), + + service.addEventListener('connection', ({ server }) => { + server.addEventListener('close', () => { + serverCloseListener(4) + }) + }), + ) + + const ws = new WebSocket(originalServer.url) + + await vi.waitFor(() => { + expect(ws.readyState).toBe(WebSocket.CLOSED) + }) + + expect(serverCloseListener).toHaveBeenNthCalledWith(1, 1) + expect(serverCloseListener).toHaveBeenCalledOnce() +})