Skip to content

Commit

Permalink
Issue #112 replace kazoo_client_factory config with zk_memoizer_tracking
Browse files Browse the repository at this point in the history
Old system in prime_caches depended on mutable configs, so some refactoring was required to move to alternative with unmutable configs
  • Loading branch information
soxofaan committed Feb 20, 2024
1 parent 09171b8 commit 47a02e6
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 87 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.

The format is roughly based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [0.20.0]

- Replace `kazoo_client_factory` config with `AggregatorBackendConfig.zk_memoizer_tracking` ([#112](https://github.com/Open-EO/openeo-aggregator/issues/112))

## [0.19.0]

- Support regexes in `collection_whitelist` config (eu-cdse/openeo-cdse-infra#54)
Expand Down
2 changes: 1 addition & 1 deletion src/openeo_aggregator/about.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from typing import Optional

__version__ = "0.19.0a1"
__version__ = "0.20.0a1"


def log_version_info(logger: Optional[logging.Logger] = None):
Expand Down
57 changes: 11 additions & 46 deletions src/openeo_aggregator/background/prime_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import functools
import logging
from pathlib import Path
from typing import Any, List, Optional, Sequence, Union
from typing import List, Optional, Union

from kazoo.client import KazooClient
from openeo.util import TimingLogger
from openeo_driver.util.logging import (
LOG_HANDLER_FILE_JSON,
Expand All @@ -18,39 +17,21 @@
setup_logging,
)

import openeo_aggregator.caching
from openeo_aggregator.about import log_version_info
from openeo_aggregator.app import get_aggregator_logging_config
from openeo_aggregator.backend import AggregatorBackendImplementation
from openeo_aggregator.config import (
OPENEO_AGGREGATOR_CONFIG,
AggregatorConfig,
get_backend_config,
get_config,
)
from openeo_aggregator.connection import MultiBackendConnection

_log = logging.getLogger(__name__)


class AttrStatsProxy:
"""
Proxy object to wrap a given object and keep stats of attribute/method usage.
"""

# TODO: move this to a utilities module
# TODO: avoid all these public attributes that could collide with existing attributes of the proxied object
__slots__ = ["target", "to_track", "stats"]

def __init__(self, target: Any, to_track: Sequence[str], stats: Optional[dict] = None):
self.target = target
self.to_track = set(to_track)
self.stats = stats if stats is not None else {}

def __getattr__(self, name):
if name in self.to_track:
self.stats[name] = self.stats.get(name, 0) + 1
return getattr(self.target, name)


FAIL_MODE_FAILFAST = "failfast"
FAIL_MODE_WARN = "warn"

Expand Down Expand Up @@ -116,10 +97,6 @@ def prime_caches(
config: AggregatorConfig = get_config(config)
_log.info(f"Using config: {config.get('config_source')=}")

# Inject Zookeeper operation statistics
kazoo_stats = {}
_patch_config_for_kazoo_client_stats(config, kazoo_stats)

_log.info(f"Creating AggregatorBackendImplementation with {config.aggregator_backends}")
backends = MultiBackendConnection.from_config(config)
backend_implementation = AggregatorBackendImplementation(backends=backends, config=config)
Expand Down Expand Up @@ -155,26 +132,14 @@ def prime_caches(
with fail_handler():
backend_implementation.processing.get_merged_process_metadata()

zk_writes = sum(kazoo_stats.get(k, 0) for k in ["create", "set"])
_log.info(f"ZooKeeper stats: {kazoo_stats=} {zk_writes=}")
if require_zookeeper_writes and zk_writes == 0:
raise RuntimeError("No Zookeeper writes.")


def _patch_config_for_kazoo_client_stats(config: AggregatorConfig, stats: dict):
orig_kazoo_client_factory = config.kazoo_client_factory or KazooClient
def kazoo_client_factory(**kwargs):
_log.info(f"AttrStatsProxy-wrapping KazooClient with {kwargs=}")
zk = orig_kazoo_client_factory(**kwargs)
return AttrStatsProxy(
target=zk,
to_track=["start", "stop", "create", "get", "set"],
stats=stats,
)

_log.info(f"Patching config with {kazoo_client_factory=}")
# TODO: create a new config instead of updating an existing one?
config.kazoo_client_factory = kazoo_client_factory
if get_backend_config().zk_memoizer_tracking:
kazoo_stats = openeo_aggregator.caching.zk_memoizer_stats
zk_writes = sum(kazoo_stats.get(k, 0) for k in ["create", "set"])
_log.info(f"ZooKeeper stats: {kazoo_stats=} {zk_writes=}")
if require_zookeeper_writes and zk_writes == 0:
raise RuntimeError("No ZooKeeper writes.")
else:
_log.warning(f"ZooKeeper stats: not configured")


if __name__ == "__main__":
Expand Down
16 changes: 12 additions & 4 deletions src/openeo_aggregator/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from kazoo.client import KazooClient
from openeo.util import TimingLogger

from openeo_aggregator.config import AggregatorConfig
from openeo_aggregator.utils import Clock, strip_join
from openeo_aggregator.config import AggregatorConfig, get_backend_config
from openeo_aggregator.utils import AttrStatsProxy, Clock, strip_join

DEFAULT_NAMESPACE = "_default"

Expand Down Expand Up @@ -490,6 +490,8 @@ def _zk_connect_or_not(self) -> Union[KazooClient, None]:
_log.error(f"{self!r} failed to stop connection: {e!r}")


zk_memoizer_stats = {}

def memoizer_from_config(
config: AggregatorConfig,
namespace: str,
Expand All @@ -504,8 +506,14 @@ def get_memoizer(memoizer_type: str, memoizer_conf: dict) -> Memoizer:
elif memoizer_type == "jsondict":
return JsonDictMemoizer(namespace=namespace, default_ttl=memoizer_conf.get("default_ttl"))
elif memoizer_type == "zookeeper":
kazoo_client_factory = config.kazoo_client_factory or KazooClient
kazoo_client = kazoo_client_factory(hosts=memoizer_conf.get("zk_hosts", "localhost:2181"))
kazoo_client = KazooClient(hosts=memoizer_conf.get("zk_hosts", "localhost:2181"))
if get_backend_config().zk_memoizer_tracking:
kazoo_client = AttrStatsProxy(
target=kazoo_client,
to_track=["start", "stop", "create", "get", "set"],
# TODO: better solution than using a module level global here?
stats=zk_memoizer_stats,
)
return ZkMemoizer(
client=kazoo_client,
path_prefix=f"{config.zookeeper_prefix}/cache/{namespace}",
Expand Down
5 changes: 3 additions & 2 deletions src/openeo_aggregator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from openeo_driver.config.load import ConfigGetter
from openeo_driver.server import build_backend_deploy_metadata
from openeo_driver.users.oidc import OidcProvider
from openeo_driver.utils import dict_item
from openeo_driver.utils import dict_item, smart_bool

import openeo_aggregator.about

Expand Down Expand Up @@ -50,7 +50,6 @@ class AggregatorConfig(dict):

partitioned_job_tracking = dict_item(default=None)
zookeeper_prefix = dict_item(default="/openeo-aggregator/")
kazoo_client_factory = dict_item(default=None)

# See `memoizer_from_config` for details.
memoizer = dict_item(default={"type": "dict"})
Expand Down Expand Up @@ -139,6 +138,8 @@ class AggregatorBackendConfig(OpenEoBackendConfig):
# List of collection ids to cover with the aggregator (when None: support union of all upstream collections)
collection_whitelist: Optional[List[Union[str, re.Pattern]]] = None

zk_memoizer_tracking: bool = smart_bool(os.environ.get("OPENEO_AGGREGATOR_ZK_MEMOIZER_TRACKING"))


# Internal singleton
_config_getter = ConfigGetter(expected_class=AggregatorBackendConfig)
Expand Down
20 changes: 20 additions & 0 deletions src/openeo_aggregator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
List,
NamedTuple,
Optional,
Sequence,
Set,
Union,
)
Expand Down Expand Up @@ -306,3 +307,22 @@ def is_whitelisted(
elif isinstance(pattern, re.Pattern) and pattern.fullmatch(needle):
return True
return False


class AttrStatsProxy:
"""
Proxy object to wrap a given object and keep stats of attribute/method usage.
"""

# TODO: avoid all these public attributes that could collide with existing attributes of the proxied object
__slots__ = ["target", "to_track", "stats"]

def __init__(self, target: Any, to_track: Sequence[str], stats: Optional[dict] = None):
self.target = target
self.to_track = set(to_track)
self.stats = stats if stats is not None else {}

def __getattr__(self, name):
if name in self.to_track:
self.stats[name] = self.stats.get(name, 0) + 1
return getattr(self.target, name)
72 changes: 38 additions & 34 deletions tests/background/test_prime_caches.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import json
import logging
import re
import textwrap
from pathlib import Path
from typing import Any
from unittest import mock

import pytest
from openeo_driver.testing import DictSubSet

from openeo_aggregator.background.prime_caches import AttrStatsProxy, main, prime_caches
import openeo_aggregator.caching
from openeo_aggregator.background.prime_caches import main, prime_caches
from openeo_aggregator.config import AggregatorConfig
from openeo_aggregator.testing import DummyKazooClient
from openeo_aggregator.testing import config_overrides

FILE_FORMATS_JUST_GEOTIFF = {
"input": {"GTiff": {"gis_data_types": ["raster"], "parameters": {}, "title": "GeoTiff"}},
Expand All @@ -23,38 +27,26 @@ def config(backend1, backend2, backend1_id, backend2_id, zk_client) -> Aggregato
backend1_id: backend1,
backend2_id: backend2,
}
conf.kazoo_client_factory = lambda **kwargs: zk_client
conf.zookeeper_prefix = "/oa/"
conf.memoizer = {
"type": "zookeeper",
"config": {
"zk_hosts": "localhost:2181",
"zk_hosts": "zk.test:2181",
"default_ttl": 24 * 60 * 60,
},
}
return conf


class TestAttrStatsProxy:
def test_basic(self):
class Foo:
def bar(self, x):
return x + 1
@pytest.fixture(autouse=True)
def _mock_kazoo_client(zk_client):
with mock.patch.object(openeo_aggregator.caching, "KazooClient", return_value=zk_client):
yield

def meh(self, x):
return x * 2

foo = AttrStatsProxy(target=Foo(), to_track=["bar"])

assert foo.bar(3) == 4
assert foo.meh(6) == 12

assert foo.stats == {"bar": 1}


def test_prime_caches_basic(config, backend1, backend2, requests_mock, mbldr, caplog, zk_client):
"""Just check that bare basics of `prime_caches` work."""
mocks = [
@pytest.fixture
def upstream_request_mocks(requests_mock, backend1, backend2, mbldr) -> list:
return [
requests_mock.get(backend1 + "/file_formats", json=FILE_FORMATS_JUST_GEOTIFF),
requests_mock.get(backend2 + "/file_formats", json=FILE_FORMATS_JUST_GEOTIFF),
requests_mock.get(backend1 + "/collections", json=mbldr.collections("S2")),
Expand All @@ -63,9 +55,13 @@ def test_prime_caches_basic(config, backend1, backend2, requests_mock, mbldr, ca
requests_mock.get(backend2 + "/collections/S2", json=mbldr.collection("S2")),
]


def test_prime_caches_basic(config, upstream_request_mocks, zk_client):
"""Just check that bare basics of `prime_caches` work."""

prime_caches(config=config)

assert all([m.call_count == 1 for m in mocks])
assert all([m.call_count == 1 for m in upstream_request_mocks])

assert zk_client.get_data_deserialized() == DictSubSet(
{
Expand All @@ -85,6 +81,22 @@ def test_prime_caches_basic(config, backend1, backend2, requests_mock, mbldr, ca
)


@pytest.mark.parametrize("zk_memoizer_tracking", [False, True])
def test_prime_caches_stats(config, upstream_request_mocks, caplog, zk_client, zk_memoizer_tracking):
"""Check logging of Zookeeper operation stats."""
caplog.set_level(logging.INFO)
with config_overrides(zk_memoizer_tracking=zk_memoizer_tracking):
prime_caches(config=config)

assert all([m.call_count == 1 for m in upstream_request_mocks])

(zk_stats,) = [r.message for r in caplog.records if r.message.startswith("ZooKeeper stats:")]
if zk_memoizer_tracking:
assert re.search(r"kazoo_stats=\{.*start.*create.*\} zk_writes=[1-9]\d*", zk_stats)
else:
assert zk_stats == "ZooKeeper stats: not configured"


def _is_primitive_construct(data: Any) -> bool:
"""Consists only of Python primitives int, float, dict, list, str, ...?"""
if isinstance(data, dict):
Expand All @@ -111,16 +123,8 @@ def _build_config_file(config: AggregatorConfig, path: Path):
)


def test_prime_caches_main_basic(backend1, backend2, requests_mock, mbldr, caplog, tmp_path, backend1_id, backend2_id):
def test_prime_caches_main_basic(backend1, backend2, upstream_request_mocks, tmp_path, backend1_id, backend2_id):
"""Just check that bare basics of `prime_caches` main work."""
mocks = [
requests_mock.get(backend1 + "/file_formats", json=FILE_FORMATS_JUST_GEOTIFF),
requests_mock.get(backend2 + "/file_formats", json=FILE_FORMATS_JUST_GEOTIFF),
requests_mock.get(backend1 + "/collections", json=mbldr.collections("S2")),
requests_mock.get(backend1 + "/collections/S2", json=mbldr.collection("S2")),
requests_mock.get(backend2 + "/collections", json=mbldr.collections("S2")),
requests_mock.get(backend2 + "/collections/S2", json=mbldr.collection("S2")),
]

# Construct config file
config = AggregatorConfig()
Expand All @@ -133,10 +137,10 @@ def test_prime_caches_main_basic(backend1, backend2, requests_mock, mbldr, caplo

main(args=["--config", str(config_file)])

assert all([m.call_count == 1 for m in mocks])
assert all([m.call_count == 1 for m in upstream_request_mocks])


def test_prime_caches_main_logging(backend1, backend2, mbldr, caplog, tmp_path, backend1_id, backend2_id, pytester):
def test_prime_caches_main_logging(backend1, backend2, tmp_path, backend1_id, backend2_id, pytester):
"""Run main in subprocess (so no request mocks, and probably a lot of failures) to see if logging setup works."""

config = AggregatorConfig()
Expand Down
18 changes: 18 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shapely.geometry

from openeo_aggregator.utils import (
AttrStatsProxy,
BoundingBox,
EventHandler,
MultiDictGetter,
Expand Down Expand Up @@ -320,3 +321,20 @@ def test_is_whitelisted_regex():
assert is_whitelisted("foobar", [re.compile(r"\w+bar"), "bar"])
assert not is_whitelisted("barfoo", [re.compile("f.*"), "bar"])
assert is_whitelisted("barfoo", [re.compile(".*f.*"), "bar"])


class TestAttrStatsProxy:
def test_basic(self):
class Foo:
def bar(self, x):
return x + 1

def meh(self, x):
return x * 2

foo = AttrStatsProxy(target=Foo(), to_track=["bar"])

assert foo.bar(3) == 4
assert foo.meh(6) == 12

assert foo.stats == {"bar": 1}

0 comments on commit 47a02e6

Please sign in to comment.