Skip to content

Commit

Permalink
Merge pull request #108 from ecmwf/feature/model-parallel
Browse files Browse the repository at this point in the history
You can run inference in parallel now by specifying 'runner:parallel' in your inference config file
  • Loading branch information
cathalobrien authored Jan 22, 2025
2 parents e9dbd48 + ecc8fa0 commit e23934e
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you!
- Add CONTRIBUTORS.md file (#36)
- Add sanetise command
- Add support for huggingface
- Added ability to run inference over multiple GPUs [#55](https://github.com/ecmwf/anemoi-inference/pull/55)

### Changed
- Change `write_initial_state` default value to `true`
Expand Down
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ of the *Anemoi* packages.

- :doc:`overview`
- :doc:`installing`
- :doc:`parallel`

.. toctree::
:maxdepth: 1
:hidden:

overview
installing
parallel

*********************
Tree levels of APIs
Expand Down
70 changes: 70 additions & 0 deletions docs/parallel.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
####################
Parallel Inference
####################

If the memory requirements of your model are too large to fit within a
single GPU, you can run Anemoi-Inference in parallel across multiple
GPUs.

Parallel inference requires SLURM to launch the parallel processes and
to determine information about your network environment. If SLURM is not
available to you, please create an issue on the Anemoi-Inference github
page `here <https://github.com/ecmwf/anemoi-inference/issues>`_.

***************
Configuration
***************

To run in parallel, you must add '`runner:parallel`' to your inference
config file.

.. code:: yaml
checkpoint: /path/to/inference-last.ckpt
lead_time: 60
runner: parallel
input:
grib: /path/to/input.grib
output:
grib: /path/to/output.grib
*******************************
Running inference in parallel
*******************************

Below is an example SLURM batch script to launch a parallel inference
job across 4 GPUs.

.. code:: bash
#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4
#SBATCH --cpus-per-task=8
#SBATCH --time=0:05:00
#SBATCH --output=outputs/parallel_inf.%j.out
source /path/to/venv/bin/activate
srun anemoi-inference run parallel.yaml
.. warning::

If you specify '`runner:parallel`' but you don't launch with
'`srun`', your anemoi-inference job may hang as only 1 process will
be launched.

.. note::

By default, anemoi-inference will determine your systems master
address and port itself. If this fails (i.e. when running
Anemoi-Inference inside a container), you can instead set these
values yourself via environment variables in your SLURM batch script:

.. code:: bash
MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n 1)
export MASTER_ADDR=$(nslookup $MASTER_ADDR | grep -oP '(?<=Address: ).*')
export MASTER_PORT=$((10000 + RANDOM % 10000))
srun anemoi-inference run parallel.yaml
3 changes: 1 addition & 2 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,10 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
# Create pytorch input tensor
input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(self.device)

LOG.info("Using autocast %s", self.autocast)

lead_time = to_timedelta(lead_time)
steps = lead_time // self.checkpoint.timestep

LOG.info("Using autocast %s", self.autocast)
LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps)

result = input_state.copy() # We should not modify the input state
Expand Down
166 changes: 166 additions & 0 deletions src/anemoi/inference/runners/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import datetime
import logging
import os
import socket
import subprocess

import numpy as np
import torch
import torch.distributed as dist

from ..outputs import create_output
from . import runner_registry
from .default import DefaultRunner

LOG = logging.getLogger(__name__)


@runner_registry.register("parallel")
class ParallelRunner(DefaultRunner):
"""Runner which splits a model over multiple devices"""

def __init__(self, context):
super().__init__(context)
global_rank, local_rank, world_size = self.__get_parallel_info()
self.global_rank = global_rank
self.local_rank = local_rank
self.world_size = world_size

if self.device == "cuda":
self.device = f"{self.device}:{local_rank}"
torch.cuda.set_device(local_rank)

# disable most logging on non-zero ranks
if self.global_rank != 0:
logging.getLogger().setLevel(logging.WARNING)

