From d397b1368b9edc7f63204ea68cc467c970de5cf6 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 17 Jan 2025 14:26:00 -0800 Subject: [PATCH] feat(ingest): add strict_env_syntax mode to config loader (#12380) --- .../datahub/configuration/config_loader.py | 123 +++++++++++------- .../src/datahub/secret/secret_common.py | 22 ++-- .../tests/unit/config/test_config_loader.py | 23 ++++ 3 files changed, 110 insertions(+), 58 deletions(-) diff --git a/metadata-ingestion/src/datahub/configuration/config_loader.py b/metadata-ingestion/src/datahub/configuration/config_loader.py index c64e81e7c714c9..16105f69d584de 100644 --- a/metadata-ingestion/src/datahub/configuration/config_loader.py +++ b/metadata-ingestion/src/datahub/configuration/config_loader.py @@ -19,64 +19,87 @@ Environ = Mapping[str, str] -def _resolve_element(element: str, environ: Environ) -> str: - if re.search(r"(\$\{).+(\})", element): - return expand(element, nounset=True, environ=environ) - elif element.startswith("$"): - try: - return expand(element, nounset=True, environ=environ) - except UnboundVariable: - return element - else: - return element - - -def _resolve_list(ele_list: list, environ: Environ) -> list: - new_v: list = [] - for ele in ele_list: - if isinstance(ele, str): - new_v.append(_resolve_element(ele, environ=environ)) - elif isinstance(ele, list): - new_v.append(_resolve_list(ele, environ=environ)) - elif isinstance(ele, dict): - new_v.append(resolve_env_variables(ele, environ=environ)) - else: - new_v.append(ele) - return new_v - - def resolve_env_variables(config: dict, environ: Environ) -> dict: - new_dict: Dict[Any, Any] = {} - for k, v in config.items(): - if isinstance(v, dict): - new_dict[k] = resolve_env_variables(v, environ=environ) - elif isinstance(v, list): - new_dict[k] = _resolve_list(v, environ=environ) - elif isinstance(v, str): - new_dict[k] = _resolve_element(v, environ=environ) - else: - new_dict[k] = v - return new_dict + # TODO: This is kept around for backwards compatibility. + return EnvResolver(environ).resolve(config) def list_referenced_env_variables(config: dict) -> Set[str]: - # This is a bit of a hack, but expandvars does a bunch of escaping - # and other logic that we don't want to duplicate here. + # TODO: This is kept around for backwards compatibility. + return EnvResolver(environ=os.environ).list_referenced_variables(config) + + +class EnvResolver: + def __init__(self, environ: Environ, strict_env_syntax: bool = False): + self.environ = environ + self.strict_env_syntax = strict_env_syntax - vars = set() + def resolve(self, config: dict) -> dict: + return self._resolve_dict(config) - def mock_get_env(key: str, default: Optional[str] = None) -> str: - vars.add(key) - if default is not None: - return default - return "mocked_value" + @classmethod + def list_referenced_variables( + cls, + config: dict, + strict_env_syntax: bool = False, + ) -> Set[str]: + # This is a bit of a hack, but expandvars does a bunch of escaping + # and other logic that we don't want to duplicate here. - mock = unittest.mock.MagicMock() - mock.get.side_effect = mock_get_env + vars = set() - resolve_env_variables(config, environ=mock) + def mock_get_env(key: str, default: Optional[str] = None) -> str: + vars.add(key) + if default is not None: + return default + return "mocked_value" + + mock = unittest.mock.MagicMock() + mock.get.side_effect = mock_get_env + + resolver = EnvResolver(environ=mock, strict_env_syntax=strict_env_syntax) + resolver._resolve_dict(config) + + return vars + + def _resolve_element(self, element: str) -> str: + if re.search(r"(\$\{).+(\})", element): + return expand(element, nounset=True, environ=self.environ) + elif not self.strict_env_syntax and element.startswith("$"): + try: + return expand(element, nounset=True, environ=self.environ) + except UnboundVariable: + # TODO: This fallback is kept around for backwards compatibility, but + # doesn't make a ton of sense from first principles. + return element + else: + return element - return vars + def _resolve_list(self, ele_list: list) -> list: + new_v: list = [] + for ele in ele_list: + if isinstance(ele, str): + new_v.append(self._resolve_element(ele)) + elif isinstance(ele, list): + new_v.append(self._resolve_list(ele)) + elif isinstance(ele, dict): + new_v.append(self._resolve_dict(ele)) + else: + new_v.append(ele) + return new_v + + def _resolve_dict(self, config: dict) -> dict: + new_dict: Dict[Any, Any] = {} + for k, v in config.items(): + if isinstance(v, dict): + new_dict[k] = self._resolve_dict(v) + elif isinstance(v, list): + new_dict[k] = self._resolve_list(v) + elif isinstance(v, str): + new_dict[k] = self._resolve_element(v) + else: + new_dict[k] = v + return new_dict WRITE_TO_FILE_DIRECTIVE_PREFIX = "__DATAHUB_TO_FILE_" @@ -159,7 +182,7 @@ def load_config_file( config = raw_config.copy() if resolve_env_vars: - config = resolve_env_variables(config, environ=os.environ) + config = EnvResolver(environ=os.environ).resolve(config) if process_directives: config = _process_directives(config) diff --git a/metadata-ingestion/src/datahub/secret/secret_common.py b/metadata-ingestion/src/datahub/secret/secret_common.py index 2f7a584d875380..a116c70407af23 100644 --- a/metadata-ingestion/src/datahub/secret/secret_common.py +++ b/metadata-ingestion/src/datahub/secret/secret_common.py @@ -2,10 +2,7 @@ import logging from typing import List -from datahub.configuration.config_loader import ( - list_referenced_env_variables, - resolve_env_variables, -) +from datahub.configuration.config_loader import EnvResolver from datahub.secret.secret_store import SecretStore logger = logging.getLogger(__name__) @@ -42,18 +39,27 @@ def resolve_secrets(secret_names: List[str], secret_stores: List[SecretStore]) - return final_secret_values -def resolve_recipe(recipe: str, secret_stores: List[SecretStore]) -> dict: +def resolve_recipe( + recipe: str, secret_stores: List[SecretStore], strict_env_syntax: bool = True +) -> dict: + # Note: the default for `strict_env_syntax` is normally False, but here we override + # it to be true. Particularly when fetching secrets from external secret stores, we + # want to be more careful about not over-fetching secrets. + json_recipe_raw = json.loads(recipe) # 1. Extract all secrets needing resolved. - secrets_to_resolve = list_referenced_env_variables(json_recipe_raw) + secrets_to_resolve = EnvResolver.list_referenced_variables( + json_recipe_raw, strict_env_syntax=strict_env_syntax + ) # 2. Resolve secret values secret_values_dict = resolve_secrets(list(secrets_to_resolve), secret_stores) # 3. Substitute secrets into recipe file - json_recipe_resolved = resolve_env_variables( - json_recipe_raw, environ=secret_values_dict + resolver = EnvResolver( + environ=secret_values_dict, strict_env_syntax=strict_env_syntax ) + json_recipe_resolved = resolver.resolve(json_recipe_raw) return json_recipe_resolved diff --git a/metadata-ingestion/tests/unit/config/test_config_loader.py b/metadata-ingestion/tests/unit/config/test_config_loader.py index 25ee289ec4e4e7..43781acd7f80c0 100644 --- a/metadata-ingestion/tests/unit/config/test_config_loader.py +++ b/metadata-ingestion/tests/unit/config/test_config_loader.py @@ -10,6 +10,7 @@ from datahub.configuration.common import ConfigurationError from datahub.configuration.config_loader import ( + EnvResolver, list_referenced_env_variables, load_config_file, ) @@ -138,6 +139,28 @@ def test_load_success(pytestconfig, filename, golden_config, env, referenced_env # TODO check referenced env vars +def test_load_strict_env_syntax() -> None: + config = { + "foo": "${BAR}", + "baz": "$BAZ", + "qux": "qux$QUX", + } + assert EnvResolver.list_referenced_variables( + config, + strict_env_syntax=True, + ) == {"BAR"} + + assert EnvResolver( + environ={ + "BAR": "bar", + } + ).resolve(config) == { + "foo": "bar", + "baz": "$BAZ", + "qux": "qux$QUX", + } + + @pytest.mark.parametrize( "filename,env,error_type", [