Skip to content

Commit

Permalink
dcs fallback unit tests (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
kat-statsig authored Oct 15, 2024
1 parent 8e8a7dd commit a7b371c
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 82 deletions.
2 changes: 1 addition & 1 deletion statsig/spec_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def sync_config_spec():
self.get_config_spec(strategy)
outof_sync = False
time_elapsed = time.time() * 1000 - self.last_update_time
if (self._enforce_sync_fallback_threshold_in_ms is not None and time_elapsed > self._enforce_sync_fallback_threshold_in_ms):
if self._enforce_sync_fallback_threshold_in_ms is not None and time_elapsed > self._enforce_sync_fallback_threshold_in_ms:
outof_sync = True
if prev_failure_count == self._sync_failure_count and not outof_sync:
globals.logger.log_process("Config Sync", f"Syncing config values with {strategy.value} successful")
Expand Down
80 changes: 60 additions & 20 deletions tests/network_stub.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import gzip
import re
from typing import Callable, Union
from io import BytesIO
from typing import Callable, Union, Optional
from urllib.parse import urlparse, ParseResult

STATSIG_APIS = ["https://api.statsigcdn.com/", "https://statsigapi.net/"]
STATSIG_APIS = ["https://api.statsigcdn.com", "https://statsigapi.net"]


class NetworkStub:
host: str
Expand All @@ -22,60 +25,97 @@ def __init__(self, status, data=None, headers=None):
def json(self):
return self._json

def __init__(self, host: str, mock_statsig_api = False):
def __init__(self, host: str, mock_statsig_api=False):
self.host = host
self.mock_statsig_api = mock_statsig_api
self._stubs = {}
self._statsig_stubs = {}

def reset(self):
self._stubs = {}

def stub_request_with_value(
self, path, response_code: int, response_body: Union[dict, str]):
if not isinstance(response_body, dict) and not isinstance(
response_body, str):
self, path, response_code: int, response_body: Union[dict, str], headers: Optional[dict] = None):
if not isinstance(response_body, dict) and not isinstance(response_body, str):
raise "Must provide a dictionary or string"

self._stubs[path] = {
"response_code": response_code,
"response_body": response_body,
"headers": headers or {}
}

def stub_request_with_function(self, path, response_code: Union[int, Callable[[str, dict], int]],
response_func: Callable[[str, dict], object]):
response_func: Callable[[str, dict], object], headers: Optional[dict] = None):
if not callable(response_func):
raise "Must provide a function"

self._stubs[path] = {
"response_code": response_code,
"response_func": response_func
"response_func": response_func,
"headers": headers or {}
}

def stub_statsig_api_request_with_value(
self, path, response_code: int, response_body: Union[dict, str], headers: Optional[dict] = None):
if not isinstance(response_body, dict) and not isinstance(response_body, str):
raise "Must provide a dictionary or string"

self._statsig_stubs[path] = {
"response_code": response_code,
"response_body": response_body,
"headers": headers or {}
}

def stub_statsig_api_request_with_function(self, path, response_code: Union[int, Callable[[str, dict], int]],
response_func: Callable[[str, dict], object],
headers: Optional[dict] = None):
if not callable(response_func):
raise "Must provide a function"

self._statsig_stubs[path] = {
"response_code": response_code,
"response_func": response_func,
"headers": headers or {}
}

def mock(*args, **kwargs):
instance: NetworkStub = args[0]
method: str = args[1]
url: ParseResult = urlparse(args[2])
request_host = (url.scheme + "://" + url.hostname)
request_host = f"{url.scheme}://{url.hostname}"

if request_host != instance.host and (instance.mock_statsig_api and request_host not in STATSIG_APIS):
return

paths = list(instance._stubs.keys())
for path in paths:
stub_data: dict = instance._stubs[path]

if re.search(f".*{path}", url.path) is not None:
response_body = stub_data.get("response_body", None)
if stub_data.get("response_func", None) is not None:
stubs = instance._statsig_stubs if request_host in STATSIG_APIS and instance.mock_statsig_api else instance._stubs
for path, stub_data in stubs.items():
if re.search(f".*{path}", url.path):
response_body = stub_data.get("response_body")
headers = stub_data.get("headers", {})

if "response_func" in stub_data:
response_body = stub_data["response_func"](url, **kwargs)
response_code = stub_data.get("response_code", None)

response_code = stub_data.get("response_code")
if callable(response_code):
response_code = response_code(url, kwargs)

headers = {}
if "Content-Encoding" in headers and headers["Content-Encoding"] == "gzip":
response_body = gzip_compress(response_body)

if isinstance(response_body, str):
headers["content-length"] = len(response_body)

return NetworkStub.StubResponse(
stub_data["response_code"], response_body, headers)
return NetworkStub.StubResponse(response_code, response_body, headers)

return NetworkStub.StubResponse(404)


def gzip_compress(data: Union[str, bytes]) -> bytes:
if isinstance(data, str):
data = data.encode('utf-8')
buf = BytesIO()
with gzip.GzipFile(fileobj=buf, mode='wb') as gz:
gz.write(data)
return buf.getvalue()
198 changes: 137 additions & 61 deletions tests/test_sync_config_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,74 +2,150 @@
import os
import time
import unittest

