Skip to content

Commit

Permalink
Issue #112 Eliminate unused AggregatorConfig arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Mar 1, 2024
1 parent 6877dfc commit 3f4fd30
Show file tree
Hide file tree
Showing 14 changed files with 336 additions and 402 deletions.
12 changes: 4 additions & 8 deletions src/openeo_aggregator/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import logging
import os
from pathlib import Path
from typing import Any, List, Optional, Union
from typing import List, Optional, Union

import flask
import openeo_driver.views
from openeo_driver.config.load import ConfigGetter
from openeo_driver.util.logging import (
LOG_HANDLER_STDERR_JSON,
LOGGING_CONTEXT_FLASK,
Expand All @@ -22,12 +21,11 @@
AggregatorBackendImplementation,
MultiBackendConnection,
)
from openeo_aggregator.config import AggregatorConfig, get_config, get_config_dir

_log = logging.getLogger(__name__)


def create_app(config: Any = None, auto_logging_setup: bool = True, flask_error_handling: bool = True) -> flask.Flask:
def create_app(auto_logging_setup: bool = True, flask_error_handling: bool = True) -> flask.Flask:
"""
Flask application factory function.
"""
Expand All @@ -39,13 +37,11 @@ def create_app(config: Any = None, auto_logging_setup: bool = True, flask_error_

log_version_info(logger=_log)

config: AggregatorConfig = get_config(config)
_log.info(f"Using config: {config.config_source=!r}")

backends = MultiBackendConnection.from_config(config)
backends = MultiBackendConnection.from_config()

_log.info("Creating AggregatorBackendImplementation")
backend_implementation = AggregatorBackendImplementation(backends=backends, config=config)
backend_implementation = AggregatorBackendImplementation(backends=backends)

_log.info(f"Building Flask app with {backend_implementation=!r}")
app = openeo_driver.views.build_app(
Expand Down
25 changes: 10 additions & 15 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def __jsonserde_load__(cls, data: dict):

class AggregatorCollectionCatalog(AbstractCollectionCatalog):

def __init__(self, backends: MultiBackendConnection, config: AggregatorConfig):
def __init__(self, backends: MultiBackendConnection):
self.backends = backends
self._memoizer = memoizer_from_config(config=config, namespace="CollectionCatalog")
self._memoizer = memoizer_from_config(namespace="CollectionCatalog")
self.backends.on_connections_change.add(self._memoizer.invalidate)

def get_all_metadata(self) -> List[dict]:
Expand Down Expand Up @@ -349,11 +349,10 @@ def __init__(
self,
backends: MultiBackendConnection,
catalog: AggregatorCollectionCatalog,
config: AggregatorConfig,
):
self.backends = backends
# TODO Cache per backend results instead of output?
self._memoizer = memoizer_from_config(config=config, namespace="Processing")
self._memoizer = memoizer_from_config(namespace="Processing")
self.backends.on_connections_change.add(self._memoizer.invalidate)
self._catalog = catalog

Expand Down Expand Up @@ -984,12 +983,11 @@ def __init__(
self,
backends: MultiBackendConnection,
processing: AggregatorProcessing,
config: AggregatorConfig
):
super(AggregatorSecondaryServices, self).__init__()

self._backends = backends
self._memoizer = memoizer_from_config(config=config, namespace="SecondaryServices")
self._memoizer = memoizer_from_config(namespace="SecondaryServices")
self._backends.on_connections_change.add(self._memoizer.invalidate)

self._processing = processing
Expand Down Expand Up @@ -1287,16 +1285,13 @@ class AggregatorBackendImplementation(OpenEoBackendImplementation):
# Simplify mocking time for unit tests.
_clock = time.time # TODO: centralized helper for this test pattern

def __init__(self, backends: MultiBackendConnection, config: AggregatorConfig):
def __init__(self, backends: MultiBackendConnection):
self._backends = backends
catalog = AggregatorCollectionCatalog(backends=backends, config=config)
processing = AggregatorProcessing(
backends=backends, catalog=catalog,
config=config,
)
catalog = AggregatorCollectionCatalog(backends=backends)
processing = AggregatorProcessing(backends=backends, catalog=catalog)

if get_backend_config().partitioned_job_tracking:
partitioned_job_tracker = PartitionedJobTracker.from_config(config=config, backends=self._backends)
partitioned_job_tracker = PartitionedJobTracker.from_config(backends=self._backends)
else:
partitioned_job_tracker = None

Expand All @@ -1307,7 +1302,7 @@ def __init__(self, backends: MultiBackendConnection, config: AggregatorConfig):
partitioned_job_tracker=partitioned_job_tracker,
)

secondary_services = AggregatorSecondaryServices(backends=backends, processing=processing, config=config)
secondary_services = AggregatorSecondaryServices(backends=backends, processing=processing)
user_defined_processes = AggregatorUserDefinedProcesses(backends=backends)

super().__init__(
Expand All @@ -1321,7 +1316,7 @@ def __init__(self, backends: MultiBackendConnection, config: AggregatorConfig):
self._configured_oidc_providers: List[OidcProvider] = get_backend_config().oidc_providers
self._auth_entitlement_check: Union[bool, dict] = get_backend_config().auth_entitlement_check

self._memoizer: Memoizer = memoizer_from_config(config=config, namespace="general")
self._memoizer: Memoizer = memoizer_from_config(namespace="general")
self._backends.on_connections_change.add(self._memoizer.invalidate)

# Shorter HTTP cache TTL to adapt quicker to changed back-end configurations
Expand Down
4 changes: 2 additions & 2 deletions src/openeo_aggregator/background/prime_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def prime_caches(
config: AggregatorConfig = get_config(config)
_log.info(f"Using config: {config.get('config_source')=}")

backends = MultiBackendConnection.from_config(config)
backend_implementation = AggregatorBackendImplementation(backends=backends, config=config)
backends = MultiBackendConnection.from_config()
backend_implementation = AggregatorBackendImplementation(backends=backends)

if fail_mode == FAIL_MODE_FAILFAST:
# Do not intercept any exceptions.
Expand Down
17 changes: 8 additions & 9 deletions src/openeo_aggregator/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,12 @@ def _zk_connect_or_not(self) -> Union[KazooClient, None]:

zk_memoizer_stats = {}

def memoizer_from_config(
config: AggregatorConfig,
namespace: str,
) -> Memoizer:

def memoizer_from_config(namespace: str) -> Memoizer:
"""Factory to create `ZkMemoizer` instance from config values."""

backend_config = get_backend_config()

def get_memoizer(memoizer_type: str, memoizer_conf: dict) -> Memoizer:
if memoizer_type == "null":
return NullMemoizer(namespace=namespace)
Expand All @@ -507,14 +507,14 @@ def get_memoizer(memoizer_type: str, memoizer_conf: dict) -> Memoizer:
return JsonDictMemoizer(namespace=namespace, default_ttl=memoizer_conf.get("default_ttl"))
elif memoizer_type == "zookeeper":
kazoo_client = KazooClient(hosts=memoizer_conf.get("zk_hosts", "localhost:2181"))
if get_backend_config().zk_memoizer_tracking:
if backend_config.zk_memoizer_tracking:
kazoo_client = AttrStatsProxy(
target=kazoo_client,
to_track=["start", "stop", "create", "get", "set"],
# TODO: better solution than using a module level global here?
stats=zk_memoizer_stats,
)
zookeeper_prefix = get_backend_config().zookeeper_prefix
zookeeper_prefix = backend_config.zookeeper_prefix
return ZkMemoizer(
client=kazoo_client,
path_prefix=f"{zookeeper_prefix}/cache/{namespace}",
Expand All @@ -530,8 +530,7 @@ def get_memoizer(memoizer_type: str, memoizer_conf: dict) -> Memoizer:
else:
raise ValueError(memoizer_type)

memoizer_config = get_backend_config().memoizer
return get_memoizer(
memoizer_type=memoizer_config.get("type", "null"),
memoizer_conf=memoizer_config.get("config", {}),
memoizer_type=backend_config.memoizer.get("type", "null"),
memoizer_conf=backend_config.memoizer.get("config", {}),
)
4 changes: 2 additions & 2 deletions src/openeo_aggregator/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ def __init__(
self.on_connections_change.add(self._memoizer.invalidate)

@staticmethod
def from_config(config: AggregatorConfig) -> 'MultiBackendConnection':
def from_config() -> "MultiBackendConnection":
backend_config = get_backend_config()
return MultiBackendConnection(
backends=backend_config.aggregator_backends,
configured_oidc_providers=backend_config.oidc_providers,
memoizer=memoizer_from_config(config, namespace="mbcon"),
memoizer=memoizer_from_config(namespace="mbcon"),
connections_cache_ttl=backend_config.connections_cache_ttl,
)

Expand Down
4 changes: 2 additions & 2 deletions src/openeo_aggregator/partitionedjobs/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(self, db: ZooKeeperPartitionedJobDB, backends: MultiBackendConnecti
self._backends = backends

@classmethod
def from_config(cls, config: AggregatorConfig, backends: MultiBackendConnection) -> "PartitionedJobTracker":
return cls(db=ZooKeeperPartitionedJobDB.from_config(config), backends=backends)
def from_config(cls, backends: MultiBackendConnection) -> "PartitionedJobTracker":
return cls(db=ZooKeeperPartitionedJobDB.from_config(), backends=backends)

def list_user_jobs(self, user_id: str) -> List[dict]:
return self._db.list_user_jobs(user_id=user_id)
Expand Down
2 changes: 1 addition & 1 deletion src/openeo_aggregator/partitionedjobs/zookeeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, client: KazooClient, prefix: str = None):
self._prefix = prefix or f"/openeo-aggregator/{self.NAMESPACE}"

@classmethod
def from_config(cls, config: AggregatorConfig) -> "ZooKeeperPartitionedJobDB":
def from_config(cls) -> "ZooKeeperPartitionedJobDB":
# Get ZooKeeper client
pjt_config = get_backend_config().partitioned_job_tracking
if pjt_config.get("zk_client"):
Expand Down
24 changes: 9 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
from pathlib import Path
from typing import List

import flask
import pytest
from openeo_driver.testing import ApiTester
from openeo_driver.users.oidc import OidcProvider

from openeo_aggregator.app import create_app
from openeo_aggregator.backend import (
Expand Down Expand Up @@ -103,13 +101,12 @@ def config(


@pytest.fixture
def multi_backend_connection(config) -> MultiBackendConnection:
return MultiBackendConnection.from_config(config)
def multi_backend_connection(backend1, backend2) -> MultiBackendConnection:
return MultiBackendConnection.from_config()


def get_flask_app(config: AggregatorConfig) -> flask.Flask:
def get_flask_app() -> flask.Flask:
app = create_app(
config=config,
auto_logging_setup=False,
# flask_error_handling=False, # Failing test debug tip: set to False for deeper stack trace insights
)
Expand All @@ -119,8 +116,8 @@ def get_flask_app(config: AggregatorConfig) -> flask.Flask:


@pytest.fixture
def flask_app(config: AggregatorConfig) -> flask.Flask:
app = get_flask_app(config)
def flask_app(backend1, backend2) -> flask.Flask:
app = get_flask_app()
with app.app_context():
yield app

Expand All @@ -141,11 +138,11 @@ def api100(flask_app: flask.Flask) -> ApiTester:


@pytest.fixture
def api100_with_entitlement_check(config: AggregatorConfig) -> ApiTester:
def api100_with_entitlement_check() -> ApiTester:
with config_overrides(
auth_entitlement_check={"oidc_issuer_whitelist": {"https://egi.test", "https://egi.test/oidc"}}
):
yield get_api100(get_flask_app(config))
yield get_api100(get_flask_app())


def assert_dict_subset(d1: dict, d2: dict):
Expand All @@ -154,11 +151,8 @@ def assert_dict_subset(d1: dict, d2: dict):


@pytest.fixture
def catalog(multi_backend_connection, config) -> AggregatorCollectionCatalog:
return AggregatorCollectionCatalog(
backends=multi_backend_connection,
config=config
)
def catalog(multi_backend_connection) -> AggregatorCollectionCatalog:
return AggregatorCollectionCatalog(backends=multi_backend_connection)


@pytest.fixture
Expand Down
14 changes: 4 additions & 10 deletions tests/partitionedjobs/test_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,8 @@ def test_tile_grid_spec_from_string():
assert TileGrid.from_string("utm-10km") == TileGrid(crs_type="utm", size=10, unit="km")


def test_flimsy_splitter(multi_backend_connection, catalog, config):
splitter = FlimsySplitter(processing=AggregatorProcessing(
backends=multi_backend_connection,
catalog=catalog,
config=config,
))
def test_flimsy_splitter(multi_backend_connection, catalog):
splitter = FlimsySplitter(processing=AggregatorProcessing(backends=multi_backend_connection, catalog=catalog))
process = {"process_graph": {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}}
pjob = splitter.split(process)
assert len(pjob.subjobs) == 1
Expand All @@ -39,12 +35,10 @@ def test_flimsy_splitter(multi_backend_connection, catalog, config):
class TestTileGridSplitter:

@pytest.fixture
def aggregator_processing(
self, multi_backend_connection, catalog, config, requests_mock, backend1
) -> AggregatorProcessing:
def aggregator_processing(self, multi_backend_connection, catalog, requests_mock, backend1) -> AggregatorProcessing:
requests_mock.get(backend1 + "/collections", json={"collections": [{"id": "S2"}]})
requests_mock.get(backend1 + "/collections/S2", json={"id": "S2"})
return AggregatorProcessing(backends=multi_backend_connection, catalog=catalog, config=config)
return AggregatorProcessing(backends=multi_backend_connection, catalog=catalog)

@pytest.mark.parametrize(["west", "south", "tile_grid", "expected_extent"], [
# >>> from pyproj import Transformer
Expand Down
19 changes: 11 additions & 8 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
from .conftest import get_api100, get_flask_app


def test_create_app(config: AggregatorConfig):
app = create_app(config)
def test_create_app():
app = create_app()
assert isinstance(app, flask.Flask)


@pytest.mark.parametrize(["partitioned_job_tracking", "expected"], [
(None, False),
({"zk_client": "dummy"}, True),
])
def test_create_app_no_partitioned_job_tracking(config: AggregatorConfig, partitioned_job_tracking, expected):
@pytest.mark.parametrize(
["partitioned_job_tracking", "expected"],
[
(None, False),
({"zk_client": "dummy"}, True),
],
)
def test_create_app_no_partitioned_job_tracking(partitioned_job_tracking, expected):
with config_overrides(partitioned_job_tracking=partitioned_job_tracking):
api100 = get_api100(get_flask_app(config))
api100 = get_api100(get_flask_app())
res = api100.get("/").assert_status_code(200).json
assert res["_partitioned_job_tracking"] is expected
Loading

0 comments on commit 3f4fd30

Please sign in to comment.