Skip to content

Commit

Permalink
refactored into explicit parallel runner class
Browse files Browse the repository at this point in the history
  • Loading branch information
cathalobrien committed Jan 17, 2025
1 parent 27965ff commit 43167c5
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 140 deletions.
96 changes: 0 additions & 96 deletions src/anemoi/inference/parallel.py

This file was deleted.

50 changes: 12 additions & 38 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@

import numpy as np
import torch
import torch.distributed as dist
from anemoi.utils.dates import frequency_to_timedelta as to_timedelta
from anemoi.utils.text import table
from anemoi.utils.timer import Timer # , Timers

from .checkpoint import Checkpoint
from .context import Context
from .parallel import get_parallel_info
from .parallel import init_parallel
from .postprocess import Accumulator
from .postprocess import Noop
from .precisions import PRECISIONS
Expand Down Expand Up @@ -247,15 +244,6 @@ def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
return model.predict_step(input_tensor_torch)

def forecast(self, lead_time, input_tensor_numpy, input_state):

# determine processes rank for parallel inference and assign a device
global_rank, local_rank, world_size = get_parallel_info()
if self.device == "cuda":
self.device = f"{self.device}:{local_rank}"
torch.cuda.set_device(local_rank)

self.model.eval()

torch.set_grad_enabled(False)

# Create pytorch input tensor
Expand All @@ -264,22 +252,14 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
lead_time = to_timedelta(lead_time)
steps = lead_time // self.checkpoint.timestep

if global_rank == 0:
LOG.info("World size: %d", world_size)
LOG.info("Using autocast %s", self.autocast)
LOG.info(
"Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps
)
LOG.info("Using autocast %s", self.autocast)
LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps)

result = input_state.copy() # We should not modify the input state
result["fields"] = dict()

start = input_state["date"]

# Create a model comm group for parallel inference
# A dummy comm group is created if only a single device is in use
model_comm_group = init_parallel(self.device, global_rank, world_size)

# The variable `check` is used to keep track of which variables have been updated
# In the input tensor. `reset` is used to reset `check` to False except
# when the values are of the constant in time variables
Expand All @@ -299,27 +279,25 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
for s in range(steps):
step = (s + 1) * self.checkpoint.timestep
date = start + step
if global_rank == 0:
LOG.info("Forecasting step %s (%s)", step, date)
LOG.info("Forecasting step %s (%s)", step, date)

result["date"] = date

# Predict next state of atmosphere
with torch.autocast(device_type=self.device, dtype=self.autocast):
y_pred = self.predict_step(self.model, input_tensor_torch, fcstep=s)

if global_rank == 0:
# Detach tensor and squeeze (should we detach here?)
output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables)
# Detach tensor and squeeze (should we detach here?)
output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables)

# Update state
for i in range(output.shape[1]):
result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i]
# Update state
for i in range(output.shape[1]):
result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i]

if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_output_tensor("Output tensor", output)
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_output_tensor("Output tensor", output)

yield result
yield result

# No need to prepare next input tensor if we are at the last step
if s == steps - 1:
Expand All @@ -333,9 +311,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):

del y_pred # Recover memory

input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)
input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(input_tensor_torch, input_state, date, check)
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)
Expand All @@ -354,8 +330,6 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_input_tensor("Next input tensor", input_tensor_torch)

#dist.destroy_process_group(model_comm_group)

def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, check):

# input_tensor_torch is shape: (batch, multi_step_input, values, variables)
Expand Down
109 changes: 103 additions & 6 deletions src/anemoi/inference/runners/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,53 @@
# nor does it submit to any jurisdiction.


import datetime
import logging
import os
import socket
import subprocess

import numpy as np
import torch
import torch.distributed as dist

from ..outputs import create_output
from . import runner_registry
from .default import DefaultRunner
from ..parallel import get_parallel_info
from ..outputs import create_output

LOG = logging.getLogger(__name__)


@runner_registry.register("parallel")
class ParallelRunner(DefaultRunner):
"""Runner which splits a model over multiple devices"""

def __init__(self, context):
super().__init__(context)
global_rank, local_rank, world_size = get_parallel_info()
global_rank, local_rank, world_size = self.__get_parallel_info()
self.global_rank = global_rank
self.local_rank = local_rank
self.world_size = world_size

if self.device == "cuda":
self.device = f"{self.device}:{local_rank}"
torch.cuda.set_device(local_rank)

# disable most logging on non-zero ranks
if self.global_rank != 0:
logging.getLogger().setLevel(logging.WARNING)

# Create a model comm group for parallel inference
# A dummy comm group is created if only a single device is in use
model_comm_group = self.__init_parallel(self.device, self.global_rank, self.world_size)
self.model_comm_group = model_comm_group

def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
model_comm_group = kwargs.get("model_comm_group", None)
if model_comm_group is None:
if self.model_comm_group is None:
return model.predict_step(input_tensor_torch)
else:
try:
return model.predict_step(input_tensor_torch, model_comm_group)
return model.predict_step(input_tensor_torch, self.model_comm_group)
except TypeError as err:
LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference")
raise err
Expand All @@ -47,3 +67,80 @@ def create_output(self):
else:
output = create_output(self, "none")
return output

def __del__(self):
if self.model_comm_group is not None:
dist.destroy_process_group()

def __init_network(self):
"""Reads Slurm environment to set master address and port for parallel communication"""

# Get the master address from the SLURM_NODELIST environment variable
slurm_nodelist = os.environ.get("SLURM_NODELIST")
if not slurm_nodelist:
raise ValueError("SLURM_NODELIST environment variable is not set.")

# Use subprocess to execute scontrol and get the first hostname
result = subprocess.run(
["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True
)
master_addr = result.stdout.splitlines()[0]

# Resolve the master address using nslookup
try:
resolved_addr = socket.gethostbyname(master_addr)
except socket.gaierror:
raise ValueError(f"Could not resolve hostname: {master_addr}")

# Set the resolved address as MASTER_ADDR
master_addr = resolved_addr

# Calculate the MASTER_PORT using SLURM_JOBID
slurm_jobid = os.environ.get("SLURM_JOBID")
if not slurm_jobid:
raise ValueError("SLURM_JOBID environment variable is not set.")

master_port = str(10000 + int(slurm_jobid[-4:]))

# Print the results for confirmation
LOG.debug(f"MASTER_ADDR: {master_addr}")
LOG.debug(f"MASTER_PORT: {master_port}")

return master_addr, master_port

def __init_parallel(self, device, global_rank, world_size):
"""Creates a model communication group to be used for parallel inference"""

if world_size > 1:

master_addr, master_port = self.__init_network()

# use 'startswith' instead of '==' in case device is 'cuda:0'
if device.startswith("cuda"):
backend = "nccl"
else:
backend = "gloo"

dist.init_process_group(
backend=backend,
init_method=f"tcp://{master_addr}:{master_port}",
timeout=datetime.timedelta(minutes=3),
world_size=world_size,
rank=global_rank,
)
LOG.info(f"Creating a model comm group with {world_size} devices with the {backend} backend")

model_comm_group_ranks = np.arange(world_size, dtype=int)
model_comm_group = torch.distributed.new_group(model_comm_group_ranks)
else:
model_comm_group = None

return model_comm_group

def __get_parallel_info(self):
"""Reads Slurm env vars, if they exist, to determine if inference is running in parallel"""
local_rank = int(os.environ.get("SLURM_LOCALID", 0)) # Rank within a node, between 0 and num_gpus
global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Rank within all nodes
world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes

return global_rank, local_rank, world_size

0 comments on commit 43167c5

Please sign in to comment.