Skip to content

Commit

Permalink
feat: fastest response strategy plugin (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
joyc-bq authored Mar 13, 2024
1 parent aaf805a commit c0921b6
Show file tree
Hide file tree
Showing 17 changed files with 888 additions and 6 deletions.
11 changes: 11 additions & 0 deletions aws_advanced_python_wrapper/driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class DriverDialect(ABC):
"""
Driver dialects help the driver-agnostic AWS Python Driver interface with the driver-specific functionality of the underlying Python Driver.
"""
_QUERY = "SELECT 1"

_executor: ClassVar[Executor] = ThreadPoolExecutor()
_dialect_code: str = DriverDialectCodes.GENERIC
Expand Down Expand Up @@ -152,3 +153,13 @@ def unwrap_connection(self, conn_obj: object) -> Any:

def transfer_session_state(self, from_conn: Connection, to_conn: Connection):
return

def ping(self, conn: Connection) -> bool:
try:
with conn.cursor() as cursor:
query = DriverDialect._QUERY
self.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=10)
cursor.fetchone()
return True
except Exception:
return False
345 changes: 345 additions & 0 deletions aws_advanced_python_wrapper/fastest_response_strategy_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from copy import copy
from dataclasses import dataclass
from datetime import datetime
from threading import Event, Lock, Thread
from time import sleep
from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, List, Optional,
Set, Tuple)

from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.hostselector import RandomHostSelector
from aws_advanced_python_wrapper.plugin import Plugin
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
SlidingExpirationCacheWithCleanupThread
from aws_advanced_python_wrapper.utils.telemetry.telemetry import (
TelemetryContext, TelemetryFactory, TelemetryGauge, TelemetryTraceLevel)

if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.utils.notifications import HostEvent

logger = Logger(__name__)

MAX_VALUE = 2147483647


class FastestResponseStrategyPlugin(Plugin):
_FASTEST_RESPONSE_STRATEGY_NAME = "fastest_response"
_SUBSCRIBED_METHODS: Set[str] = {"accepts_strategy",
"get_host_info_by_strategy",
"notify_host_list_changed"}

def __init__(self, plugin_service: PluginService, props: Properties):
self._plugin_service = plugin_service
self._properties = props
self._host_response_time_service: HostResponseTimeService = \
HostResponseTimeService(plugin_service, props, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get_int(props))
self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MILLIS.get_int(props) * 10 ^ 6
self._random_host_selector = RandomHostSelector()
self._cached_fastest_response_host_by_role: CacheMap[str, HostInfo] = CacheMap()
self._hosts: Tuple[HostInfo, ...] = ()

@property
def subscribed_methods(self) -> Set[str]:
return self._SUBSCRIBED_METHODS

def connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
return self._connect(host_info, props, is_initial_connection, connect_func)

def force_connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
force_connect_func: Callable) -> Connection:
return self._connect(host_info, props, is_initial_connection, force_connect_func)

def _connect(
self,
host: HostInfo,
properties: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
conn = connect_func()

if is_initial_connection:
self._plugin_service.refresh_host_list(conn)

return conn

def accepts_strategy(self, role: HostRole, strategy: str) -> bool:
return strategy == FastestResponseStrategyPlugin._FASTEST_RESPONSE_STRATEGY_NAME

def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo:
if not self.accepts_strategy(role, strategy):
logger.error("FastestResponseStrategyPlugin.UnsupportedHostSelectorStrategy", strategy)
raise AwsWrapperError(Messages.get_formatted("FastestResponseStrategyPlugin.UnsupportedHostSelectorStrategy", strategy))

fastest_response_host: Optional[HostInfo] = self._cached_fastest_response_host_by_role.get(role.name)
if fastest_response_host is not None:

# Found a fastest host. Let's find it in the the latest topology.
for host in self._plugin_service.hosts:
if host == fastest_response_host:
# found the fastest host in the topology
return host
# It seems that the fastest cached host isn't in the latest topology.
# Let's ignore cached results and find the fastest host.

# Cached result isn't available. Need to find the fastest response time host.
eligible_hosts: List[FastestResponseStrategyPlugin.ResponseTimeTuple] = []
for host in self._plugin_service.hosts:
if role == host.role:
response_time_tuple = FastestResponseStrategyPlugin.ResponseTimeTuple(host,
self._host_response_time_service.get_response_time(host))
eligible_hosts.append(response_time_tuple)

# Sort by response time then retrieve the first host
sorted_eligible_hosts: List[FastestResponseStrategyPlugin.ResponseTimeTuple] = \
sorted(eligible_hosts, key=lambda x: x.response_time)

calculated_fastest_response_host = sorted_eligible_hosts[0].host_info
if calculated_fastest_response_host is None or \
self._host_response_time_service.get_response_time(calculated_fastest_response_host) == MAX_VALUE:
logger.debug("FastestResponseStrategyPlugin.RandomHostSelected")
return self._random_host_selector.get_host(self._plugin_service.hosts, role, self._properties)

self._cached_fastest_response_host_by_role.put(role.name,
calculated_fastest_response_host,
self._cache_expiration_nanos)

return calculated_fastest_response_host

def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]):
self._hosts = self._plugin_service.hosts
if self._host_response_time_service is not None:
self._host_response_time_service.set_hosts(self._hosts)

