diff --git a/src/ErrorBoundary.ts b/src/ErrorBoundary.ts index bf3dc70..1c529fc 100644 --- a/src/ErrorBoundary.ts +++ b/src/ErrorBoundary.ts @@ -6,7 +6,7 @@ import { StatsigUninitializedError, } from './Errors'; import OutputLogger from './OutputLogger'; -import { StatsigOptions } from './StatsigOptions'; +import { NetworkOverrideFunc, StatsigOptions } from './StatsigOptions'; import { getSDKType, getSDKVersion, getStatsigMetadata } from './utils/core'; import safeFetch from './utils/safeFetch'; import { StatsigContext } from './utils/StatsigContext'; @@ -18,6 +18,7 @@ export default class ErrorBoundary { private optionsLoggingCopy: StatsigOptions; private statsigMetadata = getStatsigMetadata(); private seen = new Set(); + private networkOverrideFunc: NetworkOverrideFunc | null; constructor( sdkKey: string, @@ -27,6 +28,7 @@ export default class ErrorBoundary { this.sdkKey = sdkKey; this.optionsLoggingCopy = optionsLoggingCopy; this.statsigMetadata['sessionID'] = sessionID; + this.networkOverrideFunc = optionsLoggingCopy.networkOverrideFunc ?? null; } swallow(task: (ctx: StatsigContext) => T, ctx: StatsigContext) { @@ -113,7 +115,9 @@ export default class ErrorBoundary { statsigOptions: this.optionsLoggingCopy, ...ctx.getContextForLogging(), }); - safeFetch(ExceptionEndpoint, { + + const fetcher = this.networkOverrideFunc ?? safeFetch; + fetcher(ExceptionEndpoint, { method: 'POST', headers: { 'STATSIG-API-KEY': this.sdkKey, diff --git a/src/SpecStore.ts b/src/SpecStore.ts index 8957792..f51a6ef 100644 --- a/src/SpecStore.ts +++ b/src/SpecStore.ts @@ -14,7 +14,11 @@ import { InitializationSource } from './InitializationDetails'; import { DataAdapterKey, IDataAdapter } from './interfaces/IDataAdapter'; import OutputLogger from './OutputLogger'; import SDKFlags from './SDKFlags'; -import { ExplicitStatsigOptions, InitStrategy } from './StatsigOptions'; +import { + ExplicitStatsigOptions, + InitStrategy, + NetworkOverrideFunc, +} from './StatsigOptions'; import { poll } from './utils/core'; import IDListUtil, { IDList } from './utils/IDListUtil'; import safeFetch from './utils/safeFetch'; @@ -64,6 +68,7 @@ export default class SpecStore { private hashedClientSDKKeyToAppMap: Record = {}; private hashedSDKKeysToEntities: Record = {}; private primaryTargetAppID: string | null; + private networkOverrideFunc: NetworkOverrideFunc | null; public constructor(fetcher: StatsigFetcher, options: ExplicitStatsigOptions) { this.fetcher = fetcher; @@ -77,6 +82,7 @@ export default class SpecStore { layers: {}, experimentToLayer: {}, }; + this.networkOverrideFunc = options.networkOverrideFunc ?? null; this.rulesetsSyncInterval = options.rulesetsSyncIntervalMs; this.idListsSyncInterval = options.idListsSyncIntervalMs; this.disableRulesetsSync = options.disableRulesetsSync; @@ -756,7 +762,9 @@ export default class SpecStore { url: url, markerID, }); - res = await safeFetch(url, { + + const fetcher = this.networkOverrideFunc ?? safeFetch; + res = await fetcher(url, { method: 'GET', headers: { Range: `bytes=${readSize}-`, @@ -771,9 +779,11 @@ export default class SpecStore { markerID, }); } + if (threwNetworkError || !res) { return; } + try { diagnostics?.process.start({ markerID }); const contentLength = res.headers.get('content-length'); diff --git a/src/StatsigOptions.ts b/src/StatsigOptions.ts index 6fffbc2..6412eeb 100644 --- a/src/StatsigOptions.ts +++ b/src/StatsigOptions.ts @@ -33,11 +33,17 @@ export interface LoggerInterface { logLevel: 'none' | 'debug' | 'info' | 'warn' | 'error'; } +export type NetworkOverrideFunc = ( + url: string, + params: RequestInit, +) => Promise; + export type ExplicitStatsigOptions = { api: string; apiForDownloadConfigSpecs: string; apiForGetIdLists: string; fallbackToStatsigAPI: boolean; + networkOverrideFunc: NetworkOverrideFunc | null; bootstrapValues: string | null; environment: StatsigEnvironment | null; rulesUpdatedCallback: RulesUpdatedCallback | null; @@ -80,6 +86,7 @@ export function OptionsWithDefaults( normalizeUrl(getString(opts, 'apiForGetIdLists', opts.api ?? null)) ?? STATSIG_API, fallbackToStatsigAPI: getBoolean(opts, 'fallbackToStatsigAPI', false), + networkOverrideFunc: opts.networkOverrideFunc ?? null, bootstrapValues: getString(opts, 'bootstrapValues', null), environment: opts.environment ? (getObject(opts, 'environment', {}) as StatsigEnvironment) diff --git a/src/__tests__/NetworkOverrideFunc.test.ts b/src/__tests__/NetworkOverrideFunc.test.ts new file mode 100644 index 0000000..245800a --- /dev/null +++ b/src/__tests__/NetworkOverrideFunc.test.ts @@ -0,0 +1,49 @@ +import * as statsigsdk from '../index'; +import StatsigInstanceUtils from '../StatsigInstanceUtils'; + +const statsig = statsigsdk.default; + +jest.mock('node-fetch', () => jest.fn()); + +function verifyDcsFetchCall(call: any[]) { + expect(call[0]).toEqual( + 'https://api.statsigcdn.com/v1/download_config_specs/secret-key.json?sinceTime=0', + ); + expect(call[1].method).toBe('GET'); +} + +function verifyGetIdListsFetchCall(call: any[]) { + expect(call[0]).toEqual('https://statsigapi.net/v1/get_id_lists'); + expect(call[1].method).toBe('POST'); +} + +describe('NetworkOverrideFunc', () => { + const fetchSpy = require('node-fetch'); + const networkOverrideSpy = jest.fn(); + + beforeEach(() => { + StatsigInstanceUtils.setInstance(null); + + fetchSpy.mockClear(); + networkOverrideSpy.mockClear(); + }); + + it('calls the networkOverrideFunc', async () => { + await statsig.initialize('secret-key', { + networkOverrideFunc: networkOverrideSpy, + }); + + expect(fetchSpy).not.toHaveBeenCalled(); + + verifyDcsFetchCall(networkOverrideSpy.mock.calls[0]); + verifyGetIdListsFetchCall(networkOverrideSpy.mock.calls[1]); + }); + + it('calls fetch when no override is given', async () => { + await statsig.initialize('secret-key'); + expect(networkOverrideSpy).not.toHaveBeenCalled(); + + verifyDcsFetchCall(fetchSpy.mock.calls[0]); + verifyGetIdListsFetchCall(fetchSpy.mock.calls[1]); + }); +}); diff --git a/src/utils/StatsigFetcher.ts b/src/utils/StatsigFetcher.ts index 56f7e3a..08e8ec6 100644 --- a/src/utils/StatsigFetcher.ts +++ b/src/utils/StatsigFetcher.ts @@ -5,7 +5,11 @@ import { StatsigSDKKeyMismatchError, StatsigTooManyRequestsError, } from '../Errors'; -import { ExplicitStatsigOptions, RetryBackoffFunc } from '../StatsigOptions'; +import { + ExplicitStatsigOptions, + NetworkOverrideFunc, + RetryBackoffFunc, +} from '../StatsigOptions'; import { getSDKType, getSDKVersion } from './core'; import Dispatcher from './Dispatcher'; import getCompressionFunc from './getCompressionFunc'; @@ -38,6 +42,7 @@ export default class StatsigFetcher { private localMode: boolean; private sdkKey: string; private errorBoundry: ErrorBoundary; + private networkOverrideFunc: NetworkOverrideFunc | null; public constructor( secretKey: string, @@ -55,6 +60,7 @@ export default class StatsigFetcher { this.localMode = options.localMode; this.sdkKey = secretKey; this.errorBoundry = errorBoundry; + this.networkOverrideFunc = options.networkOverrideFunc; } public validateSDKKeyUsed(hashedSDKKeyUsed: string): boolean { @@ -158,11 +164,11 @@ export default class StatsigFetcher { ...options?.additionalHeaders, 'Content-type': 'application/json; charset=UTF-8', 'STATSIG-API-KEY': this.sdkKey, - 'STATSIG-CLIENT-TIME': Date.now(), + 'STATSIG-CLIENT-TIME': `${Date.now()}`, 'STATSIG-SERVER-SESSION-ID': this.sessionID, 'STATSIG-SDK-TYPE': getSDKType(), 'STATSIG-SDK-VERSION': getSDKVersion(), - } as Record; + } as Record; let contents: BodyInit | undefined = undefined; const gzipSync = getCompressionFunc(); @@ -173,12 +179,6 @@ export default class StatsigFetcher { } else if (body) { contents = JSON.stringify(body); } - const params = { - method: method, - body: contents, - headers, - signal: signal, - }; if (!isRetrying) { markDiagnostic?.start({}); @@ -186,7 +186,13 @@ export default class StatsigFetcher { let res: Response | undefined; let error: unknown; - return safeFetch(url, params) + const fetcher = this.networkOverrideFunc ?? safeFetch; + return fetcher(url, { + method: method, + body: contents, + headers, + signal: signal, + }) .then((localRes) => { res = localRes; if ((!res.ok || retryStatusCodes.includes(res.status)) && retries > 0) {