Skip to content

Commit

Permalink
feat: add StatsigOptions.networkOverrideFunc (#517)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-statsig committed Oct 30, 2024
1 parent b1b4630 commit 532cd78
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 14 deletions.
8 changes: 6 additions & 2 deletions src/ErrorBoundary.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
StatsigUninitializedError,
} from './Errors';
import OutputLogger from './OutputLogger';
import { StatsigOptions } from './StatsigOptions';
import { NetworkOverrideFunc, StatsigOptions } from './StatsigOptions';
import { getSDKType, getSDKVersion, getStatsigMetadata } from './utils/core';
import safeFetch from './utils/safeFetch';
import { StatsigContext } from './utils/StatsigContext';
Expand All @@ -18,6 +18,7 @@ export default class ErrorBoundary {
private optionsLoggingCopy: StatsigOptions;
private statsigMetadata = getStatsigMetadata();
private seen = new Set<string>();
private networkOverrideFunc: NetworkOverrideFunc | null;

constructor(
sdkKey: string,
Expand All @@ -27,6 +28,7 @@ export default class ErrorBoundary {
this.sdkKey = sdkKey;
this.optionsLoggingCopy = optionsLoggingCopy;
this.statsigMetadata['sessionID'] = sessionID;
this.networkOverrideFunc = optionsLoggingCopy.networkOverrideFunc ?? null;
}

swallow<T>(task: (ctx: StatsigContext) => T, ctx: StatsigContext) {
Expand Down Expand Up @@ -113,7 +115,9 @@ export default class ErrorBoundary {
statsigOptions: this.optionsLoggingCopy,
...ctx.getContextForLogging(),
});
safeFetch(ExceptionEndpoint, {

const fetcher = this.networkOverrideFunc ?? safeFetch;
fetcher(ExceptionEndpoint, {
method: 'POST',
headers: {
'STATSIG-API-KEY': this.sdkKey,
Expand Down
14 changes: 12 additions & 2 deletions src/SpecStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ import { InitializationSource } from './InitializationDetails';
import { DataAdapterKey, IDataAdapter } from './interfaces/IDataAdapter';
import OutputLogger from './OutputLogger';
import SDKFlags from './SDKFlags';
import { ExplicitStatsigOptions, InitStrategy } from './StatsigOptions';
import {
ExplicitStatsigOptions,
InitStrategy,
NetworkOverrideFunc,
} from './StatsigOptions';
import { poll } from './utils/core';
import IDListUtil, { IDList } from './utils/IDListUtil';
import safeFetch from './utils/safeFetch';
Expand Down Expand Up @@ -64,6 +68,7 @@ export default class SpecStore {
private hashedClientSDKKeyToAppMap: Record<string, string> = {};
private hashedSDKKeysToEntities: Record<string, APIEntityNames> = {};
private primaryTargetAppID: string | null;
private networkOverrideFunc: NetworkOverrideFunc | null;

public constructor(fetcher: StatsigFetcher, options: ExplicitStatsigOptions) {
this.fetcher = fetcher;
Expand All @@ -77,6 +82,7 @@ export default class SpecStore {
layers: {},
experimentToLayer: {},
};
this.networkOverrideFunc = options.networkOverrideFunc ?? null;
this.rulesetsSyncInterval = options.rulesetsSyncIntervalMs;
this.idListsSyncInterval = options.idListsSyncIntervalMs;
this.disableRulesetsSync = options.disableRulesetsSync;
Expand Down Expand Up @@ -756,7 +762,9 @@ export default class SpecStore {
url: url,
markerID,
});
res = await safeFetch(url, {

const fetcher = this.networkOverrideFunc ?? safeFetch;
res = await fetcher(url, {
method: 'GET',
headers: {
Range: `bytes=${readSize}-`,
Expand All @@ -771,9 +779,11 @@ export default class SpecStore {
markerID,
});
}

if (threwNetworkError || !res) {
return;
}

try {
diagnostics?.process.start({ markerID });
const contentLength = res.headers.get('content-length');
Expand Down
7 changes: 7 additions & 0 deletions src/StatsigOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,17 @@ export interface LoggerInterface {
logLevel: 'none' | 'debug' | 'info' | 'warn' | 'error';
}

export type NetworkOverrideFunc = (
url: string,
params: RequestInit,
) => Promise<Response>;

export type ExplicitStatsigOptions = {
api: string;
apiForDownloadConfigSpecs: string;
apiForGetIdLists: string;
fallbackToStatsigAPI: boolean;
networkOverrideFunc: NetworkOverrideFunc | null;
bootstrapValues: string | null;
environment: StatsigEnvironment | null;
rulesUpdatedCallback: RulesUpdatedCallback | null;
Expand Down Expand Up @@ -80,6 +86,7 @@ export function OptionsWithDefaults(
normalizeUrl(getString(opts, 'apiForGetIdLists', opts.api ?? null)) ??
STATSIG_API,
fallbackToStatsigAPI: getBoolean(opts, 'fallbackToStatsigAPI', false),
networkOverrideFunc: opts.networkOverrideFunc ?? null,
bootstrapValues: getString(opts, 'bootstrapValues', null),
environment: opts.environment
? (getObject(opts, 'environment', {}) as StatsigEnvironment)
Expand Down
49 changes: 49 additions & 0 deletions src/__tests__/NetworkOverrideFunc.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import * as statsigsdk from '../index';
import StatsigInstanceUtils from '../StatsigInstanceUtils';

const statsig = statsigsdk.default;

jest.mock('node-fetch', () => jest.fn());

function verifyDcsFetchCall(call: any[]) {
expect(call[0]).toEqual(
'https://api.statsigcdn.com/v1/download_config_specs/secret-key.json?sinceTime=0',
);
expect(call[1].method).toBe('GET');
}

function verifyGetIdListsFetchCall(call: any[]) {
expect(call[0]).toEqual('https://statsigapi.net/v1/get_id_lists');
expect(call[1].method).toBe('POST');
}

describe('NetworkOverrideFunc', () => {
const fetchSpy = require('node-fetch');
const networkOverrideSpy = jest.fn();

beforeEach(() => {
StatsigInstanceUtils.setInstance(null);

fetchSpy.mockClear();
networkOverrideSpy.mockClear();
});

it('calls the networkOverrideFunc', async () => {
await statsig.initialize('secret-key', {
networkOverrideFunc: networkOverrideSpy,
});

expect(fetchSpy).not.toHaveBeenCalled();

verifyDcsFetchCall(networkOverrideSpy.mock.calls[0]);
verifyGetIdListsFetchCall(networkOverrideSpy.mock.calls[1]);
});

it('calls fetch when no override is given', async () => {
await statsig.initialize('secret-key');
expect(networkOverrideSpy).not.toHaveBeenCalled();

verifyDcsFetchCall(fetchSpy.mock.calls[0]);
verifyGetIdListsFetchCall(fetchSpy.mock.calls[1]);
});
});
26 changes: 16 additions & 10 deletions src/utils/StatsigFetcher.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ import {
StatsigSDKKeyMismatchError,
StatsigTooManyRequestsError,
} from '../Errors';
import { ExplicitStatsigOptions, RetryBackoffFunc } from '../StatsigOptions';
import {
ExplicitStatsigOptions,
NetworkOverrideFunc,
RetryBackoffFunc,
} from '../StatsigOptions';
import { getSDKType, getSDKVersion } from './core';
import Dispatcher from './Dispatcher';
import getCompressionFunc from './getCompressionFunc';
Expand Down Expand Up @@ -38,6 +42,7 @@ export default class StatsigFetcher {
private localMode: boolean;
private sdkKey: string;
private errorBoundry: ErrorBoundary;
private networkOverrideFunc: NetworkOverrideFunc | null;

public constructor(
secretKey: string,
Expand All @@ -55,6 +60,7 @@ export default class StatsigFetcher {
this.localMode = options.localMode;
this.sdkKey = secretKey;
this.errorBoundry = errorBoundry;
this.networkOverrideFunc = options.networkOverrideFunc;
}

public validateSDKKeyUsed(hashedSDKKeyUsed: string): boolean {
Expand Down Expand Up @@ -158,11 +164,11 @@ export default class StatsigFetcher {
...options?.additionalHeaders,
'Content-type': 'application/json; charset=UTF-8',
'STATSIG-API-KEY': this.sdkKey,
'STATSIG-CLIENT-TIME': Date.now(),
'STATSIG-CLIENT-TIME': `${Date.now()}`,
'STATSIG-SERVER-SESSION-ID': this.sessionID,
'STATSIG-SDK-TYPE': getSDKType(),
'STATSIG-SDK-VERSION': getSDKVersion(),
} as Record<string, string | number>;
} as Record<string, string>;

let contents: BodyInit | undefined = undefined;
const gzipSync = getCompressionFunc();
Expand All @@ -173,20 +179,20 @@ export default class StatsigFetcher {
} else if (body) {
contents = JSON.stringify(body);
}
const params = {
method: method,
body: contents,
headers,
signal: signal,
};

if (!isRetrying) {
markDiagnostic?.start({});
}

let res: Response | undefined;
let error: unknown;
return safeFetch(url, params)
const fetcher = this.networkOverrideFunc ?? safeFetch;
return fetcher(url, {
method: method,
body: contents,
headers,
signal: signal,
})
.then((localRes) => {
res = localRes;
if ((!res.ok || retryStatusCodes.includes(res.status)) && retries > 0) {
Expand Down

0 comments on commit 532cd78

Please sign in to comment.