Skip to content

Commit

Permalink
Refactor catalog write (#587)
Browse files Browse the repository at this point in the history
* start refactoring write_pp_catalog to populate output catalog info from existing catalog

* fix temp_dir_root definition in filesystem.py

* add realm back to DataSourceBase query attribute
remove deprecated data source classes and functions from data_sources module

* remove filename parsers from catalog.py
add chunk_freq and path to catalog_define_pp_catalog_assets columns

* add realm back to query_catalog and a check for modeling realm if it is used instead of realm
refactor the write_pp_catalog method to populate output catalog information from ingested catalog and varlist attributes instead of inferring information from output file name

* explicity define query attributes to avoid issues with retaining information from prior
variable query in preprocessor
  • Loading branch information
wrongkindofdoctor authored Jun 7, 2024
1 parent 0848848 commit 1f53248
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 315 deletions.
2 changes: 1 addition & 1 deletion mdtf_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def backup_config(config):
# rename vars in cat_subset to align with POD convention
cat_subset = data_pp.rename_dataset_vars(cat_subset, cases)
# write the ESM intake catalog for the preprocessed files
data_pp.write_pp_catalog(cat_subset, model_paths, log.log)
data_pp.write_pp_catalog(cases, cat_subset, model_paths, log.log)
# configure the runtime environments and run the POD(s)
if not any(p.failed for p in pods.values()):
log.log.info("### %s: running pods '%s'.", [p for p in pods.keys()])
Expand Down
155 changes: 2 additions & 153 deletions src/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,6 @@
# RegexPattern that matches any string (path) that doesn't end with ".nc".
ignore_non_nc_regex = util.RegexPattern(r".*(?<!\.nc)")

sample_data_regex = util.RegexPattern(
r"""
(?P<sample_dataset>\S+)/ # first directory: model name
(?P<frequency>\w+)/ # subdirectory: data frequency
# file name = model name + variable name + frequency
(?P=sample_dataset)\.(?P<variable>\w+)\.(?P=frequency)\.nc
""",
input_field="remote_path",
match_error_filter=ignore_non_nc_regex
)


@util.regex_dataclass(sample_data_regex)
class SampleDataFile:
"""Dataclass describing catalog entries for sample model data files.
"""
sample_dataset: str = util.MANDATORY
frequency: util.DateFrequency = util.MANDATORY
variable: str = util.MANDATORY
remote_path: str = util.MANDATORY


@util.mdtf_dataclass
class DataSourceAttributesBase:
"""Class defining attributes that any DataSource needs to specify:
- *CASENAME*: User-supplied label to identify output of this run of the
package.
- *FIRSTYR*, *LASTYR*, *date_range*: Analysis period, specified as a closed
interval (i.e. running from 1 Jan of FIRSTYR through 31 Dec of LASTYR).
- *CASE_ROOT_DIR*: Root directory containing input model data. Different
DataSources may interpret this differently.
- *convention*: name of the variable naming convention used by the source of
model data.
"""
CASENAME: str = util.MANDATORY
FIRSTYR: str = util.MANDATORY
LASTYR: str = util.MANDATORY
date_range: util.DateRange = dataclasses.field(init=False)
CASE_ROOT_DIR: str = ""

log: dataclasses.InitVar = _log

def _set_case_root_dir(self, log=_log):
config = {}
if not self.CASE_ROOT_DIR and config.CASE_ROOT_DIR:
log.debug("Using global CASE_ROOT_DIR = '%s'.", config.CASE_ROOT_DIR)
self.CASE_ROOT_DIR = config.CASE_ROOT_DIR
# verify case root dir exists
if not os.path.isdir(self.CASE_ROOT_DIR):
log.critical("Data directory CASE_ROOT_DIR = '%s' not found.",
self.CASE_ROOT_DIR)
util.exit_handler(code=1)

def __post_init__(self, log=_log):
self._set_case_root_dir(log=log)
self.date_range = util.DateRange(self.FIRSTYR, self.LASTYR)


class DataSourceBase(util.MDTFObjectBase, util.CaseLoggerMixin):
"""DataSource for handling POD sample model data for multirun cases stored on a local filesystem.
Expand All @@ -85,7 +27,8 @@ class DataSourceBase(util.MDTFObjectBase, util.CaseLoggerMixin):
env_vars: util.WormDict()
query: dict = dict(frequency="",
path="",
standard_name=""
standard_name="",
realm=""
)

def __init__(self, case_name: str,
Expand Down Expand Up @@ -152,80 +95,6 @@ class CMIPDataSource(DataSourceBase):
# varlist = diagnostic.varlist
convention: str = "CMIP"

## NOTE: the __post_init__ method below is retained for reference in case
## we need to define all CMIP6 DRS attributes for the catalog query
#def __post_init__(self, log=_log, model=None, experiment=None):
# config = {}
# cv = cmip6.CMIP6_CVs()

# def _init_x_from_y(source, dest):
# if not getattr(self, dest, ""):
# try:
# source_val = getattr(self, source, "")
# if not source_val:
# raise KeyError()
# dest_val = cv.lookup_single(source_val, source, dest)
# log.debug("Set %s='%s' based on %s='%s'.",
# dest, dest_val, source, source_val)
# setattr(self, dest, dest_val)
# except KeyError:
# log.debug("Couldn't set %s from %s='%s'.",
# dest, source, source_val)
# setattr(self, dest, "")

# if not self.CASE_ROOT_DIR and config.CASE_ROOT_DIR:
# log.debug("Using global CASE_ROOT_DIR = '%s'.", config.CASE_ROOT_DIR)
# self.CASE_ROOT_DIR = config.CASE_ROOT_DIR
# verify case root dir exists
# if not os.path.isdir(self.CASE_ROOT_DIR):
# log.critical("Data directory CASE_ROOT_DIR = '%s' not found.",
# self.CASE_ROOT_DIR)
# util.exit_handler(code=1)

# should really fix this at the level of CLI flag synonyms
# if model and not self.source_id:
# self.source_id = model
# if experiment and not self.experiment_id:
# self.experiment_id = experiment

# # validate non-empty field values
# for field in dataclasses.fields(self):
# val = getattr(self, field.name, "")
# if not val:
# continue
# try:
# if not cv.is_in_cv(field.name, val):
# log.error(("Supplied value '%s' for '%s' is not recognized by "
# "the CMIP6 CV. Continuing, but queries will probably fail."),
# val, field.name)
# except KeyError:
# # raised if not a valid CMIP6 CV category
# continue
# # currently no inter-field consistency checks: happens implicitly, since
# # set_experiment will find zero experiments.

# # Attempt to determine first few fields of DRS, to avoid having to crawl
# # entire DRS structure
# _init_x_from_y('experiment_id', 'activity_id')
# _init_x_from_y('source_id', 'institution_id')
# _init_x_from_y('institution_id', 'source_id')
# # TODO: multi-column lookups
# # set CATALOG_DIR to be further down the hierarchy if possible, to
# # avoid having to crawl entire DRS structure; CASE_ROOT_DIR remains the
# # root of the DRS hierarchy
# new_root = self.CASE_ROOT_DIR
# for drs_attr in ("activity_id", "institution_id", "source_id", "experiment_id"):
# drs_val = getattr(self, drs_attr, "")
# if not drs_val:
# break
# new_root = os.path.join(new_root, drs_val)
# if not os.path.isdir(new_root):
# log.error("Data directory '%s' not found; starting crawl at '%s'.",
# new_root, self.CASE_ROOT_DIR)
# self.CATALOG_DIR = self.CASE_ROOT_DIR
# else:
# self.CATALOG_DIR = new_root


@data_source.maker
class CESMDataSource(DataSourceBase):
Expand All @@ -247,23 +116,3 @@ class GFDLDataSource(DataSourceBase):
# col_spec = sampleLocalFileDataSource_col_spec
# varlist = diagnostic.varlist
convention: str = "GFDL"



dummy_regex = util.RegexPattern(
r"""(?P<dummy_group>.*) # match everything; RegexPattern needs >= 1 named groups
""",
input_field="remote_path",
match_error_filter=ignore_non_nc_regex
)


@util.regex_dataclass(dummy_regex)
class GlobbedDataFile:
"""Applies a trivial regex to the paths returned by the glob."""
dummy_group: str = util.MANDATORY
remote_path: str = util.MANDATORY




73 changes: 40 additions & 33 deletions src/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,34 +875,35 @@ def query_catalog(self,
# path_regex = '*' + case_name + '*'
freq = case_d.varlist.T.frequency.format()

for v in case_d.varlist.iter_vars():
realm_regex = v.realm + '*'
date_range = v.translation.T.range
for var in case_d.varlist.iter_vars():
realm_regex = var.realm + '*'
date_range = var.translation.T.range
# define initial query dictionary with variable settings requirements that do not change if
# the variable is translated
# TODO: add method to convert freq from DateFrequency object to string
case_d.query['frequency'] = freq
case_d.query['path'] = [path_regex]
case_d.query['variable_id'] = v.translation.name
# search translation for further query requirements
for q in case_d.query:
if hasattr(v.translation, q):
case_d.query.update({q: getattr(v.translation, q)})
case_d.query['variable_id'] = var.translation.name
case_d.query['realm'] = realm_regex
case_d.query['standard_name'] = var.translation.standard_name

# change realm key name if necessary
if cat.df.get('modeling_realm', None) is not None:
case_d.query['modeling_realm'] = case_d.query.pop('realm')

# search catalog for convention specific query object
cat_subset = cat.search(**case_d.query)
if cat_subset.df.empty:
# check whether there is an alternate variable to substitute
if any(v.alternates):
if any(var.alternates):
try_new_query = True
for a in v.alternates:
for a in var.alternates:
case_d.query.update({'variable_id': a.name})
if any(v.translation.scalar_coords):
if any(var.translation.scalar_coords):
found_z_entry = False
# check for vertical coordinate to determine if level extraction is needed
for c in a.scalar_coords:
if c.axis == 'Z':
v.translation.requires_level_extraction = True
var.translation.requires_level_extraction = True
found_z_entry = True
break
else:
Expand Down Expand Up @@ -1257,6 +1258,7 @@ def process(self,
return cat_subset

def write_pp_catalog(self,
cases: dict,
input_catalog_ds: xr.Dataset,
config: util.PodPathManager,
log: logging.log):
Expand All @@ -1267,26 +1269,32 @@ def write_pp_catalog(self,
pp_cat_assets = util.define_pp_catalog_assets(config, cat_file_name)
file_list = util.get_file_list(config.OUTPUT_DIR)
# fill in catalog information from pp file name
entries = [e.cat_entry for e in list(map(util.catalog.ppParser['ppParser' + 'GFDL'], file_list))]
# append columns defined in assets
columns = [att['column_name'] for att in pp_cat_assets['attributes']]
for col in columns:
for e in entries:
if col not in e.keys():
e[col] = ""
# copy information from input catalog to pp catalog entries
global_attrs = ['convention', 'realm']
for e in entries:
ds_match = input_catalog_ds[e['dataset_name']]
for att in global_attrs:
e[att] = ds_match.attrs.get(att, '')
ds_var = ds_match.data_vars.get(e['variable_id'])
for key, val in ds_var.attrs.items():
if key in columns:
e[key] = val

# create a Pandas dataframe rom the the catalog entries
cat_df = pd.DataFrame(entries)
cat_entries = []
# each key is a case
for case_name, case_dict in cases.items():
ds_match = input_catalog_ds[case_name]
for var in case_dict.varlist.iter_vars():
ds_var = ds_match.data_vars.get(var.translation.name, None)
if ds_var is None:
log.error(f'No var {var.translation.name}')
d = dict.fromkeys(columns, "")
for key, val in ds_match.attrs.items():
if 'intake_esm_attrs' in key:
for c in columns:
if key.split('intake_esm_attrs:')[1] == c:
d[c] = val
if var.translation.convention == 'no_translation':
d.update({'convention': var.convention})
else:
d.update({'convention': var.translation.convention})
d.update({'path': var.dest_path})
cat_entries.append(d)

# create a Pandas dataframe romthe catalog entries

cat_df = pd.DataFrame(cat_entries)
cat_df.head()
# validate the catalog
try:
Expand All @@ -1298,7 +1306,7 @@ def write_pp_catalog(self,
)
)
except Exception as exc:
log.error(f'Unable to validate esm intake catalog for pp data: {exc}')
log.error(f'Error validating ESM intake catalog for pp data: {exc}')
try:
log.debug(f'Writing pp data catalog {cat_file_name} csv and json files to {config.OUTPUT_DIR}')
validated_cat.serialize(cat_file_name,
Expand Down Expand Up @@ -1334,7 +1342,6 @@ def process(self, case_list: dict,
cat_subset = self.query_catalog(case_list, config.DATA_CATALOG)
for case_name, case_xr_dataset in cat_subset.items():
for v in case_list[case_name].varlist.iter_vars():
self.edit_request(v, convention=cat_subset[case_name].convention)
cat_subset[case_name] = self.parse_ds(v, case_xr_dataset)

return cat_subset
Expand Down
Loading

0 comments on commit 1f53248

Please sign in to comment.