diff --git a/src/bentoml/_internal/configuration/containers.py b/src/bentoml/_internal/configuration/containers.py index bbac6a36e7c..bdcc04cdaff 100644 --- a/src/bentoml/_internal/configuration/containers.py +++ b/src/bentoml/_internal/configuration/containers.py @@ -71,7 +71,6 @@ def __init__( if override_defaults: if migration is not None: override_defaults = migration( - default_config=self.config, override_config=dict(flatten_dict(override_defaults)), ) config_merger.merge(self.config, override_defaults) @@ -85,7 +84,6 @@ def __init__( # Running migration layer if it exists if migration is not None: override = migration( - default_config=self.config, override_config=dict(flatten_dict(override)), ) config_merger.merge(self.config, override) @@ -97,7 +95,6 @@ def __init__( # Running migration layer if it exists if migration is not None: override_config_json = migration( - default_config=self.config, override_config=dict(flatten_dict(override_config_json)), ) config_merger.merge(self.config, override_config_json) @@ -122,9 +119,7 @@ def __init__( } # Running migration layer if it exists if migration is not None: - override_config_map = migration( - default_config=self.config, override_config=override_config_map - ) + override_config_map = migration(override_config=override_config_map) # Previous behaviour, before configuration versioning. try: override = unflatten(override_config_map) @@ -133,6 +128,9 @@ def __init__( f"Failed to parse config options from the env var:\n{e}.\n*** Note: You can use '\"' to quote the key if it contains special characters. ***" ) from None config_merger.merge(self.config, override) + + if finalize_config := getattr(spec_module, "finalize_config", None): + finalize_config(self.config) expand_env_var_in_values(self.config) if validate_schema: diff --git a/src/bentoml/_internal/configuration/v1/__init__.py b/src/bentoml/_internal/configuration/v1/__init__.py index 567bf4d01ce..6ed8ea64c61 100644 --- a/src/bentoml/_internal/configuration/v1/__init__.py +++ b/src/bentoml/_internal/configuration/v1/__init__.py @@ -2,6 +2,7 @@ import re import typing as t +from copy import deepcopy from numbers import Real import schema as s @@ -13,14 +14,11 @@ from ..helpers import ensure_larger_than from ..helpers import ensure_larger_than_zero from ..helpers import ensure_range -from ..helpers import flatten_dict from ..helpers import is_valid_ip_address from ..helpers import rename_fields from ..helpers import validate_otlp_protocol from ..helpers import validate_tracing_type -__all__ = ["SCHEMA", "migration"] - TRACING_CFG = { "exporter_type": s.Or(s.And(str, s.Use(str.lower), validate_tracing_type), None), "sample_rate": s.Or(s.And(float, ensure_range(0, 1)), None), @@ -194,7 +192,7 @@ ) -def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.Any]): +def migration(*, override_config: dict[str, t.Any]): # We will use a flattened config to make it easier to migrate, # Then we will convert it back to a nested config. if depth(override_config) > 1: @@ -310,8 +308,13 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t. replace_with=f"runners.{runner_name}.traffic.timeout", ) + return unflatten(override_config) + + +def finalize_config(config: dict[str, t.Any]) -> None: + from ..containers import config_merger + # 8. if runner is overriden, set the runner default values - default_runner_config = dict(flatten_dict(default_config["runners"])) RUNNER_CFG_KEYS = [ "batching", "resources", @@ -320,23 +323,14 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t. "traffic", "workers_per_resource", ] - default_runner_config: dict[str, t.Any] = {} - for runner_name, runner_cfg in default_config["runners"].items(): - if runner_name in RUNNER_CFG_KEYS: - default_runner_config[runner_name] = runner_cfg + default_runner_config: dict[str, t.Any] = { + key: value for key, value in config["runners"].items() if key in RUNNER_CFG_KEYS + } - for key in list(override_config): - if key.startswith("runners."): - key_parts = key.split(".") - runner_name = key_parts[1] - if runner_name in RUNNER_CFG_KEYS: - default_runner_config[".".join(key_parts[1:])] = override_config[key] - for i in range(2, len(key_parts)): - if (k := ".".join(key_parts[1:i])) in default_runner_config: - del default_runner_config[k] - else: - if runner_name not in default_config["runners"].keys(): - default_config["runners"][runner_name] = unflatten( - default_runner_config - ) - return unflatten(override_config) + for runner_name, runner_cfg in config["runners"].items(): + if runner_name in RUNNER_CFG_KEYS: + continue + # key is a runner name + config["runners"][runner_name] = config_merger.merge( + deepcopy(default_runner_config), runner_cfg + ) diff --git a/src/bentoml/_internal/configuration/v2/__init__.py b/src/bentoml/_internal/configuration/v2/__init__.py index 66ec42e21c9..6e52ea585a6 100644 --- a/src/bentoml/_internal/configuration/v2/__init__.py +++ b/src/bentoml/_internal/configuration/v2/__init__.py @@ -13,13 +13,10 @@ from ..helpers import ensure_larger_than from ..helpers import ensure_larger_than_zero from ..helpers import ensure_range -from ..helpers import flatten_dict from ..helpers import is_valid_ip_address from ..helpers import validate_otlp_protocol from ..helpers import validate_tracing_type -__all__ = ["SCHEMA", "migration"] - TRACING_CFG = { "exporter_type": s.Or(s.And(str, s.Use(str.lower), validate_tracing_type), None), "sample_rate": s.Or(s.And(float, ensure_range(0, 1)), None), @@ -187,7 +184,7 @@ ) -def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.Any]): +def migration(*, override_config: dict[str, t.Any]): # We will use a flattened config to make it easier to migrate, # Then we will convert it back to a nested config. if depth(override_config) > 1: @@ -196,6 +193,12 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t. if "version" not in override_config: override_config["version"] = 2 + return unflatten(override_config) + + +def finalize_config(config: dict[str, t.Any]) -> dict[str, t.Any]: + from ..containers import config_merger + SERVICE_CFG_KEYS = [ "batching", "resources", @@ -212,26 +215,16 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t. "monitoring", "tracing", ] - default_service_config: dict[str, t.Any] = {} - for svc, svc_cfg in default_config["services"].items(): - if svc in SERVICE_CFG_KEYS: - default_service_config[svc] = svc_cfg - default_service_config = dict(flatten_dict(default_service_config)) - for key in list(override_config): - if key.startswith("services."): - # NOTE: We need to remove the quotation in case the runner name includes dashes. - # Since unflatten_dict will include the quotes for given name - key_parts = [s.replace('"', "") for s in key.split(".")] - service_name = key_parts[1] - if service_name in SERVICE_CFG_KEYS: - default_service_config[".".join(key_parts[1:])] = override_config[key] - for i in range(2, len(key_parts)): - if (k := ".".join(key_parts[1:i])) in default_service_config: - del default_service_config[k] - else: - if service_name not in default_config["services"].keys(): - default_config["services"][service_name] = unflatten( - default_service_config - ) - return unflatten(override_config) + default_service_config = { + key: value + for key, value in config["services"].items() + if key in SERVICE_CFG_KEYS + } + + for svc, service_config in config["services"].items(): + if svc in SERVICE_CFG_KEYS: + continue + config["services"][svc] = config_merger.merge( + default_service_config, service_config + )