Skip to content

Commit

Permalink
feat: better reconnect gherkins
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Schrottner <[email protected]>
  • Loading branch information
aepfli committed Dec 18, 2024
1 parent 335cce7 commit 578692f
Show file tree
Hide file tree
Showing 26 changed files with 837 additions and 348 deletions.
2 changes: 1 addition & 1 deletion providers/openfeature-provider-flagd/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [
keywords = []
dependencies = [
"openfeature-sdk>=0.6.0",
"grpcio>=1.68.0",
"grpcio>=1.68.1",
"protobuf>=4.25.2",
"mmh3>=4.1.0",
"panzi-json-logic>=1.0.1",
Expand Down
8 changes: 8 additions & 0 deletions providers/openfeature-provider-flagd/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@ markers =
in-process: tests for rpc mode.
customCert: Supports custom certs.
unixsocket: Supports unixsockets.
targetURI: Supports targetURI.
grace: Supports grace attempts.
targeting: Supports targeting.
fractional: Supports fractional.
string: Supports string.
semver: Supports semver.
reconnect: Supports reconnect.
events: Supports events.
sync: Supports sync.
caching: Supports caching.
offline: Supports offline.
bdd_features_base_dir = tests/features
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class CacheType(Enum):
DEFAULT_RESOLVER_TYPE = ResolverType.RPC
DEFAULT_RETRY_BACKOFF = 1000
DEFAULT_RETRY_BACKOFF_MAX = 120000
DEFAULT_RETRY_GRACE_ATTEMPTS = 5
DEFAULT_RETRY_GRACE_PERIOD = 5
DEFAULT_STREAM_DEADLINE = 600000
DEFAULT_TLS = False

Expand All @@ -41,7 +41,7 @@ class CacheType(Enum):
ENV_VAR_RESOLVER_TYPE = "FLAGD_RESOLVER"
ENV_VAR_RETRY_BACKOFF_MS = "FLAGD_RETRY_BACKOFF_MS"
ENV_VAR_RETRY_BACKOFF_MAX_MS = "FLAGD_RETRY_BACKOFF_MAX_MS"
ENV_VAR_RETRY_GRACE_ATTEMPTS = "FLAGD_RETRY_GRACE_ATTEMPTS"
ENV_VAR_RETRY_GRACE_PERIOD = "FLAGD_RETRY_GRACE_PERIOD"
ENV_VAR_STREAM_DEADLINE_MS = "FLAGD_STREAM_DEADLINE_MS"
ENV_VAR_TLS = "FLAGD_TLS"

Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__( # noqa: PLR0913
offline_poll_interval_ms: typing.Optional[int] = None,
retry_backoff_ms: typing.Optional[int] = None,
retry_backoff_max_ms: typing.Optional[int] = None,
retry_grace_attempts: typing.Optional[int] = None,
retry_grace_period: typing.Optional[int] = None,
deadline_ms: typing.Optional[int] = None,
stream_deadline_ms: typing.Optional[int] = None,
keep_alive_time: typing.Optional[int] = None,
Expand Down Expand Up @@ -115,14 +115,14 @@ def __init__( # noqa: PLR0913
else retry_backoff_max_ms
)

self.retry_grace_attempts: int = (
self.retry_grace_period: int = (
int(
env_or_default(
ENV_VAR_RETRY_GRACE_ATTEMPTS, DEFAULT_RETRY_GRACE_ATTEMPTS, cast=int
ENV_VAR_RETRY_GRACE_PERIOD, DEFAULT_RETRY_GRACE_PERIOD, cast=int
)
)
if retry_grace_attempts is None
else retry_grace_attempts
if retry_grace_period is None
else retry_grace_period
)

self.resolver = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,34 @@ def __init__( # noqa: PLR0913
host: typing.Optional[str] = None,
port: typing.Optional[int] = None,
tls: typing.Optional[bool] = None,
deadline: typing.Optional[int] = None,
deadline_ms: typing.Optional[int] = None,
timeout: typing.Optional[int] = None,
retry_backoff_ms: typing.Optional[int] = None,
resolver_type: typing.Optional[ResolverType] = None,
offline_flag_source_path: typing.Optional[str] = None,
stream_deadline_ms: typing.Optional[int] = None,
keep_alive_time: typing.Optional[int] = None,
cache_type: typing.Optional[CacheType] = None,
cache: typing.Optional[CacheType] = None,
max_cache_size: typing.Optional[int] = None,
retry_backoff_max_ms: typing.Optional[int] = None,
retry_grace_attempts: typing.Optional[int] = None,
retry_grace_period: typing.Optional[int] = None,
):
"""
Create an instance of the FlagdProvider
:param host: the host to make requests to
:param port: the port the flagd service is available on
:param tls: enable/disable secure TLS connectivity
:param deadline: the maximum to wait before a request times out
:param deadline_ms: the maximum to wait before a request times out
:param timeout: the maximum time to wait before a request times out
:param retry_backoff_ms: the number of milliseconds to backoff
:param offline_flag_source_path: the path to the flag source file
:param stream_deadline_ms: the maximum time to wait before a request times out
:param keep_alive_time: the number of milliseconds to keep alive
:param resolver_type: the type of resolver to use
"""
if deadline is None and timeout is not None:
deadline = timeout * 1000
if deadline_ms is None and timeout is not None:
deadline_ms = timeout * 1000
warnings.warn(
"'timeout' property is deprecated, please use 'deadline' instead, be aware that 'deadline' is in milliseconds",
DeprecationWarning,
Expand All @@ -81,15 +81,15 @@ def __init__( # noqa: PLR0913
host=host,
port=port,
tls=tls,
deadline_ms=deadline,
deadline_ms=deadline_ms,
retry_backoff_ms=retry_backoff_ms,
retry_backoff_max_ms=retry_backoff_max_ms,
retry_grace_attempts=retry_grace_attempts,
retry_grace_period=retry_grace_period,
resolver=resolver_type,
offline_flag_source_path=offline_flag_source_path,
stream_deadline_ms=stream_deadline_ms,
keep_alive_time=keep_alive_time,
cache=cache_type,
cache=cache,
max_cache_size=max_cache_size,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cachebox import BaseCacheImpl, LRUCache
from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct
from grpc import ChannelConnectivity

from openfeature.evaluation_context import EvaluationContext
from openfeature.event import ProviderEventDetails
Expand Down Expand Up @@ -47,53 +48,60 @@ def __init__(
[ProviderEventDetails], None
],
):
self.active = False
self.config = config
self.emit_provider_ready = emit_provider_ready
self.emit_provider_error = emit_provider_error
self.emit_provider_stale = emit_provider_stale
self.emit_provider_configuration_changed = emit_provider_configuration_changed
self.cache: typing.Optional[BaseCacheImpl] = (
LRUCache(maxsize=self.config.max_cache_size)
if self.config.cache == CacheType.LRU
else None
)
self.stub, self.channel = self._create_stub()
self.retry_backoff_seconds = config.retry_backoff_ms * 0.001
self.retry_backoff_max_seconds = config.retry_backoff_max_ms * 0.001
self.retry_grace_attempts = config.retry_grace_attempts
self.cache: typing.Optional[BaseCacheImpl] = self._create_cache()

self.retry_grace_period = config.retry_grace_period
self.streamline_deadline_seconds = config.stream_deadline_ms * 0.001
self.deadline = config.deadline_ms * 0.001
self.connected = False

def _create_stub(
self,
) -> typing.Tuple[evaluation_pb2_grpc.ServiceStub, grpc.Channel]:
config = self.config
channel_factory = grpc.secure_channel if config.tls else grpc.insecure_channel
channel = channel_factory(

# Create the channel with the service config
options = [
("grpc.keepalive_time_ms", config.keep_alive_time),
("grpc.initial_reconnect_backoff_ms", config.retry_backoff_ms),
("grpc.max_reconnect_backoff_ms", config.retry_backoff_max_ms),
("grpc.min_reconnect_backoff_ms", config.deadline_ms),
]

self.channel = channel_factory(
f"{config.host}:{config.port}",
options=(("grpc.keepalive_time_ms", config.keep_alive_time),),
options=options,
)
stub = evaluation_pb2_grpc.ServiceStub(channel)
self.stub = evaluation_pb2_grpc.ServiceStub(self.channel)

return stub, channel
self.thread: typing.Optional[threading.Thread] = None
self.timer: typing.Optional[threading.Timer] = None
self.active = False

def _create_cache(self):
return (
LRUCache(maxsize=self.config.max_cache_size)
if self.config.cache == CacheType.LRU
else None
)

def initialize(self, evaluation_context: EvaluationContext) -> None:
self.connect()

def shutdown(self) -> None:
self.active = False
self.channel.close()
if self.cache:
self.cache.clear()

def connect(self) -> None:
self.active = True
self.thread = threading.Thread(
target=self.listen, daemon=True, name="FlagdGrpcServiceWorkerThread"
)
self.thread.start()

# Run monitoring in a separate thread
self.monitor_thread = threading.Thread(
target=self.monitor, daemon=True, name="FlagdGrpcServiceMonitorThread"
)
self.monitor_thread.start()
## block until ready or deadline reached
timeout = self.deadline + time.time()
while not self.connected and time.time() < timeout:
Expand All @@ -105,81 +113,87 @@ def connect(self) -> None:
"Blocking init finished before data synced. Consider increasing startup deadline to avoid inconsistent evaluations."
)

def monitor(self) -> None:
def state_change_callback(new_state: ChannelConnectivity) -> None:
logger.debug(f"gRPC state change: {new_state}")
if new_state == ChannelConnectivity.READY:
if not self.thread or not self.thread.is_alive():
self.thread = threading.Thread(
target=self.listen,
daemon=True,
name="FlagdGrpcServiceWorkerThread",
)
self.thread.start()

if self.timer and self.timer.is_alive():
logger.debug("gRPC error timer expired")
self.timer.cancel()

elif new_state == ChannelConnectivity.TRANSIENT_FAILURE:
# this is the failed reonnect attempt so we are going into stale
self.emit_provider_stale(
ProviderEventDetails(
message="gRPC sync disconnected, reconnecting",
)
)
# adding a timer, so we can emit the error event after time
self.timer = threading.Timer(self.retry_grace_period, self.emit_error)

logger.debug("gRPC error timer started")
self.timer.start()
self.connected = False

self.channel.subscribe(state_change_callback, try_to_connect=True)

def emit_error(self) -> None:
logger.debug("gRPC error emitted")
if self.cache is not None:
self.cache.clear()
self.emit_provider_error(
ProviderEventDetails(
message="gRPC sync disconnected, reconnecting",
error_code=ErrorCode.GENERAL,
)
)

def listen(self) -> None:
retry_delay = self.retry_backoff_seconds
logger.info("gRPC starting listener thread")
call_args = (
{"timeout": self.streamline_deadline_seconds}
if self.streamline_deadline_seconds > 0
else {}
)
retry_counter = 0
while self.active:
request = evaluation_pb2.EventStreamRequest()
request = evaluation_pb2.EventStreamRequest()

# defining a never ending loop to recreate the stream
while self.active:
try:
logger.debug("Setting up gRPC sync flags connection")
for message in self.stub.EventStream(request, **call_args):
logger.info("Setting up gRPC sync flags connection")
for message in self.stub.EventStream(
request, wait_for_ready=True, **call_args
):
if message.type == "provider_ready":
if not self.connected:
self.emit_provider_ready(
ProviderEventDetails(
message="gRPC sync connection established"
)
self.connected = True
self.emit_provider_ready(
ProviderEventDetails(
message="gRPC sync connection established"
)
self.connected = True
retry_counter = 0
# reset retry delay after successsful read
retry_delay = self.retry_backoff_seconds

)
elif message.type == "configuration_change":
data = MessageToDict(message)["data"]
self.handle_changed_flags(data)

if not self.active:
logger.info("Terminating gRPC sync thread")
return
except grpc.RpcError as e:
logger.error(f"SyncFlags stream error, {e.code()=} {e.details()=}")
# re-create the stub if there's a connection issue - otherwise reconnect does not work as expected
self.stub, self.channel = self._create_stub()
except grpc.RpcError as e: # noqa: PERF203
# although it seems like this error log is not interesting, without it, the retry is not working as expected
logger.debug(f"SyncFlags stream error, {e.code()=} {e.details()=}")
except ParseError:
logger.exception(
f"Could not parse flag data using flagd syntax: {message=}"
)

self.connected = False
self.on_connection_error(retry_counter, retry_delay)

retry_delay = self.handle_retry(retry_counter, retry_delay)

retry_counter = retry_counter + 1

def handle_retry(self, retry_counter: int, retry_delay: float) -> float:
if retry_counter == 0:
logger.info("gRPC sync disconnected, reconnecting immediately")
else:
logger.info(f"gRPC sync disconnected, reconnecting in {retry_delay}s")
time.sleep(retry_delay)
retry_delay = min(1.1 * retry_delay, self.retry_backoff_max_seconds)
return retry_delay

def on_connection_error(self, retry_counter: int, retry_delay: float) -> None:
if retry_counter == self.retry_grace_attempts:
if self.cache:
self.cache.clear()
self.emit_provider_error(
ProviderEventDetails(
message=f"gRPC sync disconnected, reconnecting in {retry_delay}s",
error_code=ErrorCode.GENERAL,
)
)
elif retry_counter == 1:
self.emit_provider_stale(
ProviderEventDetails(
message=f"gRPC sync disconnected, reconnecting in {retry_delay}s",
)
)

def handle_changed_flags(self, data: typing.Any) -> None:
changed_flags = list(data["flags"].keys())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ def __init__(
self.last_modified = 0.0
self.flag_data: typing.Mapping[str, Flag] = {}
self.load_data()
self.active = True
self.thread = threading.Thread(target=self.refresh_file, daemon=True)
self.thread.start()

def shutdown(self) -> None:
self.active = False
pass

def get_flag(self, key: str) -> typing.Optional[Flag]:
return self.flag_data.get(key)

def refresh_file(self) -> None:
while True:
while self.active:
time.sleep(self.poll_interval_seconds)
logger.debug("checking for new flag store contents from file")
last_modified = os.path.getmtime(self.file_path)
Expand Down
Loading

0 comments on commit 578692f

Please sign in to comment.