Skip to content

Commit

Permalink
selectively import and load grpc workers (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
kat-statsig authored Sep 3, 2024
1 parent b59b31d commit 82189f2
Showing 1 changed file with 77 additions and 50 deletions.
127 changes: 77 additions & 50 deletions statsig/statsig_network.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import threading
from typing import Any, Callable, Optional

Expand All @@ -7,11 +8,9 @@
DEFAULT_RULESET_SYNC_INTERVAL,
StatsigOptions,
STATSIG_CDN,
STATSIG_API,
STATSIG_API, ProxyConfig,
)
from .diagnostics import Diagnostics
from .grpc_websocket_worker import GRPCWebsocketWorker
from .grpc_worker import GRPCWorker
from .globals import logger
from .http_worker import HttpWorker
from .interface_network import (
Expand All @@ -27,7 +26,7 @@

class StreamingFallback(IStreamingFallback):
def __init__(
self, fn: Callable, interval: int, name: str, eb: _StatsigErrorBoundary
self, fn: Callable, interval: int, name: str, eb: _StatsigErrorBoundary
):
self.fn = fn
self.stop_event = threading.Event()
Expand Down Expand Up @@ -62,17 +61,19 @@ def _sync(self):

class _StatsigNetwork:
def __init__(
self,
sdk_key: str,
options: StatsigOptions,
statsig_metadata: dict,
error_boundary: _StatsigErrorBoundary,
diagnostics: Diagnostics,
shutdown_event,
self,
sdk_key: str,
options: StatsigOptions,
statsig_metadata: dict,
error_boundary: _StatsigErrorBoundary,
diagnostics: Diagnostics,
shutdown_event,
):
self.sdk_key = sdk_key
self.error_boundary = error_boundary
self.statsig_options = options
self.diagnostics = diagnostics
self.shutdown_event = shutdown_event
self.statsig_metadata = statsig_metadata
defaultHttpWorker: IStatsigNetworkWorker = HttpWorker(
sdk_key, options, statsig_metadata, error_boundary, diagnostics
Expand All @@ -85,26 +86,52 @@ def __init__(
protocol = config.protocol
worker = defaultHttpWorker
if protocol == NetworkProtocol.GRPC:
worker = GRPCWorker(sdk_key, config)
self.load_grpc_worker(endpoint, config)
elif protocol == NetworkProtocol.GRPC_WEBSOCKET:
worker = GRPCWebsocketWorker(
sdk_key,
config,
options,
error_boundary,
diagnostics,
shutdown_event,
)

if endpoint == NetworkEndpoint.DOWNLOAD_CONFIG_SPECS:
self.dcs_worker = worker
elif endpoint == NetworkEndpoint.GET_ID_LISTS:
self.id_list_worker = worker
elif endpoint == NetworkEndpoint.LOG_EVENT:
self.log_event_worker = worker
self.load_grpc_websocket_worker(endpoint, config)

self._background_download_configs_from_statsig = None
self._background_download_id_lists_from_statsig = None

def load_grpc_websocket_worker(self, endpoint: NetworkEndpoint, config: ProxyConfig):
grpc_worker_module = importlib.import_module("statsig.grpc_websocket_worker")
grpc_webhook_worker_class = getattr(grpc_worker_module, 'GRPCWebsocketWorker')
grpc_webhook_worker = grpc_webhook_worker_class(
self.sdk_key,
config,
self.statsig_options,
self.error_boundary,
self.diagnostics,
self.shutdown_event,
)
if endpoint == NetworkEndpoint.DOWNLOAD_CONFIG_SPECS:
self.dcs_worker = grpc_webhook_worker
elif endpoint == NetworkEndpoint.GET_ID_LISTS:
self.id_list_worker = grpc_webhook_worker
elif endpoint == NetworkEndpoint.LOG_EVENT:
self.log_event_worker = grpc_webhook_worker
elif endpoint == NetworkEndpoint.ALL:
self.log_event_worker = grpc_webhook_worker
self.id_list_worker = grpc_webhook_worker
self.dcs_worker = grpc_webhook_worker

def load_grpc_worker(self, endpoint: NetworkEndpoint, config: ProxyConfig):
grpc_worker_module = importlib.import_module("statsig.grpc_worker")
grpc_worker_class = getattr(grpc_worker_module, 'GRPCWorker')
grpc_worker = grpc_worker_class(
self.sdk_key, config
)
if endpoint == NetworkEndpoint.DOWNLOAD_CONFIG_SPECS:
self.dcs_worker = grpc_worker
elif endpoint == NetworkEndpoint.GET_ID_LISTS:
self.id_list_worker = grpc_worker
elif endpoint == NetworkEndpoint.LOG_EVENT:
self.log_event_worker = grpc_worker
elif endpoint == NetworkEndpoint.ALL:
self.log_event_worker = grpc_worker
self.id_list_worker = grpc_worker
self.dcs_worker = grpc_worker

def is_pull_worker(self, endpoint: str) -> bool:
if endpoint == NetworkEndpoint.DOWNLOAD_CONFIG_SPECS.value:
return self.dcs_worker.is_pull_worker()
Expand All @@ -115,23 +142,23 @@ def is_pull_worker(self, endpoint: str) -> bool:
return True

def get_dcs(
self,
on_complete: Any,
since_time: int = 0,
log_on_exception: Optional[bool] = False,
timeout: Optional[int] = None,
self,
on_complete: Any,
since_time: int = 0,
log_on_exception: Optional[bool] = False,
timeout: Optional[int] = None,
):
if self.statsig_options.local_mode:
logger.warning("Local mode is enabled. Not fetching DCS.")
return
self.dcs_worker.get_dcs(on_complete, since_time, log_on_exception, timeout)

def get_dcs_fallback(
self,
on_complete: Any,
since_time: int = 0,
log_on_exception: Optional[bool] = False,
timeout: Optional[int] = None,
self,
on_complete: Any,
since_time: int = 0,
log_on_exception: Optional[bool] = False,
timeout: Optional[int] = None,
):
if self.statsig_options.local_mode:
logger.warning("Local mode is enabled. Not fetching DCS.")
Expand All @@ -140,31 +167,31 @@ def get_dcs_fallback(
NetworkEndpoint.DOWNLOAD_CONFIG_SPECS
)
is_proxy_dcs = (
dcs_proxy
and dcs_proxy.proxy_address != STATSIG_CDN
or self.statsig_options.api_for_download_config_specs != STATSIG_CDN
dcs_proxy
and dcs_proxy.proxy_address != STATSIG_CDN
or self.statsig_options.api_for_download_config_specs != STATSIG_CDN
)
if is_proxy_dcs:
self.http_worker.get_dcs_fallback(
on_complete, since_time, log_on_exception, timeout
)

def get_id_lists(
self,
on_complete: Any,
log_on_exception: Optional[bool] = False,
timeout: Optional[int] = None,
self,
on_complete: Any,
log_on_exception: Optional[bool] = False,
timeout: Optional[int] = None,
):
if self.statsig_options.local_mode:
logger.warning("Local mode is enabled. Not fetching ID Lists.")
return
self.id_list_worker.get_id_lists(on_complete, log_on_exception, timeout)

def get_id_lists_fallback(
self,
on_complete: Any,
log_on_exception: Optional[bool] = False,
timeout: Optional[int] = None,
self,
on_complete: Any,
log_on_exception: Optional[bool] = False,
timeout: Optional[int] = None,
):
if self.statsig_options.local_mode:
logger.warning("Local mode is enabled. Not fetching ID Lists.")
Expand Down Expand Up @@ -199,8 +226,8 @@ def listen_for_dcs(self, listeners: IStreamingListeners, fallback: Callable):
if isinstance(self.dcs_worker, IStatsigWebhookWorker):
self.dcs_worker.start_listen_for_config_spec(listeners)
interval = (
self.statsig_options.rulesets_sync_interval
or DEFAULT_RULESET_SYNC_INTERVAL
self.statsig_options.rulesets_sync_interval
or DEFAULT_RULESET_SYNC_INTERVAL
)
callbacks = StreamingFallback(
fn=fallback,
Expand Down

0 comments on commit 82189f2

Please sign in to comment.