@dataclass
class ResponseTimeTuple:
host_info: HostInfo
response_time: int


class FastestResponseStrategyPluginFactory:

def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return FastestResponseStrategyPlugin(plugin_service, props)


class HostResponseTimeMonitor:

_MONITORING_PROPERTY_PREFIX: str = "frt-"
_NUM_OF_MEASURES: int = 5
_DEFAULT_CONNECT_TIMEOUT_SEC = 10

def __init__(self, plugin_service: PluginService, host_info: HostInfo, props: Properties, interval_ms: int):
self._plugin_service = plugin_service
self._host_info = host_info
self._properties = props
self._interval_ms = interval_ms

self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory()
self._response_time: int = MAX_VALUE
self._lock: Lock = Lock()
self._monitoring_conn: Optional[Connection] = None
self._is_stopped: Event = Event()

self._host_id: Optional[str] = self._host_info.host_id
if self._host_id is None or self._host_id == "":
self._host_id = self._host_info.host

self._daemon_thread: Thread = Thread(daemon=True, target=self.run)

# Report current response time (in milliseconds) to telemetry engine.
# Report -1 if response time couldn't be measured.
self._response_time_gauge: TelemetryGauge = \
self._telemetry_factory.create_gauge("frt.response.time." + self._host_id,
lambda: self._response_time if self._response_time != MAX_VALUE else -1)
self._daemon_thread.start()

@property
def response_time(self):
return self._response_time

@response_time.setter
def response_time(self, response_time: int):
self._response_time = response_time

@property
def host_info(self):
return self._host_info

@property
def is_stopped(self):
return self._is_stopped.is_set()

def close(self):
self._is_stopped.set()
self._daemon_thread.join(5)
logger.debug("HostResponseTimeMonitor.Stopped", self._host_info.host)

def _get_current_time(self):
return datetime.now().microsecond / 1000 # milliseconds

def run(self):
context: TelemetryContext = self._telemetry_factory.open_telemetry_context(
"node response time thread", TelemetryTraceLevel.TOP_LEVEL)
context.set_attribute("url", self._host_info.url)
try:
while not self.is_stopped:
self._open_connection()

if self._monitoring_conn is not None:

response_time_sum = 0
count = 0
for i in range(self._NUM_OF_MEASURES):
if self.is_stopped:
break
start_time = self._get_current_time()
if self._plugin_service.driver_dialect.ping(self._monitoring_conn):
calculated_response_time = self._get_current_time() - start_time
response_time_sum = response_time_sum + calculated_response_time
count = count + 1

if count > 0:
self.response_time = response_time_sum / count
else:
self.response_time = MAX_VALUE
logger.debug("HostResponseTimeMonitor.ResponseTime", self._host_info.host, self._response_time)

sleep(self._interval_ms / 1000)

except InterruptedError:
# exit thread
logger.debug("HostResponseTimeMonitor.InterruptedExceptionDuringMonitoring", self._host_info.host)
except Exception as e:
# this should not be reached; log and exit thread
logger.debug("HostResponseTimeMonitor.ExceptionDuringMonitoringStop",
self._host_info.host,
e) # print full trace stack of the exception.
finally:
self._is_stopped.set()
if self._monitoring_conn is not None:
try:
self._monitoring_conn.close()
except Exception:
# Do nothing
pass

if context is not None:
context.close_context()

def _open_connection(self):
try:
driver_dialect = self._plugin_service.driver_dialect
if self._monitoring_conn is None or driver_dialect.is_closed(self._monitoring_conn):
monitoring_conn_properties: Properties = copy(self._properties)
for key, value in self._properties.items():
if key.startswith(self._MONITORING_PROPERTY_PREFIX):
monitoring_conn_properties[key[len(self._MONITORING_PROPERTY_PREFIX):len(key)]] = value
monitoring_conn_properties.pop(key, None)

