diff --git a/README.md b/README.md index b4aa5333..8db62cd1 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,33 @@ const urlProvider = async () => { const rws = new ReconnectingWebSocket(urlProvider); ``` +### Update Protocols + +The `protocols` parameter will be resolved before connecting, possible types: + +- `null` +- `string` +- `string[]` +- `() => string | string[] | null` +- `() => Promise` + +```javascript +import ReconnectingWebSocket from 'reconnecting-websocket`; +const rws = new ReconnectingWebSocket('ws://your.site.com', 'your protocol'); +``` + +```javascript +import ReconnectingWebSocket from 'reconnecting-websocket`; + +const protocols = ['p1', 'p2', ['p3.1', 'p3.2']]; +let protocolsIndex = 0; + +// round robin protocols provider +const protocolsProvider = () => protocols[protocolsIndex++ % protocols.length]; + +const rws = new ReconnectingWebSocket('ws://your.site.com', protocolsProvider); +``` + ### Options #### Sample with custom options @@ -130,7 +157,7 @@ debug: false, ### Methods ```typescript -constructor(url: UrlProvider, protocols?: string | string[], options?: Options) +constructor(url: UrlProvider, protocols?: ProtocolsProvider, options?: Options) close(code?: number, reason?: string) reconnect(code?: number, reason?: string) diff --git a/__tests__/test.ts b/__tests__/test.ts index b05922fc..c61c0533 100644 --- a/__tests__/test.ts +++ b/__tests__/test.ts @@ -178,6 +178,38 @@ test('null websocket protocol', done => { }); }); +test('websocket sync protocolsProvider', done => { + const anyProtocol = 'bar'; + const wss = new WebSocketServer({port: PORT}); + + const ws = new ReconnectingWebSocket(URL, () => anyProtocol, {}); + ws.addEventListener('open', () => { + expect(ws.url).toBe(URL); + expect(ws.protocol).toBe(anyProtocol); + ws.close(); + }); + + ws.addEventListener('close', () => { + wss.close(() => setTimeout(done, 100)); + }); +}); + +test('websocket async protocolsProvider', done => { + const anyProtocol = 'foo'; + const wss = new WebSocketServer({port: PORT}); + + const ws = new ReconnectingWebSocket(URL, async () => anyProtocol, {}); + ws.addEventListener('open', () => { + expect(ws.url).toBe(URL); + expect(ws.protocol).toBe(anyProtocol); + ws.close(); + }); + + ws.addEventListener('close', () => { + wss.close(() => setTimeout(done, 100)); + }); +}); + test('connection status constants', () => { const ws = new ReconnectingWebSocket(URL, undefined, {maxRetries: 0}); diff --git a/reconnecting-websocket.ts b/reconnecting-websocket.ts index ba7f4c47..3b44a027 100644 --- a/reconnecting-websocket.ts +++ b/reconnecting-websocket.ts @@ -49,6 +49,12 @@ const DEFAULT = { }; export type UrlProvider = string | (() => string) | (() => Promise); +export type ProtocolsProvider = + | null + | string + | string[] + | (() => string | string[] | null) + | (() => Promise); export type Message = string | ArrayBuffer | Blob | ArrayBufferView; @@ -77,10 +83,10 @@ export default class ReconnectingWebSocket { private _messageQueue: Message[] = []; private readonly _url: UrlProvider; - private readonly _protocols?: string | string[]; + private readonly _protocols?: ProtocolsProvider; private readonly _options: Options; - constructor(url: UrlProvider, protocols?: string | string[], options: Options = {}) { + constructor(url: UrlProvider, protocols?: ProtocolsProvider, options: Options = {}) { this._url = url; this._protocols = protocols; this._options = options; @@ -330,6 +336,32 @@ export default class ReconnectingWebSocket { }); } + private _getNextProtocols( + protocolsProvider: ProtocolsProvider | null, + ): Promise { + if (!protocolsProvider) return Promise.resolve(null); + + if (typeof protocolsProvider === 'string' || Array.isArray(protocolsProvider)) { + return Promise.resolve(protocolsProvider); + } + + if (typeof protocolsProvider === 'function') { + const protocols = protocolsProvider(); + if (!protocols) return Promise.resolve(null); + + if (typeof protocols === 'string' || Array.isArray(protocols)) { + return Promise.resolve(protocols); + } + + // @ts-ignore redundant check + if (protocols.then) { + return protocols; + } + } + + throw Error('Invalid protocols'); + } + private _getNextUrl(urlProvider: UrlProvider): Promise { if (typeof urlProvider === 'string') { return Promise.resolve(urlProvider); @@ -372,16 +404,19 @@ export default class ReconnectingWebSocket { throw Error('No valid WebSocket class provided'); } this._wait() - .then(() => this._getNextUrl(this._url)) - .then(url => { + .then(() => + Promise.all([ + this._getNextUrl(this._url), + this._getNextProtocols(this._protocols || null), + ]), + ) + .then(([url, protocols]) => { // close could be called before creating the ws if (this._closeCalled) { return; } - this._debug('connect', {url, protocols: this._protocols}); - this._ws = this._protocols - ? new WebSocket(url, this._protocols) - : new WebSocket(url); + this._debug('connect', {url, protocols}); + this._ws = protocols ? new WebSocket(url, protocols) : new WebSocket(url); this._ws!.binaryType = this._binaryType; this._connectLock = false; this._addListeners();