Skip to content

Commit

Permalink
fix: issue overriding default service config from config file (#4627)
Browse files Browse the repository at this point in the history
* fix: issue overriding default service config from config file

Signed-off-by: Frost Ming <[email protected]>

* fix

Signed-off-by: Frost Ming <[email protected]>

* fix __all__

Signed-off-by: Frost Ming <[email protected]>

---------

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming authored Apr 8, 2024
1 parent 96ab26d commit 6b0ca6b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 56 deletions.
10 changes: 4 additions & 6 deletions src/bentoml/_internal/configuration/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(
if override_defaults:
if migration is not None:
override_defaults = migration(
default_config=self.config,
override_config=dict(flatten_dict(override_defaults)),
)
config_merger.merge(self.config, override_defaults)
Expand All @@ -85,7 +84,6 @@ def __init__(
# Running migration layer if it exists
if migration is not None:
override = migration(
default_config=self.config,
override_config=dict(flatten_dict(override)),
)
config_merger.merge(self.config, override)
Expand All @@ -97,7 +95,6 @@ def __init__(
# Running migration layer if it exists
if migration is not None:
override_config_json = migration(
default_config=self.config,
override_config=dict(flatten_dict(override_config_json)),
)
config_merger.merge(self.config, override_config_json)
Expand All @@ -122,9 +119,7 @@ def __init__(
}
# Running migration layer if it exists
if migration is not None:
override_config_map = migration(
default_config=self.config, override_config=override_config_map
)
override_config_map = migration(override_config=override_config_map)
# Previous behaviour, before configuration versioning.
try:
override = unflatten(override_config_map)
Expand All @@ -133,6 +128,9 @@ def __init__(
f"Failed to parse config options from the env var:\n{e}.\n*** Note: You can use '\"' to quote the key if it contains special characters. ***"
) from None
config_merger.merge(self.config, override)

if finalize_config := getattr(spec_module, "finalize_config", None):
finalize_config(self.config)
expand_env_var_in_values(self.config)

if validate_schema:
Expand Down
42 changes: 18 additions & 24 deletions src/bentoml/_internal/configuration/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
import typing as t
from copy import deepcopy
from numbers import Real

import schema as s
Expand All @@ -13,14 +14,11 @@
from ..helpers import ensure_larger_than
from ..helpers import ensure_larger_than_zero
from ..helpers import ensure_range
from ..helpers import flatten_dict
from ..helpers import is_valid_ip_address
from ..helpers import rename_fields
from ..helpers import validate_otlp_protocol
from ..helpers import validate_tracing_type

__all__ = ["SCHEMA", "migration"]

TRACING_CFG = {
"exporter_type": s.Or(s.And(str, s.Use(str.lower), validate_tracing_type), None),
"sample_rate": s.Or(s.And(float, ensure_range(0, 1)), None),
Expand Down Expand Up @@ -194,7 +192,7 @@
)


def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.Any]):
def migration(*, override_config: dict[str, t.Any]):
# We will use a flattened config to make it easier to migrate,
# Then we will convert it back to a nested config.
if depth(override_config) > 1:
Expand Down Expand Up @@ -310,8 +308,13 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.
replace_with=f"runners.{runner_name}.traffic.timeout",
)

return unflatten(override_config)


def finalize_config(config: dict[str, t.Any]) -> None:
from ..containers import config_merger

# 8. if runner is overriden, set the runner default values
default_runner_config = dict(flatten_dict(default_config["runners"]))
RUNNER_CFG_KEYS = [
"batching",
"resources",
Expand All @@ -320,23 +323,14 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.
"traffic",
"workers_per_resource",
]
default_runner_config: dict[str, t.Any] = {}
for runner_name, runner_cfg in default_config["runners"].items():
if runner_name in RUNNER_CFG_KEYS:
default_runner_config[runner_name] = runner_cfg
default_runner_config: dict[str, t.Any] = {
key: value for key, value in config["runners"].items() if key in RUNNER_CFG_KEYS
}

for key in list(override_config):
if key.startswith("runners."):
key_parts = key.split(".")
runner_name = key_parts[1]
if runner_name in RUNNER_CFG_KEYS:
default_runner_config[".".join(key_parts[1:])] = override_config[key]
for i in range(2, len(key_parts)):
if (k := ".".join(key_parts[1:i])) in default_runner_config:
del default_runner_config[k]
else:
if runner_name not in default_config["runners"].keys():
default_config["runners"][runner_name] = unflatten(
default_runner_config
)
return unflatten(override_config)
for runner_name, runner_cfg in config["runners"].items():
if runner_name in RUNNER_CFG_KEYS:
continue
# key is a runner name
config["runners"][runner_name] = config_merger.merge(
deepcopy(default_runner_config), runner_cfg
)
45 changes: 19 additions & 26 deletions src/bentoml/_internal/configuration/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
from ..helpers import ensure_larger_than
from ..helpers import ensure_larger_than_zero
from ..helpers import ensure_range
from ..helpers import flatten_dict
from ..helpers import is_valid_ip_address
from ..helpers import validate_otlp_protocol
from ..helpers import validate_tracing_type

__all__ = ["SCHEMA", "migration"]

TRACING_CFG = {
"exporter_type": s.Or(s.And(str, s.Use(str.lower), validate_tracing_type), None),
"sample_rate": s.Or(s.And(float, ensure_range(0, 1)), None),
Expand Down Expand Up @@ -187,7 +184,7 @@
)


def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.Any]):
def migration(*, override_config: dict[str, t.Any]):
# We will use a flattened config to make it easier to migrate,
# Then we will convert it back to a nested config.
if depth(override_config) > 1:
Expand All @@ -196,6 +193,12 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.
if "version" not in override_config:
override_config["version"] = 2

return unflatten(override_config)


def finalize_config(config: dict[str, t.Any]) -> dict[str, t.Any]:
from ..containers import config_merger

SERVICE_CFG_KEYS = [
"batching",
"resources",
Expand All @@ -212,26 +215,16 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.
"monitoring",
"tracing",
]
default_service_config: dict[str, t.Any] = {}
for svc, svc_cfg in default_config["services"].items():
if svc in SERVICE_CFG_KEYS:
default_service_config[svc] = svc_cfg
default_service_config = dict(flatten_dict(default_service_config))

for key in list(override_config):
if key.startswith("services."):
# NOTE: We need to remove the quotation in case the runner name includes dashes.
# Since unflatten_dict will include the quotes for given name
key_parts = [s.replace('"', "") for s in key.split(".")]
service_name = key_parts[1]
if service_name in SERVICE_CFG_KEYS:
default_service_config[".".join(key_parts[1:])] = override_config[key]
for i in range(2, len(key_parts)):
if (k := ".".join(key_parts[1:i])) in default_service_config:
del default_service_config[k]
else:
if service_name not in default_config["services"].keys():
default_config["services"][service_name] = unflatten(
default_service_config
)
return unflatten(override_config)
default_service_config = {
key: value
for key, value in config["services"].items()
if key in SERVICE_CFG_KEYS
}

for svc, service_config in config["services"].items():
if svc in SERVICE_CFG_KEYS:
continue
config["services"][svc] = config_merger.merge(
default_service_config, service_config
)

0 comments on commit 6b0ca6b

Please sign in to comment.