Skip to content

Commit

Permalink
Work on NaNs
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 28, 2024
1 parent 91f9f79 commit 621b17a
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 28 deletions.
123 changes: 95 additions & 28 deletions src/anemoi/inference/checkpoint/metadata/version_0_2_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __init__(self, metadata):
def variables(self):
return self.metadata["variables"]

def __repr__(self) -> str:
return self.__class__.__name__


class ZarrRequest(DataRequest):
def __init__(self, metadata):
Expand Down Expand Up @@ -52,6 +55,14 @@ def param_level_ml(self):
def param_step_sfc(self):
return self.request["param_step"].get("sfc", [])

@property
def variables_with_nans(self):
return sorted(self.attributes.get("variables_with_nans", []))

def dump(self, indent):
print(" " * indent, self)
print(" " * indent, self.request)


class Forward(DataRequest):
@cached_property
Expand All @@ -61,8 +72,13 @@ def forward(self):
def __getattr__(self, name):
return getattr(self.forward, name)

def dump(self, indent):
print(" " * indent, self)
self.forward.dump(indent + 2)


class SubsetRequest(Forward):
# Subset in time
pass


Expand All @@ -71,70 +87,98 @@ class StatisticsRequest(Forward):


class RenameRequest(Forward):
pass

@property
def variables(self):
raise NotImplementedError()

@property
def variables_with_nans(self):
raise NotImplementedError()

class ConcatRequest(Forward):
@cached_property
def forward(self):
return data_request(self.metadata["datasets"][0])

class MultiRequest(Forward):
def __init__(self, metadata):
super().__init__(metadata)
self.datasets = [data_request(d) for d in metadata["datasets"]]

class JoinRequest(Forward):
@cached_property
def forward(self):
return data_request(self.metadata["datasets"][0])
return self.datasets[0]

def dump(self, indent):
print(" " * indent, self)
for dataset in self.datasets:
dataset.dump(indent + 2)


class JoinRequest(MultiRequest):
@property
def param_sfc(self):
result = []
for dataset in self.metadata["datasets"]:
for param in data_request(dataset).param_sfc:
for dataset in self.datasets:
for param in dataset.param_sfc:
if param not in result:
result.append(param)
return result

@property
def param_level_pl(self):
result = []
for dataset in self.metadata["datasets"]:
for param in data_request(dataset).param_level_pl:
for dataset in self.datasets:
for param in dataset.param_level_pl:
if param not in result:
result.append(param)
return result

@property
def param_level_ml(self):
result = []
for dataset in self.metadata["datasets"]:
for param in data_request(dataset).param_level_ml:
for dataset in self.datasets:
for param in dataset.param_level_ml:
if param not in result:
result.append(param)
return result

@property
def param_step_sfc(self):
result = []
for dataset in self.metadata["datasets"]:
for param in data_request(dataset).param_step_sfc:
for dataset in self.datasets:
for param in dataset.param_step_sfc:
if param not in result:
result.append(param)
return result

@property
def variables(self):
raise NotImplementedError()

class EnsembleRequest(Forward):
@cached_property
def forward(self):
return data_request(self.metadata["datasets"][0])
@property
def variables_with_nans(self):
result = set()
for dataset in self.datasets:
result.update(dataset.variables_with_nans)

return sorted(result)

class GridRequest(Forward):
@cached_property
def forward(self):
return data_request(self.metadata["datasets"][0])

class EnsembleRequest(MultiRequest):
pass


class GridRequest(MultiRequest):
@property
def grid(self):
raise NotImplementedError()

@property
def area(self):
raise NotImplementedError()


class SelectRequest(Forward):
# Select variables

@property
def param_sfc(self):
return [x for x in self.forward.param_sfc if x in self.variables]
Expand All @@ -151,24 +195,43 @@ def param_level_ml(self):
def param_step(self):
return [x for x in self.forward.param_step if x[0] in self.variables]

@property
def param_step_sfc(self):
return [x for x in self.forward.param_step_sfc if x[0] in self.variables]

@property
def variables_with_nans(self):
return [x for x in self.forward.variables_with_nans if x in self.variables]


class DropRequest(SelectRequest):
pass

@property
def variables(self):
raise NotImplementedError()

def data_request(dataset):
action = dataset["action"]
@property
def variables_with_nans(self):
result = set()
for dataset in self.metadata["datasets"]:
result.extend(dataset.variables_with_nans)

return sorted(result)


def data_request(specific):
action = specific.pop("action")
action = action[0].upper() + action[1:].lower() + "Request"
LOG.debug(f"DataRequest: {action}")
return globals()[action](dataset)
return globals()[action](specific)


class Version_0_2_0(Metadata):
def __init__(self, metadata):
super().__init__(metadata)
specific = metadata["dataset"]["specific"]

self.data_request = data_request(specific)
self.data_request.dump(0)

@property
def variables(self):
Expand All @@ -182,6 +245,10 @@ def area(self):
def grid(self):
return self.data_request.grid

@cached_property
def variables_with_nans(self):
return self.data_request.variables_with_nans

#########################

@property
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/inference/commands/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def run(self, args):
print("select:", c.select)
print("variable_to_index:", c.variable_to_index)
print("variables:", c.variables)
print("variables_with_nans::", c.variables_with_nans)


command = CheckpointCmd

0 comments on commit 621b17a

Please sign in to comment.