Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added the option to run inference in parallel #108

Merged
merged 30 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ce74e6e
model parallel wip
cathalobrien Nov 25, 2024
936c60a
logging only on rank 0
cathalobrien Nov 26, 2024
d870289
fallback if env vars arent set and some work only done by rank 0
cathalobrien Nov 26, 2024
b39b796
changelog
cathalobrien Nov 26, 2024
b95e167
pre-commit checks and no model comm group for single gpu case
cathalobrien Nov 26, 2024
9fe691c
changelog
cathalobrien Nov 26, 2024
5f92574
added parallel inf
cathalobrien Jan 14, 2025
71fdf0e
precommit
cathalobrien Jan 14, 2025
9264754
9k parallel inference works
cathalobrien Jan 15, 2025
06a575d
refactor
cathalobrien Jan 15, 2025
fa89bb8
refactor
cathalobrien Jan 15, 2025
a6a4ea4
tidy
cathalobrien Jan 16, 2025
8a73f62
more compatible with older versions of models
cathalobrien Jan 16, 2025
db560eb
forgot precommit
cathalobrien Jan 16, 2025
b21d811
remove commented code
cathalobrien Jan 16, 2025
48ad37b
added license
cathalobrien Jan 16, 2025
b9ecc14
feedback
cathalobrien Jan 16, 2025
1a0ae49
Merge remote-tracking branch 'origin/develop' into feature/model-para…
cathalobrien Jan 17, 2025
27965ff
refactor to parallel runner
cathalobrien Jan 17, 2025
43167c5
refactored into explicit parallel runner class
cathalobrien Jan 17, 2025
6974ac3
allow MASTER_ADDR and MASTER_PORT to be set as env vars before runtime
cathalobrien Jan 20, 2025
2016c7b
readd line accicdentally deleted
cathalobrien Jan 21, 2025
bd391f5
added documentation
cathalobrien Jan 21, 2025
1cd4982
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2025
079036a
forgot precommit
cathalobrien Jan 21, 2025
d6a77ff
Merge branch 'feature/model-parallel' of github.com:ecmwf/anemoi-infe…
cathalobrien Jan 21, 2025
b8be926
docs feedback
cathalobrien Jan 21, 2025
5dd8a55
added a link to parallel inference to index
cathalobrien Jan 22, 2025
861161d
Ensure each model has the same seed
cathalobrien Jan 22, 2025
ecc8fa0
Merge branch 'develop' into feature/model-parallel
cathalobrien Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you!
- Add CONTRIBUTORS.md file (#36)
- Add sanetise command
- Add support for huggingface
- Added ability to run inference over multiple GPUs [#55](https://github.com/ecmwf/anemoi-inference/pull/55)

### Changed
- Change `write_initial_state` default value to `true`
Expand Down
79 changes: 79 additions & 0 deletions src/anemoi/inference/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import datetime
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved
import logging
import os
import socket
import subprocess

import numpy as np
import torch
import torch.distributed as dist

LOG = logging.getLogger(__name__)


def init_network():
"""Reads Slurm environment to set master address and port for parallel communication"""
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved

# Get the master address from the SLURM_NODELIST environment variable
slurm_nodelist = os.environ.get("SLURM_NODELIST")
if not slurm_nodelist:
raise ValueError("SLURM_NODELIST environment variable is not set.")

# Use subprocess to execute scontrol and get the first hostname
result = subprocess.run(
["scontrol", "show", "hostname", slurm_nodelist], stdout=subprocess.PIPE, text=True, check=True
)
master_addr = result.stdout.splitlines()[0]

# Resolve the master address using nslookup
try:
resolved_addr = socket.gethostbyname(master_addr)
except socket.gaierror:
raise ValueError(f"Could not resolve hostname: {master_addr}")

# Set the resolved address as MASTER_ADDR
master_addr = resolved_addr

# Calculate the MASTER_PORT using SLURM_JOBID
slurm_jobid = os.environ.get("SLURM_JOBID")
if not slurm_jobid:
raise ValueError("SLURM_JOBID environment variable is not set.")

master_port = str(10000 + int(slurm_jobid[-4:]))

# Print the results for confirmation
LOG.debug(f"MASTER_ADDR: {master_addr}")
LOG.debug(f"MASTER_PORT: {master_port}")

return master_addr, master_port


def init_parallel(global_rank, world_size):
"""Creates a model communication group to be used for parallel inference"""

if world_size > 1:

master_addr, master_port = init_network()
dist.init_process_group(
backend="nccl",
init_method=f"tcp://{master_addr}:{master_port}",
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved
timeout=datetime.timedelta(minutes=3),
world_size=world_size,
rank=global_rank,
)

model_comm_group_ranks = np.arange(world_size, dtype=int)
model_comm_group = torch.distributed.new_group(model_comm_group_ranks)
else:
model_comm_group = None

return model_comm_group


def get_parallel_info():
"""Reads Slurm env vars, if they exist, to determine if inference is running in parallel"""
local_rank = int(os.environ.get("SLURM_LOCALID", 0)) # Rank within a node, between 0 and num_gpus
global_rank = int(os.environ.get("SLURM_PROCID", 0)) # Rank within all nodes
world_size = int(os.environ.get("SLURM_NTASKS", 1)) # Total number of processes
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved

return global_rank, local_rank, world_size
96 changes: 61 additions & 35 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from .checkpoint import Checkpoint
from .context import Context
from .parallel import get_parallel_info
from .parallel import init_parallel
from .postprocess import Accumulator
from .postprocess import Noop
from .precisions import PRECISIONS
Expand Down Expand Up @@ -239,25 +241,38 @@ def model(self):
return model

def forecast(self, lead_time, input_tensor_numpy, input_state):

# determine processes rank for parallel inference and assign a device
global_rank, local_rank, world_size = get_parallel_info()
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved
self.device = f"{self.device}:{local_rank}"
torch.cuda.set_device(local_rank)
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved

self.model.eval()
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved

torch.set_grad_enabled(False)

# Create pytorch input tensor
input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(self.device)

LOG.info("Using autocast %s", self.autocast)

lead_time = to_timedelta(lead_time)
steps = lead_time // self.checkpoint.timestep

LOG.info("Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps)
if global_rank == 0:
LOG.info("World size: %d", world_size)
LOG.info("Using autocast %s", self.autocast)
LOG.info(
"Lead time: %s, time stepping: %s Forecasting %s steps", lead_time, self.checkpoint.timestep, steps
)

result = input_state.copy() # We should not modify the input state
result["fields"] = dict()

start = input_state["date"]

# Create a model comm group for parallel inference
# A dummy comm group is created if only a single device is in use
model_comm_group = init_parallel(global_rank, world_size)

# The variable `check` is used to keep track of which variables have been updated
# In the input tensor. `reset` is used to reset `check` to False except
# when the values are of the constant in time variables
Expand All @@ -277,56 +292,67 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
for s in range(steps):
step = (s + 1) * self.checkpoint.timestep
date = start + step
LOG.info("Forecasting step %s (%s)", step, date)
if global_rank == 0:
LOG.info("Forecasting step %s (%s)", step, date)

result["date"] = date

# Predict next state of atmosphere
with torch.autocast(device_type=self.device, dtype=self.autocast):
y_pred = self.model.predict_step(input_tensor_torch)
if model_comm_group is None:
y_pred = self.model.predict_step(input_tensor_torch)
else:
try:
y_pred = self.model.predict_step(input_tensor_torch, model_comm_group)
except TypeError as err:
LOG.error("Please upgrade to a newer version of anemoi-models to use parallel inference")
raise err
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved

# Detach tensor and squeeze (should we detach here?)
output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables)
if global_rank == 0:
cathalobrien marked this conversation as resolved.
Show resolved Hide resolved
# Detach tensor and squeeze (should we detach here?)
output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables)

# Update state
for i in range(output.shape[1]):
result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i]
# Update state
for i in range(output.shape[1]):
result["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i]

if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_output_tensor("Output tensor", output)
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_output_tensor("Output tensor", output)

yield result
yield result

# No need to prepare next input tensor if we are at the last step
if s == steps - 1:
continue
# No need to prepare next input tensor if we are at the last step
if s == steps - 1:
continue

# Update tensor for next iteration
# Update tensor for next iteration

check[:] = reset
check[:] = reset

input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check)
input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check)

del y_pred # Recover memory
del y_pred # Recover memory

input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(input_tensor_torch, input_state, date, check)
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)
input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
input_tensor_torch, input_state, date, check
)

if not check.all():
# Not all variables have been updated
missing = []
variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index
mapping = {v: k for k, v in variable_to_input_tensor_index.items()}
for i in range(check.shape[-1]):
if not check[i]:
missing.append(mapping[i])
if not check.all():
# Not all variables have been updated
missing = []
variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index
mapping = {v: k for k, v in variable_to_input_tensor_index.items()}
for i in range(check.shape[-1]):
if not check[i]:
missing.append(mapping[i])

raise ValueError(f"Missing variables in input tensor: {sorted(missing)}")
raise ValueError(f"Missing variables in input tensor: {sorted(missing)}")

if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_input_tensor("Next input tensor", input_tensor_torch)
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
self._print_input_tensor("Next input tensor", input_tensor_torch)

def copy_prognostic_fields_to_input_tensor(self, input_tensor_torch, y_pred, check):

Expand Down
Loading