Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support MPI launching through MPICH & variants #672

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
spectrum.
- Added gradient clipping to StaticCapture utilities.
- Bistride Multiscale MeshGraphNet example.
- Support for MPICH-based MPI launching.

### Changed

Expand Down
44 changes: 42 additions & 2 deletions modulus/distributed/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,39 @@ def initialize_open_mpi(addr, port):
method="openmpi",
)

@staticmethod
def initialize_mpich(addr, port):
"""Setup method using MPICH initialization"""
rank = int(os.environ.get("PMI_RANK"))
world_size = int(os.environ.get("PMI_SIZE"))

# cray-mpich
if "PMI_LOCAL_RANK" in os.environ:
local_rank = int(os.environ.get("PMI_LOCAL_RANK"))
# mpich-4.2.1 / hydra
else:
local_rank = int(os.environ.get("MPI_LOCALRANKID"))

# for multi-node MPI jobs, determine "addr" as the
# address of global rank 0.
if "localhost" == addr:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we don't really use Yoda conditions in Modulus code since assignment to var is a syntax error (unless you use walrus op which has a different syntax).

try:
import socket
from mpi4py import MPI
comm = MPI.COMM_WORLD
addr = comm.bcast(socket.gethostbyname(socket.gethostname()), root=0)
except ImportError: pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking that it's ok to silence the exception without any warnings and such.


DistributedManager.setup(
rank=rank,
world_size=world_size,
local_rank=local_rank,
addr=addr,
port=port,
backend=DistributedManager.get_available_backend(),
method="mpich",
)

@staticmethod
def initialize_slurm(port):
"""Setup method using SLURM initialization"""
Expand Down Expand Up @@ -319,6 +352,9 @@ def initialize():
`OPENMPI`: Initialization for OpenMPI launchers.
Uses `OMPI_COMM_WORLD_RANK`, `OMPI_COMM_WORLD_SIZE` and
`OMPI_COMM_WORLD_LOCAL_RANK` environment variables.
`MPICH`: Initialization for MPICH-based MPI launchers.
Uses `PMI_RANK`, `PMI_SIZE` and
either `PMI_LOCAL_RANK` or `MPI_LOCALRANKID` environment variables.

Initialization by default is done using the first valid method in the order
listed above. Initialization method can also be explicitly controlled using the
Expand All @@ -342,9 +378,11 @@ def initialize():
DistributedManager.initialize_slurm(port)
elif "OMPI_COMM_WORLD_RANK" in os.environ:
DistributedManager.initialize_open_mpi(addr, port)
elif "PMI_RANK" in os.environ:
DistributedManager.initialize_mpich(addr, port)
else:
warn(
"Could not initialize using ENV, SLURM or OPENMPI methods. Assuming this is a single process job"
"Could not initialize using ENV, SLURM, OPENMPI or MPICH methods. Assuming this is a single process job"
)
DistributedManager._shared_state["_is_initialized"] = True
elif initialization_method == "ENV":
Expand All @@ -353,13 +391,15 @@ def initialize():
DistributedManager.initialize_slurm(port)
elif initialization_method == "OPENMPI":
DistributedManager.initialize_open_mpi(addr, port)
elif initialization_method == "MPICH":
DistributedManager.initialize_mpich(addr, port)
else:
raise RuntimeError(
"Unknown initialization method "
f"{initialization_method}. "
"Supported values for "
"MODULUS_DISTRIBUTED_INITIALIZATION_METHOD are "
"ENV, SLURM and OPENMPI"
"ENV, SLURM, OPENMPI, and MPICH"
)

# Set per rank numpy random seed for data sampling
Expand Down