# Create a model comm group for parallel inference
# A dummy comm group is created if only a single device is in use
model_comm_group = self.__init_parallel(self.device, self.global_rank, self.world_size)
self.model_comm_group = model_comm_group

# Ensure each parallel model instance uses the same seed
if self.global_rank == 0:
seed = torch.initial_seed()
torch.distributed.broadcast_object_list([seed], src=0, group=model_comm_group)
else:
msg_buffer = np.array([1], dtype=np.uint64)
torch.distributed.broadcast_object_list(msg_buffer, src=0, group=model_comm_group)
seed = msg_buffer[0]
torch.manual_seed(seed)

def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
if self.model_comm_group is None:
return model.predict_step(input_tensor_torch)
else:
try:
return model.predict_step(input_tensor_torch, self.model_comm_group)
except TypeError as err:
LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference")
raise err

def create_output(self):
if self.global_rank == 0:
output = create_output(self, self.config.output)
LOG.info("Output: %s", output)
return output
else:
output = create_output(self, "none")
return output

def __del__(self):
if self.model_comm_group is not None:
dist.destroy_process_group()

def __init_network(self):
"""Reads Slurm environment to set master address and port for parallel communication"""

# Get the master address from the SLURM_NODELIST environment variable
slurm_nodelist = os.environ.get("SLURM_NODELIST")
if not slurm_nodelist:
raise ValueError("SLURM_NODELIST environment variable is not set.")

# Check if MASTER_ADDR is given, otherwise try set it using 'scontrol'
master_addr = os.environ.get("MASTER_ADDR")
if master_addr is None:
LOG.debug("'MASTER_ADDR' environment variable not set. Trying to set via SLURM")
try:
result = subprocess.run(
["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True
)
except subprocess.CalledProcessError as err:
LOG.error(
"Python could not execute 'scontrol show hostname $SLURM_NODELIST' while calculating MASTER_ADDR. You could avoid this error by setting the MASTER_ADDR env var manually."
)
raise err

master_addr = result.stdout.splitlines()[0]

# Resolve the master address using nslookup
try:
master_addr = socket.gethostbyname(master_addr)
except socket.gaierror:
raise ValueError(f"Could not resolve hostname: {master_addr}")

# Check if MASTER_PORT is given, otherwise generate one based on SLURM_JOBID
master_port = os.environ.get("MASTER_PORT")
if master_port is None:
LOG.debug("'MASTER_PORT' environment variable not set. Trying to set via SLURM")
slurm_jobid = os.environ.get("SLURM_JOBID")
if not slurm_jobid:
raise ValueError("SLURM_JOBID environment variable is not set.")

master_port = str(10000 + int(slurm_jobid[-4:]))

# Print the results for confirmation
LOG.debug(f"MASTER_ADDR: {master_addr}")
LOG.debug(f"MASTER_PORT: {master_port}")

return master_addr, master_port

def __init_parallel(self, device, global_rank, world_size):
"""Creates a model communication group to be used for parallel inference"""

if world_size > 1:

master_addr, master_port = self.__init_network()

# use 'startswith' instead of '==' in case device is 'cuda:0'
if device.startswith("cuda"):
backend = "nccl"
else:
backend = "gloo"

dist.init_process_group(
backend=backend,
init_method=f"tcp://{master_addr}:{master_port}",
timeout=datetime.timedelta(minutes=3),
world_size=world_size,
rank=global_rank,
)
LOG.info(f"Creating a model comm group with {world_size} devices with the {backend} backend")

model_comm_group_ranks = np.arange(world_size, dtype=int)
model_comm_group = torch.distributed.new_group(model_comm_group_ranks)
else:
model_comm_group = None

return model_comm_group

def __get_parallel_info(self):
"""Reads Slurm env vars, if they exist, to determine if inference is running in parallel"""
local_rank = int(os.environ.get("SLURM_LOCALID", 0)) # Rank within a node, between 0 and num_gpus
global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Rank within all nodes
world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes

return global_rank, local_rank, world_size

0 comments on commit e23934e

Please sign in to comment.