Skip to content

Commit

Permalink
update load_yaml function
Browse files Browse the repository at this point in the history
  • Loading branch information
nsheff committed Feb 22, 2024
1 parent 1496e46 commit 31744b5
Showing 1 changed file with 124 additions and 26 deletions.
150 changes: 124 additions & 26 deletions yacman/yacman.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from ubiquerg import create_lock, expandpath, is_url, make_lock_path, mkabs, remove_lock

from .const import *
from ._version import __version__
from typing import Union
from pathlib import Path


_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -403,36 +407,130 @@ def _check_filepath(filepath):
return filepath


# Thought about moving this to ubiquerg -- but `yaml` isn't built-in
# so it would add a new dependency there, which I'd rather not do.
def load_yaml(filepath):
"""Load a yaml file into a python dict"""

def read_yaml_file(filepath):
"""
Read a YAML file

:param str filepath: path to the file to read
:return dict: read data
"""
with open(filepath, "r") as f:
data = yaml.safe_load(f)
return data
def load_yaml(filepath: Union[str, Path]) -> dict:
"""
Load a local or remote YAML file into a Python dict
:param str filepath: path to the file to read
:raises ConnectionError: if the remote YAML file reading fails
:return dict: loaded yaml data
"""
if is_url(filepath):
_LOGGER.debug(f"Got URL: {filepath}")
try: # python3
from urllib.error import HTTPError
from urllib.request import urlopen
except: # python2
from urllib2 import URLError as HTTPError
from urllib2 import urlopen
try:
response = urlopen(filepath)
except HTTPError as e:
raise e
data = response.read() # a `bytes` object
text = data.decode("utf-8")
return yaml.safe_load(text)
except Exception as e:
raise ConnectionError(
f"Could not load remote file: {filepath}. "
f"Original exception: {getattr(e, 'message', repr(e))}"
)
else:
data = response.read().decode("utf-8")
return yaml.safe_load(data)
else:
return read_yaml_file(filepath)
with open(os.path.abspath(filepath), "r") as f:
data = yaml.safe_load(f)
return data


def select_config(
config_filepath: str = None,
config_env_vars=None,
default_config_filepath: str = None,
check_exist: bool = True,
on_missing=lambda fp: IOError(fp),
strict_env: bool = False,
config_name=None,
) -> str:
"""
Selects the config file to load.
This uses a priority ordering to first choose a config filepath if it's given,
but if not, then look in a priority list of environment variables and choose
the first available filepath to return.
:param str | NoneType config_filepath: direct filepath specification
:param Iterable[str] | NoneType config_env_vars: names of environment
variables to try for config filepaths
:param str default_config_filepath: default value if no other alternative
resolution succeeds
:param bool check_exist: whether to check for path existence as file
:param function(str) -> object on_missing: what to do with a filepath if it
doesn't exist
:param bool strict_env: whether to raise an exception if no file path provided
and environment variables do not point to any files
raise: OSError: when strict environment variables validation is not passed
"""

# First priority: given file
if type(config_name) is str:
config_name = f"{config_name} "
else:
config_name = ""

if type(config_env_vars) is str:
config_env_vars = [config_env_vars]

if config_filepath:
config_filepath = os.path.expandvars(config_filepath)
if not check_exist or os.path.isfile(config_filepath):
return os.path.abspath(config_filepath)
_LOGGER.error(f"{config_name}config file path isn't a file: {config_filepath}")
result = on_missing(config_filepath)
if isinstance(result, Exception):
raise result
return os.path.abspath(result)

_LOGGER.debug(f"No local {config_name}config file was provided.")
selected_filepath = None

# Second priority: environment variables (in order)
if config_env_vars:
_LOGGER.debug(
f"Checking environment variables '{config_env_vars}' for {config_name}config"
)

for env_var in config_env_vars:
result = os.environ.get(env_var, None)
if result == None:
_LOGGER.debug(f"Env var '{env_var}' not set.")
continue
elif result == "":
_LOGGER.debug(f"Env var '{env_var}' exists, but value is empty.")
continue
elif not os.path.isfile(result):
_LOGGER.debug(f"Env var '{env_var}' file not found: {result}")
continue
else:
_LOGGER.debug(f"Found {config_name}config file in {env_var}: {result}")
selected_filepath = result

if selected_filepath is None:
# Third priority: default filepath
if default_config_filepath:
_LOGGER.info(
f"Using default {config_name}config. You may specify in env var: {str(config_env_vars)}"
)
return default_config_filepath
else:
if strict_env:
raise OSError("Unable to select config file.")

_LOGGER.info(f"Could not locate {config_name}config file.")
return None
return (
os.path.abspath(selected_filepath) if selected_filepath else selected_filepath
)


def deep_update(old, new):
"""
Recursively update nested dict, modifying source
"""
for key, value in new.items():
if isinstance(value, Mapping) and value:
old[key] = deep_update(old.get(key, {}), value)
else:
old[key] = new[key]
return old

0 comments on commit 31744b5

Please sign in to comment.