# Set a default connect timeout if the user hasn't configured one
if monitoring_conn_properties.get(WrapperProperties.CONNECT_TIMEOUT_SEC.name, None) is None:
monitoring_conn_properties[WrapperProperties.CONNECT_TIMEOUT_SEC.name] = HostResponseTimeMonitor._DEFAULT_CONNECT_TIMEOUT_SEC

logger.debug("HostResponseTimeMonitor.OpeningConnection", self._host_info.url)
self._monitoring_conn = self._plugin_service.force_connect(self._host_info, monitoring_conn_properties, None)
logger.debug("HostResponseTimeMonitor.OpenedConnection", self._host_info.url)

except Exception:
if self._monitoring_conn is not None:
try:
self._monitoring_conn.close()
except Exception:
pass # ignore

self._monitoring_conn = None


class HostResponseTimeService:
_CACHE_EXPIRATION_NS: int = 6 * 10 ^ 11 # 10 minutes
_CACHE_CLEANUP_NS: int = 6 * 10 ^ 10 # 1 minute
_lock: Lock = Lock()
_monitoring_nodes: ClassVar[SlidingExpirationCacheWithCleanupThread[str, HostResponseTimeMonitor]] = \
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NS,
should_dispose_func=lambda monitor: True,
item_disposal_func=lambda monitor: HostResponseTimeService._monitor_close(monitor))

def __init__(self, plugin_service: PluginService, props: Properties, interval_ms: int):
self._plugin_service = plugin_service
self._properties = props
self._interval_ms = interval_ms
self._hosts: Tuple[HostInfo, ...] = ()
self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory()
self._host_count_gauge: TelemetryGauge = self._telemetry_factory.create_gauge("frt.nodes.count", lambda: len(self._monitoring_nodes))

@property
def hosts(self) -> Tuple[HostInfo, ...]:
return self._hosts

@hosts.setter
def hosts(self, new_hosts: Tuple[HostInfo, ...]):
self._hosts = new_hosts

@staticmethod
def _monitor_close(monitor: HostResponseTimeMonitor):
try:
monitor.close()
except Exception:
pass

def get_response_time(self, host_info: HostInfo) -> int:
monitor: Optional[HostResponseTimeMonitor] = HostResponseTimeService._monitoring_nodes.get(host_info.url)
if monitor is None:
return MAX_VALUE
return monitor.response_time

def set_hosts(self, new_hosts: Tuple[HostInfo, ...]) -> None:
old_hosts_dict = {x.url: x for x in self.hosts}
self.hosts = new_hosts

for host in self.hosts:
if host.url not in old_hosts_dict:
with self._lock:
self._monitoring_nodes.compute_if_absent(host.url,
lambda _: HostResponseTimeMonitor(
self._plugin_service,
host,
self._properties,
self._interval_ms), self._CACHE_EXPIRATION_NS)
3 changes: 2 additions & 1 deletion aws_advanced_python_wrapper/host_monitoring_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ class Monitor:
_INACTIVE_SLEEP_MS = 100
_MIN_HOST_CHECK_TIMEOUT_MS = 3000
_MONITORING_PROPERTY_PREFIX = "monitoring-"
_QUERY = "SELECT 1"

def __init__(
self,
Expand Down Expand Up @@ -560,7 +561,7 @@ def _is_host_available(self, conn: Connection, timeout_sec: float) -> bool:
def _execute_conn_check(self, conn: Connection, timeout_sec: float):
driver_dialect = self._plugin_service.driver_dialect
with conn.cursor() as cursor:
query = "SELECT 1"
query = Monitor._QUERY
driver_dialect.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=timeout_sec)
cursor.fetchone()

Expand Down
3 changes: 3 additions & 0 deletions aws_advanced_python_wrapper/mysql_driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def transfer_session_state(self, from_conn: Connection, to_conn: Connection):
isinstance(to_conn, CMySQLConnection) or isinstance(to_conn, MySQLConnection)):
to_conn.autocommit = from_conn.autocommit

def ping(self, conn: Connection) -> bool:
return not self.is_closed(conn)

def prepare_connect_info(self, host_info: HostInfo, original_props: Properties) -> Properties:
driver_props: Properties = Properties(original_props.copy())
PropertiesUtils.remove_wrapper_props(driver_props)
Expand Down
Loading

0 comments on commit c0921b6

Please sign in to comment.