diff --git a/requirements.txt b/requirements.txt index 99ab87b..5ea4679 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ grpcio==1.62.3 protobuf==4.24.4 requests semver -pyfarmhash \ No newline at end of file +pyfarmhash +ijson \ No newline at end of file diff --git a/setup.py b/setup.py index 3e4fc99..7a50f6c 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ 'ip3country', 'grpcio', 'protobuf', + 'ijson', ], tests_require=test_deps, extras_require=extras, diff --git a/statsig/dynamic_config.py b/statsig/dynamic_config.py index a1ca971..d1689e4 100644 --- a/statsig/dynamic_config.py +++ b/statsig/dynamic_config.py @@ -1,7 +1,7 @@ from typing import Optional -from statsig.evaluation_details import EvaluationDetails, EvaluationReason, DataSource -from statsig.statsig_user import StatsigUser +from .evaluation_details import EvaluationDetails, EvaluationReason, DataSource +from .statsig_user import StatsigUser class DynamicConfig: diff --git a/statsig/http_worker.py b/statsig/http_worker.py index 4d71533..59780f1 100644 --- a/statsig/http_worker.py +++ b/statsig/http_worker.py @@ -2,9 +2,11 @@ import json import time from concurrent.futures.thread import ThreadPoolExecutor +from decimal import Decimal from io import BytesIO from typing import Callable, Tuple, Optional, Any +import ijson import requests from . import globals @@ -53,7 +55,7 @@ def get_dcs(self, on_complete: Callable, since_time=0, log_on_exception=False, i tag="download_config_specs") self._context.source_api = self.__api_for_download_config_specs if response is not None and self._is_success_code(response.status_code): - on_complete(DataSource.NETWORK, response.json() or {}, None) + on_complete(DataSource.NETWORK, self._stream_response_into_result_dict(response) or {}, None) return on_complete(DataSource.NETWORK, None, None) @@ -64,7 +66,7 @@ def get_dcs_fallback(self, on_complete: Callable, since_time=0, log_on_exception tag="download_config_specs") self._context.source_api = STATSIG_CDN if response is not None and self._is_success_code(response.status_code): - on_complete(DataSource.STATSIG_NETWORK, response.json() or {}, None) + on_complete(DataSource.STATSIG_NETWORK, self._stream_response_into_result_dict(response) or {}, None) return on_complete(DataSource.STATSIG_NETWORK, None, None) @@ -78,7 +80,7 @@ def get_id_lists(self, on_complete: Callable, log_on_exception=False, init_timeo tag="get_id_lists", ) if response is not None and self._is_success_code(response.status_code): - return on_complete(response.json() or {}, None) + return on_complete(self._stream_response_into_result_dict(response) or {}, None) return on_complete(None, None) def get_id_lists_fallback(self, on_complete: Callable, log_on_exception=False, init_timeout=None): @@ -91,7 +93,7 @@ def get_id_lists_fallback(self, on_complete: Callable, log_on_exception=False, i tag="get_id_lists", ) if response is not None and self._is_success_code(response.status_code): - return on_complete(response.json() or {}, None) + return on_complete(self._stream_response_into_result_dict(response) or {}, None) return on_complete(None, None) def get_id_list(self, on_complete, url, headers, log_on_exception=False): @@ -189,7 +191,7 @@ def _request( timeout = self.__req_timeout def request_task(): - return requests.request(method, url, data=payload, headers=headers, timeout=timeout) + return requests.request(method, url, data=payload, headers=headers, timeout=timeout, stream=True) response = None if init_timeout is not None: @@ -245,6 +247,32 @@ def request_task(): ) return None + def _stream_response_into_result_dict(self, response): + result = {} + try: + if response.headers.get("Content-Encoding") == "gzip": + stream = gzip.GzipFile(fileobj=response.raw) + else: + stream = response.raw + for k, v in ijson.kvitems(stream, ""): + v = self._convert_decimals_to_floats(v) + result[k] = v + return result + except Exception as e: + globals.logger.warning( + f"Failed to stream response into result dict from {response.url}. {e}" + ) + return None + + def _convert_decimals_to_floats(self, obj): + if isinstance(obj, Decimal): + return float(obj) + if isinstance(obj, dict): + return {k: self._convert_decimals_to_floats(v) for k, v in obj.items()} + if isinstance(obj, list): + return [self._convert_decimals_to_floats(v) for v in obj] + return obj + def _is_success_code(self, status_code: int) -> bool: return 200 <= status_code < 300 diff --git a/testdata/download_config_specs.json b/testdata/download_config_specs.json index 0de7691..fcba77b 100644 --- a/testdata/download_config_specs.json +++ b/testdata/download_config_specs.json @@ -34,7 +34,8 @@ "id": "1kNmlB23wylPFZi1M0Divl", "salt": "f2ac6975-174d-497e-be7f-599fea626132" } - ] + ], + "entity": "dynamic_config" }, { "name": "sample_experiment", @@ -1114,7 +1115,8 @@ "id": "2RamGujUou6h2bVNQWhtNZ", "salt": "2RamGujUou6h2bVNQWhtNZ" } - ] + ], + "entity": "experiment" } ], "feature_gates": [ @@ -1142,7 +1144,8 @@ "id": "6N6Z8ODekNYZ7F8gFdoLP5", "salt": "14862979-1468-4e49-9b2a-c8bb100eed8f" } - ] + ], + "entity": "feature_gate" }, { "name": "on_for_statsig_email", @@ -1170,7 +1173,8 @@ "id": "7w9rbTSffLT89pxqpyhuqK", "salt": "e452510f-bd5b-42cb-a71e-00498a7903fc" } - ] + ], + "entity": "feature_gate" }, { "name": "on_for_id_list", @@ -1196,7 +1200,8 @@ "id": "7w9rbTSffLT89pxqpyhuqA", "salt": "e452510f-bd5b-42cb-a71e-00498a7903fD" } - ] + ], + "entity": "feature_gate" } ], "layer_configs": [ diff --git a/tests/network_stub.py b/tests/network_stub.py index 1ce9308..3db88f9 100644 --- a/tests/network_stub.py +++ b/tests/network_stub.py @@ -1,4 +1,6 @@ import gzip +import io +import json import re from io import BytesIO from typing import Callable, Union, Optional @@ -12,7 +14,7 @@ class NetworkStub: mock_statsig_api: bool class StubResponse: - def __init__(self, status, data=None, headers=None): + def __init__(self, status, data=None, headers=None, raw=None): if headers is None: headers = {} @@ -21,6 +23,7 @@ def __init__(self, status, data=None, headers=None): self.headers = headers self._json = data self.text = data + self.raw = raw def json(self): return self._json @@ -107,8 +110,17 @@ def mock(*args, **kwargs): if isinstance(response_body, str): headers["content-length"] = len(response_body) + byte_body = response_body.encode("utf-8") + else: + byte_body = json.dumps(response_body).encode("utf-8") - return NetworkStub.StubResponse(response_code, response_body, headers) + try: + raw = io.BytesIO(byte_body) + except Exception as e: + print(f"Error in creating raw response: {e}") + raw = None + + return NetworkStub.StubResponse(response_code, response_body, headers, raw) return NetworkStub.StubResponse(404) diff --git a/tests/test_layer_exposures.py b/tests/test_layer_exposures.py index cd72439..c48dab5 100644 --- a/tests/test_layer_exposures.py +++ b/tests/test_layer_exposures.py @@ -1,11 +1,11 @@ -import unittest -import os import json +import os +import unittest from unittest.mock import patch +from gzip_helpers import GzipHelpers from network_stub import NetworkStub from statsig import statsig, StatsigUser, StatsigOptions, StatsigEnvironmentTier, Layer -from gzip_helpers import GzipHelpers from test_case_with_extras import TestCaseWithExtras with open(os.path.join(os.path.abspath(os.path.dirname(__file__)), diff --git a/tests/test_storage_adapter.py b/tests/test_storage_adapter.py index c2cad8c..edd69bc 100644 --- a/tests/test_storage_adapter.py +++ b/tests/test_storage_adapter.py @@ -3,8 +3,8 @@ import unittest from unittest.mock import patch -from statsig import statsig, IDataStore, StatsigOptions, StatsigUser from network_stub import NetworkStub +from statsig import statsig, IDataStore, StatsigOptions, StatsigUser with open(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../testdata/download_config_specs.json')) as r: CONFIG_SPECS_RESPONSE = json.loads(r.read()) @@ -89,8 +89,20 @@ def test_saving(self, mock_request): statsig.initialize("secret-key", self._options) stored_string = self._data_adapter.data["statsig.cache"] - expected_string = json.dumps(CONFIG_SPECS_RESPONSE) - self.assertEqual(stored_string, expected_string) + self.assertIsNotNone(stored_string, "Expected statsig.cache to be saved in data adapter") + stored = json.loads(stored_string) + self.assertTrue( + self._contains_spec(stored["feature_gates"], "always_on_gate", "feature_gate"), + "Expected data adapter to have downloaded gates" + ) + self.assertTrue( + self._contains_spec(stored["dynamic_configs"], "test_config", "dynamic_config"), + "Expected data adapter to have downloaded configs" + ) + self.assertTrue( + self._contains_spec(stored["layer_configs"], "a_layer", "layer"), + "Expected data adapter to have downloaded layers" + ) @patch('requests.request', side_effect=_network_stub.mock) def test_calls_network_when_adapter_is_empty(self, mock_request): @@ -135,3 +147,9 @@ def test_bootstrap_is_ignored_when_data_store_is_set(self): result = statsig.check_gate(self._user, "gate_from_bootstrap") self.assertEqual(False, result) + + def _contains_spec(self, specs, key, spec_type): + for spec in specs: + if spec.get("name") == key and spec.get("entity") == spec_type: + return True + return False