Skip to content

Commit

Permalink
Merge pull request #1 from MeraX/feature/validate_datetime
Browse files Browse the repository at this point in the history
Check valid_datetime of input data
  • Loading branch information
b8raoult authored Sep 8, 2024
2 parents 24e27c8 + 2751c17 commit 7bb7e34
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,33 @@ def run(

LOGGER.info("Loading input: %d fields (lagged=%d)", len(input_fields), len(self.lagged))

input_fields_numpy = input_fields.to_numpy(dtype=np.float32)
if start_datetime is None:
start_datetime = input_fields.order_by(valid_datetime="ascending")[-1].metadata("valid_datetime")

num_fields_per_date = len(input_fields) // len(self.lagged) # assumed

# Check valid_datetime of input data
# The subsequent reshape operation assumes that input_fields are chunkable by datetime
for i, lag in enumerate(self.lagged):
date = start_datetime + datetime.timedelta(hours=lag)
dates_found = set(
field.datetime() for field in input_fields[i * num_fields_per_date : (i + 1) * num_fields_per_date]
)
# All chunks must have the same datetime that must agree with the lag
if dates_found != {date}:
raise RuntimeError(
"Inconsistent datetimes detected.\n"
f"Datetimes in data: {', '.join(d.isoformat() for d in dates_found)}.\n"
f"Expected datetime: {date.isoformat()} (for lag {lag})"
)

input_fields_numpy = input_fields.to_numpy(dtype=np.float32, reshape=False)

print(input_fields_numpy.shape)

input_fields_numpy = input_fields_numpy.reshape(
len(self.lagged),
len(input_fields) // len(self.lagged),
num_fields_per_date,
number_of_grid_points,
) # nlags, nparams, ngrid

Expand Down Expand Up @@ -223,10 +244,6 @@ def run(
:, constant_data_from_retrieved_fields_mask
]

if start_datetime is None:
start_datetime_str = input_fields.order_by(valid_datetime="ascending")[-1].metadata("valid_datetime")
start_datetime = datetime.datetime.fromisoformat(start_datetime_str)

constants = forcing_and_constants(
source=input_fields[:1],
param=self.checkpoint.computed_constants,
Expand Down

0 comments on commit 7bb7e34

Please sign in to comment.