From 82189f2debb723f8d2a2b0f2bc7c7d404c7f6c1d Mon Sep 17 00:00:00 2001 From: kat-statsig <167801639+kat-statsig@users.noreply.github.com> Date: Tue, 3 Sep 2024 09:34:52 -0700 Subject: [PATCH] selectively import and load grpc workers (#304) --- statsig/statsig_network.py | 127 ++++++++++++++++++++++--------------- 1 file changed, 77 insertions(+), 50 deletions(-) diff --git a/statsig/statsig_network.py b/statsig/statsig_network.py index 75fe21f..ff4fc8b 100644 --- a/statsig/statsig_network.py +++ b/statsig/statsig_network.py @@ -1,3 +1,4 @@ +import importlib import threading from typing import Any, Callable, Optional @@ -7,11 +8,9 @@ DEFAULT_RULESET_SYNC_INTERVAL, StatsigOptions, STATSIG_CDN, - STATSIG_API, + STATSIG_API, ProxyConfig, ) from .diagnostics import Diagnostics -from .grpc_websocket_worker import GRPCWebsocketWorker -from .grpc_worker import GRPCWorker from .globals import logger from .http_worker import HttpWorker from .interface_network import ( @@ -27,7 +26,7 @@ class StreamingFallback(IStreamingFallback): def __init__( - self, fn: Callable, interval: int, name: str, eb: _StatsigErrorBoundary + self, fn: Callable, interval: int, name: str, eb: _StatsigErrorBoundary ): self.fn = fn self.stop_event = threading.Event() @@ -62,17 +61,19 @@ def _sync(self): class _StatsigNetwork: def __init__( - self, - sdk_key: str, - options: StatsigOptions, - statsig_metadata: dict, - error_boundary: _StatsigErrorBoundary, - diagnostics: Diagnostics, - shutdown_event, + self, + sdk_key: str, + options: StatsigOptions, + statsig_metadata: dict, + error_boundary: _StatsigErrorBoundary, + diagnostics: Diagnostics, + shutdown_event, ): self.sdk_key = sdk_key self.error_boundary = error_boundary self.statsig_options = options + self.diagnostics = diagnostics + self.shutdown_event = shutdown_event self.statsig_metadata = statsig_metadata defaultHttpWorker: IStatsigNetworkWorker = HttpWorker( sdk_key, options, statsig_metadata, error_boundary, diagnostics @@ -85,26 +86,52 @@ def __init__( protocol = config.protocol worker = defaultHttpWorker if protocol == NetworkProtocol.GRPC: - worker = GRPCWorker(sdk_key, config) + self.load_grpc_worker(endpoint, config) elif protocol == NetworkProtocol.GRPC_WEBSOCKET: - worker = GRPCWebsocketWorker( - sdk_key, - config, - options, - error_boundary, - diagnostics, - shutdown_event, - ) - - if endpoint == NetworkEndpoint.DOWNLOAD_CONFIG_SPECS: - self.dcs_worker = worker - elif endpoint == NetworkEndpoint.GET_ID_LISTS: - self.id_list_worker = worker - elif endpoint == NetworkEndpoint.LOG_EVENT: - self.log_event_worker = worker + self.load_grpc_websocket_worker(endpoint, config) + self._background_download_configs_from_statsig = None self._background_download_id_lists_from_statsig = None + def load_grpc_websocket_worker(self, endpoint: NetworkEndpoint, config: ProxyConfig): + grpc_worker_module = importlib.import_module("statsig.grpc_websocket_worker") + grpc_webhook_worker_class = getattr(grpc_worker_module, 'GRPCWebsocketWorker') + grpc_webhook_worker = grpc_webhook_worker_class( + self.sdk_key, + config, + self.statsig_options, + self.error_boundary, + self.diagnostics, + self.shutdown_event, + ) + if endpoint == NetworkEndpoint.DOWNLOAD_CONFIG_SPECS: + self.dcs_worker = grpc_webhook_worker + elif endpoint == NetworkEndpoint.GET_ID_LISTS: + self.id_list_worker = grpc_webhook_worker + elif endpoint == NetworkEndpoint.LOG_EVENT: + self.log_event_worker = grpc_webhook_worker + elif endpoint == NetworkEndpoint.ALL: + self.log_event_worker = grpc_webhook_worker + self.id_list_worker = grpc_webhook_worker + self.dcs_worker = grpc_webhook_worker + + def load_grpc_worker(self, endpoint: NetworkEndpoint, config: ProxyConfig): + grpc_worker_module = importlib.import_module("statsig.grpc_worker") + grpc_worker_class = getattr(grpc_worker_module, 'GRPCWorker') + grpc_worker = grpc_worker_class( + self.sdk_key, config + ) + if endpoint == NetworkEndpoint.DOWNLOAD_CONFIG_SPECS: + self.dcs_worker = grpc_worker + elif endpoint == NetworkEndpoint.GET_ID_LISTS: + self.id_list_worker = grpc_worker + elif endpoint == NetworkEndpoint.LOG_EVENT: + self.log_event_worker = grpc_worker + elif endpoint == NetworkEndpoint.ALL: + self.log_event_worker = grpc_worker + self.id_list_worker = grpc_worker + self.dcs_worker = grpc_worker + def is_pull_worker(self, endpoint: str) -> bool: if endpoint == NetworkEndpoint.DOWNLOAD_CONFIG_SPECS.value: return self.dcs_worker.is_pull_worker() @@ -115,11 +142,11 @@ def is_pull_worker(self, endpoint: str) -> bool: return True def get_dcs( - self, - on_complete: Any, - since_time: int = 0, - log_on_exception: Optional[bool] = False, - timeout: Optional[int] = None, + self, + on_complete: Any, + since_time: int = 0, + log_on_exception: Optional[bool] = False, + timeout: Optional[int] = None, ): if self.statsig_options.local_mode: logger.warning("Local mode is enabled. Not fetching DCS.") @@ -127,11 +154,11 @@ def get_dcs( self.dcs_worker.get_dcs(on_complete, since_time, log_on_exception, timeout) def get_dcs_fallback( - self, - on_complete: Any, - since_time: int = 0, - log_on_exception: Optional[bool] = False, - timeout: Optional[int] = None, + self, + on_complete: Any, + since_time: int = 0, + log_on_exception: Optional[bool] = False, + timeout: Optional[int] = None, ): if self.statsig_options.local_mode: logger.warning("Local mode is enabled. Not fetching DCS.") @@ -140,9 +167,9 @@ def get_dcs_fallback( NetworkEndpoint.DOWNLOAD_CONFIG_SPECS ) is_proxy_dcs = ( - dcs_proxy - and dcs_proxy.proxy_address != STATSIG_CDN - or self.statsig_options.api_for_download_config_specs != STATSIG_CDN + dcs_proxy + and dcs_proxy.proxy_address != STATSIG_CDN + or self.statsig_options.api_for_download_config_specs != STATSIG_CDN ) if is_proxy_dcs: self.http_worker.get_dcs_fallback( @@ -150,10 +177,10 @@ def get_dcs_fallback( ) def get_id_lists( - self, - on_complete: Any, - log_on_exception: Optional[bool] = False, - timeout: Optional[int] = None, + self, + on_complete: Any, + log_on_exception: Optional[bool] = False, + timeout: Optional[int] = None, ): if self.statsig_options.local_mode: logger.warning("Local mode is enabled. Not fetching ID Lists.") @@ -161,10 +188,10 @@ def get_id_lists( self.id_list_worker.get_id_lists(on_complete, log_on_exception, timeout) def get_id_lists_fallback( - self, - on_complete: Any, - log_on_exception: Optional[bool] = False, - timeout: Optional[int] = None, + self, + on_complete: Any, + log_on_exception: Optional[bool] = False, + timeout: Optional[int] = None, ): if self.statsig_options.local_mode: logger.warning("Local mode is enabled. Not fetching ID Lists.") @@ -199,8 +226,8 @@ def listen_for_dcs(self, listeners: IStreamingListeners, fallback: Callable): if isinstance(self.dcs_worker, IStatsigWebhookWorker): self.dcs_worker.start_listen_for_config_spec(listeners) interval = ( - self.statsig_options.rulesets_sync_interval - or DEFAULT_RULESET_SYNC_INTERVAL + self.statsig_options.rulesets_sync_interval + or DEFAULT_RULESET_SYNC_INTERVAL ) callbacks = StreamingFallback( fn=fallback,