from unittest.mock import patch

from network_stub import NetworkStub
from statsig import StatsigOptions, statsig, StatsigUser
from statsig.evaluation_details import EvaluationDetails, EvaluationReason
from statsig.http_worker import HttpWorker
from statsig.statsig_options import DataSource

_network_stub = NetworkStub("http://test-sync-config-fallback", mock_statsig_api=True)
with open(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../testdata/download_config_specs.json')) as r:
CONFIG_SPECS_RESPONSE = r.read()
PARSED_CONFIG_SPEC = json.loads(CONFIG_SPECS_RESPONSE)

UPDATED_TIME_CONFIG_SPEC = PARSED_CONFIG_SPEC.copy()
UPDATED_TIME_CONFIG_SPEC['time'] = 1631638014821


@patch('requests.request', side_effect=_network_stub.mock)
class TestSyncConfigFallback(unittest.TestCase):
@classmethod
@patch('requests.request', side_effect=_network_stub.mock)
def setUpClass(cls, mock_proxy):
cls.dcs_hit = 0
_network_stub.reset()
def dcs_proxy_callback(url: str, **kwargs):
cls.dcs_hit += 1
return json.loads(CONFIG_SPECS_RESPONSE)

_network_stub.stub_request_with_function(
"download_config_specs/.*", 200, dcs_proxy_callback)

cls.test_user = StatsigUser("123", email="[email protected]")

def tearDown(self):
self.dcs_hit = 0
statsig.shutdown()

@patch('requests.request', side_effect=_network_stub.mock)
@patch.object(HttpWorker, 'get_dcs_fallback')
def test_default_behavior(self, fallback_mock, request_mock):
# default behavior is no fallback if is out of sync
options = StatsigOptions(api=_network_stub.host, fallback_to_statsig_api=True, rulesets_sync_interval=1)
statsig.initialize("secret-key", options)
gate = statsig.get_feature_gate(self.test_user, "always_on_gate")
eval_detail: EvaluationDetails = gate.get_evaluation_details()
self.assertEqual(eval_detail.reason, EvaluationReason.network)
self.assertEqual(eval_detail.config_sync_time, 1631638014811)
time.sleep(1.1)

fallback_mock.assert_not_called()

@patch('requests.request', side_effect=_network_stub.mock)
@patch.object(HttpWorker, 'get_dcs_fallback')
def test_fallback_when_out_of_sync(self, fallback_mock, request_mock):
# default behavior is no fallback if is out of sync
options = StatsigOptions(api_for_download_config_specs=_network_stub.host, fallback_to_statsig_api=True, rulesets_sync_interval=1, out_of_sync_threshold_in_s=0.5)
statsig.initialize("secret-key", options)
gate = statsig.get_feature_gate(self.test_user, "always_on_gate")
eval_detail: EvaluationDetails = gate.get_evaluation_details()
self.assertEqual(eval_detail.reason, EvaluationReason.network)
self.assertEqual(eval_detail.config_sync_time, 1631638014811)
time.sleep(1.1)
#ensure it falls back
fallback_mock.assert_called_once()

@patch('requests.request', side_effect=_network_stub.mock)
@patch.object(HttpWorker, 'get_dcs_fallback')
def test_behavior_when_not_out_of_sync(self, fallback_mock, request_mock):
# default behavior is no fallback if is out of sync
options = StatsigOptions(api_for_download_config_specs=_network_stub.host, fallback_to_statsig_api=True, rulesets_sync_interval=1, out_of_sync_threshold_in_s=4e10)
statsig.initialize("secret-key", options)
gate = statsig.get_feature_gate(self.test_user, "always_on_gate")
eval_detail: EvaluationDetails = gate.get_evaluation_details()
self.assertEqual(eval_detail.reason, EvaluationReason.network)
self.assertEqual(eval_detail.config_sync_time, 1631638014811)
time.sleep(1.1)
#ensure no fallback
fallback_mock.assert_not_called()
@classmethod
@patch('requests.request', side_effect=_network_stub.mock)
def setUpClass(cls, mock_proxy):
cls.dcs_called = False
cls.statsig_dcs_called = False
cls.status_code = 200

cls.test_user = StatsigUser("123", email="[email protected]")

def setUp(self):
self.__class__.dcs_called = False
self.__class__.statsig_dcs_called = False
self.__class__.status_code = 200

def common_callback(url, **kwargs):
if 'statsig' in url.netloc:
self.__class__.statsig_dcs_called = True
return PARSED_CONFIG_SPEC
else:
self.__class__.dcs_called = True

if self.__class__.status_code == 200:
return PARSED_CONFIG_SPEC
if self.__class__.status_code == 300:
return "{jiBbRIsh;"
if self.__class__.status_code == 400:
return "Bad Request"
if self.__class__.status_code == 500:
raise Exception("Internal Server Error")

_network_stub.stub_request_with_function(
"download_config_specs/.*", self.__class__.status_code, common_callback)
_network_stub.stub_statsig_api_request_with_function(
"download_config_specs/.*", 200, common_callback)

def tearDown(self):
statsig.shutdown()
_network_stub.reset()

def test_default_sync_success(self, request_mock):
options = StatsigOptions(api=_network_stub.host, fallback_to_statsig_api=True, rulesets_sync_interval=1)
statsig.initialize("secret-key", options)
gate = statsig.get_feature_gate(self.test_user, "always_on_gate")
eval_detail: EvaluationDetails = gate.get_evaluation_details()
self.assertEqual(eval_detail.reason, EvaluationReason.network)
self.assertEqual(eval_detail.config_sync_time, 1631638014811)
time.sleep(1.1)
self.assertFalse(self.__class__.statsig_dcs_called)

def test_fallback_when_out_of_sync(self, request_mock):
options = StatsigOptions(api_for_download_config_specs=_network_stub.host, fallback_to_statsig_api=True,
rulesets_sync_interval=1, out_of_sync_threshold_in_s=0.5)
statsig.initialize("secret-key", options)
gate = statsig.get_feature_gate(self.test_user, "always_on_gate")
eval_detail: EvaluationDetails = gate.get_evaluation_details()
self.assertEqual(eval_detail.reason, EvaluationReason.network)
self.assertEqual(eval_detail.config_sync_time, 1631638014811)
time.sleep(1.1)
self.assertTrue(self.__class__.statsig_dcs_called)

def test_no_fallback_when_not_out_of_sync(self, request_mock):
options = StatsigOptions(api_for_download_config_specs=_network_stub.host, fallback_to_statsig_api=True,
rulesets_sync_interval=1, out_of_sync_threshold_in_s=4e10)
statsig.initialize("secret-key", options)
gate = statsig.get_feature_gate(self.test_user, "always_on_gate")
eval_detail: EvaluationDetails = gate.get_evaluation_details()
self.assertEqual(eval_detail.reason, EvaluationReason.network)
self.assertEqual(eval_detail.config_sync_time, 1631638014811)
time.sleep(1.1)
self.assertFalse(self.__class__.statsig_dcs_called)

def test_fallback_when_dcs_400(self, request_mock):
self.__class__.status_code = 400
options = StatsigOptions(api_for_download_config_specs=_network_stub.host, fallback_to_statsig_api=True,
rulesets_sync_interval=1)
statsig.initialize("secret-key", options)
self.assertEqual(statsig.get_instance()._spec_store.init_source, DataSource.STATSIG_NETWORK)
self.get_gate_and_validate()
self.wait_for_sync_and_validate()

def test_fallback_when_dcs_500(self, request_mock):
self.__class__.status_code = 500
options = StatsigOptions(api_for_download_config_specs=_network_stub.host, fallback_to_statsig_api=True,
rulesets_sync_interval=1)
statsig.initialize("secret-key", options)
self.assertEqual(statsig.get_instance()._spec_store.init_source, DataSource.STATSIG_NETWORK)
self.get_gate_and_validate()
self.wait_for_sync_and_validate()

def test_fallback_when_dcs_invalid_json(self, request_mock):
self.__class__.status_code = 300
options = StatsigOptions(api_for_download_config_specs=_network_stub.host, fallback_to_statsig_api=True,
rulesets_sync_interval=1)
statsig.initialize("secret-key", options)
self.assertEqual(statsig.get_instance()._spec_store.init_source, DataSource.STATSIG_NETWORK)
self.get_gate_and_validate()
self.wait_for_sync_and_validate()

def test_fallback_when_invalid_gzip_content(self, request_mock):
def cb(url, **kwargs):
self.__class__.dcs_called = True
return "{jiBbRIsh;"

_network_stub.stub_request_with_function(
"download_config_specs/.*", 200, cb,
headers={"Content-Encoding": "gzip"}
)
options = StatsigOptions(api_for_download_config_specs=_network_stub.host, fallback_to_statsig_api=True,
rulesets_sync_interval=1)
statsig.initialize("secret-key", options)
self.assertEqual(statsig.get_instance()._spec_store.init_source, DataSource.STATSIG_NETWORK)
self.get_gate_and_validate()
self.wait_for_sync_and_validate()

def wait_for_sync_and_validate(self):
_network_stub.stub_statsig_api_request_with_value("download_config_specs/.*", 200,
UPDATED_TIME_CONFIG_SPEC)
time.sleep(1.1)
gate = statsig.get_feature_gate(self.test_user, "always_on_gate")
eval_detail: EvaluationDetails = gate.get_evaluation_details()
self.assertEqual(eval_detail.config_sync_time, 1631638014821)

def get_gate_and_validate(self):
gate = statsig.get_feature_gate(self.test_user, "always_on_gate")
eval_detail: EvaluationDetails = gate.get_evaluation_details()
self.assertEqual(eval_detail.reason, EvaluationReason.network)
self.assertEqual(eval_detail.config_sync_time, 1631638014811)
self.assertTrue(self.__class__.dcs_called)
self.assertTrue(self.__class__.statsig_dcs_called)

0 comments on commit a7b371c

Please sign in to comment.