From 621b17a6d6008bd455a51b55242595443f148475 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 28 May 2024 22:12:06 +0100 Subject: [PATCH] Work on NaNs --- .../checkpoint/metadata/version_0_2_0.py | 123 ++++++++++++++---- src/anemoi/inference/commands/checkpoint.py | 1 + 2 files changed, 96 insertions(+), 28 deletions(-) diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py index 29942ac..b83939a 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py @@ -22,6 +22,9 @@ def __init__(self, metadata): def variables(self): return self.metadata["variables"] + def __repr__(self) -> str: + return self.__class__.__name__ + class ZarrRequest(DataRequest): def __init__(self, metadata): @@ -52,6 +55,14 @@ def param_level_ml(self): def param_step_sfc(self): return self.request["param_step"].get("sfc", []) + @property + def variables_with_nans(self): + return sorted(self.attributes.get("variables_with_nans", [])) + + def dump(self, indent): + print(" " * indent, self) + print(" " * indent, self.request) + class Forward(DataRequest): @cached_property @@ -61,8 +72,13 @@ def forward(self): def __getattr__(self, name): return getattr(self.forward, name) + def dump(self, indent): + print(" " * indent, self) + self.forward.dump(indent + 2) + class SubsetRequest(Forward): + # Subset in time pass @@ -71,25 +87,37 @@ class StatisticsRequest(Forward): class RenameRequest(Forward): - pass + @property + def variables(self): + raise NotImplementedError() + + @property + def variables_with_nans(self): + raise NotImplementedError() -class ConcatRequest(Forward): - @cached_property - def forward(self): - return data_request(self.metadata["datasets"][0]) +class MultiRequest(Forward): + def __init__(self, metadata): + super().__init__(metadata) + self.datasets = [data_request(d) for d in metadata["datasets"]] -class JoinRequest(Forward): @cached_property def forward(self): - return data_request(self.metadata["datasets"][0]) + return self.datasets[0] + def dump(self, indent): + print(" " * indent, self) + for dataset in self.datasets: + dataset.dump(indent + 2) + + +class JoinRequest(MultiRequest): @property def param_sfc(self): result = [] - for dataset in self.metadata["datasets"]: - for param in data_request(dataset).param_sfc: + for dataset in self.datasets: + for param in dataset.param_sfc: if param not in result: result.append(param) return result @@ -97,8 +125,8 @@ def param_sfc(self): @property def param_level_pl(self): result = [] - for dataset in self.metadata["datasets"]: - for param in data_request(dataset).param_level_pl: + for dataset in self.datasets: + for param in dataset.param_level_pl: if param not in result: result.append(param) return result @@ -106,8 +134,8 @@ def param_level_pl(self): @property def param_level_ml(self): result = [] - for dataset in self.metadata["datasets"]: - for param in data_request(dataset).param_level_ml: + for dataset in self.datasets: + for param in dataset.param_level_ml: if param not in result: result.append(param) return result @@ -115,26 +143,42 @@ def param_level_ml(self): @property def param_step_sfc(self): result = [] - for dataset in self.metadata["datasets"]: - for param in data_request(dataset).param_step_sfc: + for dataset in self.datasets: + for param in dataset.param_step_sfc: if param not in result: result.append(param) return result + @property + def variables(self): + raise NotImplementedError() -class EnsembleRequest(Forward): - @cached_property - def forward(self): - return data_request(self.metadata["datasets"][0]) + @property + def variables_with_nans(self): + result = set() + for dataset in self.datasets: + result.update(dataset.variables_with_nans) + return sorted(result) -class GridRequest(Forward): - @cached_property - def forward(self): - return data_request(self.metadata["datasets"][0]) + +class EnsembleRequest(MultiRequest): + pass + + +class GridRequest(MultiRequest): + @property + def grid(self): + raise NotImplementedError() + + @property + def area(self): + raise NotImplementedError() class SelectRequest(Forward): + # Select variables + @property def param_sfc(self): return [x for x in self.forward.param_sfc if x in self.variables] @@ -151,24 +195,43 @@ def param_level_ml(self): def param_step(self): return [x for x in self.forward.param_step if x[0] in self.variables] + @property + def param_step_sfc(self): + return [x for x in self.forward.param_step_sfc if x[0] in self.variables] + + @property + def variables_with_nans(self): + return [x for x in self.forward.variables_with_nans if x in self.variables] + class DropRequest(SelectRequest): - pass + @property + def variables(self): + raise NotImplementedError() -def data_request(dataset): - action = dataset["action"] + @property + def variables_with_nans(self): + result = set() + for dataset in self.metadata["datasets"]: + result.extend(dataset.variables_with_nans) + + return sorted(result) + + +def data_request(specific): + action = specific.pop("action") action = action[0].upper() + action[1:].lower() + "Request" LOG.debug(f"DataRequest: {action}") - return globals()[action](dataset) + return globals()[action](specific) class Version_0_2_0(Metadata): def __init__(self, metadata): super().__init__(metadata) specific = metadata["dataset"]["specific"] - self.data_request = data_request(specific) + self.data_request.dump(0) @property def variables(self): @@ -182,6 +245,10 @@ def area(self): def grid(self): return self.data_request.grid + @cached_property + def variables_with_nans(self): + return self.data_request.variables_with_nans + ######################### @property diff --git a/src/anemoi/inference/commands/checkpoint.py b/src/anemoi/inference/commands/checkpoint.py index 7bac20d..ba2c8c4 100644 --- a/src/anemoi/inference/commands/checkpoint.py +++ b/src/anemoi/inference/commands/checkpoint.py @@ -60,6 +60,7 @@ def run(self, args): print("select:", c.select) print("variable_to_index:", c.variable_to_index) print("variables:", c.variables) + print("variables_with_nans::", c.variables_with_nans) command = CheckpointCmd