Skip to content

Commit

Permalink
refactor to parallel runner
Browse files Browse the repository at this point in the history
  • Loading branch information
cathalobrien committed Jan 17, 2025
1 parent 1a0ae49 commit 27965ff
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 24 deletions.
51 changes: 27 additions & 24 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import torch
import torch.distributed as dist
from anemoi.utils.dates import frequency_to_timedelta as to_timedelta
from anemoi.utils.text import table
from anemoi.utils.timer import Timer # , Timers
Expand Down Expand Up @@ -320,38 +321,40 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):

yield result

# No need to prepare next input tensor if we are at the last step
if s == steps - 1:
continue
# No need to prepare next input tensor if we are at the last step
if s == steps - 1:
continue

# Update tensor for next iteration
# Update tensor for next iteration

check[:] = reset
check[:] = reset

input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check)
input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check)

del y_pred # Recover memory
del y_pred # Recover memory

input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)
input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)

if not check.all():
# Not all variables have been updated
missing = []
variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index
mapping = {v: k for k, v in variable_to_input_tensor_index.items()}
for i in range(check.shape[-1]):
if not check[i]:
missing.append(mapping[i])
if not check.all():
# Not all variables have been updated
missing = []
variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index
mapping = {v: k for k, v in variable_to_input_tensor_index.items()}
for i in range(check.shape[-1]):
if not check[i]:
missing.append(mapping[i])

raise ValueError(f"Missing variables in input tensor: {sorted(missing)}")
raise ValueError(f"Missing variables in input tensor: {sorted(missing)}")

if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_input_tensor("Next input tensor", input_tensor_torch)
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_input_tensor("Next input tensor", input_tensor_torch)

#dist.destroy_process_group(model_comm_group)

def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, check):

Expand Down
49 changes: 49 additions & 0 deletions src/anemoi/inference/runners/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging

from . import runner_registry
from .default import DefaultRunner
from ..parallel import get_parallel_info
from ..outputs import create_output

LOG = logging.getLogger(__name__)


@runner_registry.register("parallel")
class ParallelRunner(DefaultRunner):

def __init__(self, context):
super().__init__(context)
global_rank, local_rank, world_size = get_parallel_info()
self.global_rank = global_rank
self.local_rank = local_rank
self.world_size = world_size

def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
model_comm_group = kwargs.get("model_comm_group", None)
if model_comm_group is None:
return model.predict_step(input_tensor_torch)
else:
try:
return model.predict_step(input_tensor_torch, model_comm_group)
except TypeError as err:
LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference")
raise err

def create_output(self):
if self.global_rank == 0:
output = create_output(self, self.config.output)
LOG.info("Output: %s", output)
return output
else:
output = create_output(self, "none")
return output

0 comments on commit 27965ff

Please sign in to comment.