From 3d822caec0400fde64ae3e09a4afc5c7b7816155 Mon Sep 17 00:00:00 2001 From: Rilwan Adewoyin <18564167+Rilwan-Adewoyin@users.noreply.github.com> Date: Wed, 4 Sep 2024 09:04:43 +0000 Subject: [PATCH] #47 added statistics_tendencies func Co-authored-by: Jakob Schloer --- src/anemoi/training/data/dataset.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index caaa986b..3494729a 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -32,6 +32,7 @@ def __init__( rollout: int = 1, multistep: int = 1, timeincrement: int = 1, + timestep: str = '6h', model_comm_group_rank: int = 0, model_comm_group_id: int = 0, model_comm_num_groups: int = 1, @@ -48,6 +49,8 @@ def __init__( length of rollout window, by default 12 timeincrement : int, optional time increment between samples, by default 1 + timestep : int, optional + the time frequency of the samples, by default '6h' multistep : int, optional collate (t-1, ... t - multistep) into the input state vector, by default 1 model_comm_group_rank : int, optional @@ -68,6 +71,7 @@ def __init__( self.rollout = rollout self.timeincrement = timeincrement + self.timestep = timestep # lazy init self.n_samples_per_epoch_total: int = 0 @@ -95,6 +99,14 @@ def statistics(self) -> dict: """Return dataset statistics.""" return self.data.statistics + @cached_property + def statistics_tendencies(self) -> dict: + """Return dataset tendency statistics.""" + # The statistics_tendencies are lazily loaded + if callable(self.data.statistics_tendencies): + self.data.statistics_tendencies = self.data.statistics_tendencies(self.timestep) + return self.data.statistics_tendencies + @cached_property def metadata(self) -> dict: """Return dataset metadata."""