Skip to content

Commit

Permalink
Add first round of debugging features
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Apr 26, 2024
1 parent 56b7d49 commit 4926ae5
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 75 deletions.
49 changes: 27 additions & 22 deletions src/fairseq2/gang.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def init_default_process_group(
device: Optional[Device] = None,
timeout: Optional[timedelta] = None,
num_threads: Optional[int] = None,
warn_only: bool = False,
debug: bool = False,
ok_initialized: bool = False,
) -> ProcessGroupGang:
"""Initialize the default process group and wrap it as a gang.
Expand All @@ -259,9 +259,9 @@ def init_default_process_group(
The timeout for collective operations.
:param num_threads:
The number of threads to use for interaop parallelism.
:param warn_only:
If ``True``, logs a warning instead of raising an error if the gang
is not set up reliably.
:param debug:
If ``True``, turns on additional logging and synchronization checks
to help diagnose distributed training related issues.
:param ok_initialized:
If ``True``, does not raise an error if the default process group is
already initialized.
Expand All @@ -275,6 +275,11 @@ def init_default_process_group(

raise RuntimeError("The default process group is already initialized.")

# Turn on `torch.distributed` debugging.
if debug:
for debug_flag in ["TORCH_CPP_LOG_LEVEL", "TORCH_DISTRBUTED_DEBUG"]:
os.environ[debug_flag] = "INFO"

num_procs = get_local_world_size()

if num_threads is None:
Expand Down Expand Up @@ -303,25 +308,17 @@ def init_default_process_group(
)

if device.type == "cuda":
nccl_env_name = "NCCL_ASYNC_ERROR_HANDLING"

def check_async_handling() -> None:
env_name = "NCCL_ASYNC_ERROR_HANDLING"
if env_name in os.environ:
return

if torch_greater_or_equal(2, 2):
env_name = "TORCH_NCCL_ASYNC_ERROR_HANDLING"
if env_name in os.environ:
return
if torch_greater_or_equal(2, 2):
try:
del os.environ[nccl_env_name] # Suppress the deprecation warning.
except KeyError:
pass

if warn_only:
log.warning("The default process group uses the `nccl` backend, but the `{}` environment variable is not set. Your collective communication calls can hang indefinitely. Learn more at https://github.com/pytorch/pytorch/issues/46874.", env_name) # fmt: skip
else:
raise RuntimeError(
f"The default process group uses the `nccl` backend, but the `{env_name}` environment variable is not set. Learn more at https://github.com/pytorch/pytorch/issues/46874."
)
nccl_env_name = "TORCH_NCCL_ASYNC_ERROR_HANDLING"

check_async_handling()
os.environ[nccl_env_name] = "1"

if timeout is None:
timeout = timedelta(minutes=30)
Expand Down Expand Up @@ -586,7 +583,10 @@ def _get_int_from_env(var_name: str, allow_zero: bool = False) -> Optional[int]:


def setup_default_gang(
*, device: Optional[Device] = None, timeout: Optional[timedelta] = None
*,
device: Optional[Device] = None,
timeout: Optional[timedelta] = None,
debug: bool = False,
) -> Gang:
"""Set up the default gang of this process.
Expand All @@ -595,11 +595,16 @@ def setup_default_gang(
device of the process; otherwise, it will use the CPU.
:param timeout:
The timeout for collective operations.
:param debug:
If ``True``, turns on additional logging and synchronization checks
to help diagnose distributed training related issues.
"""
if get_world_size() == 1:
return FakeGang(device=device)

return ProcessGroupGang.init_default_process_group(device=device, timeout=timeout)
return ProcessGroupGang.init_default_process_group(
device=device, timeout=timeout, debug=debug
)


def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Dict[str, Gang]:
Expand Down
22 changes: 16 additions & 6 deletions src/fairseq2/nn/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
from dataclasses import dataclass
from itertools import chain
from logging import Logger
from typing import (
Any,
Callable,
Expand All @@ -20,6 +21,7 @@
Sequence,
Set,
Tuple,
Union,
runtime_checkable,
)

Expand All @@ -29,9 +31,17 @@
from torch.nn.utils import remove_weight_norm

from fairseq2.typing import CPU, META, Device
from fairseq2.utils.logging import LogWriter


# compat
from fairseq2.utils.log import log_module as log_module # noqa: F401
def log_module(module: Module, log: Union[LogWriter, Logger]) -> None:
from fairseq2.utils.log import log_model

if isinstance(log, Logger):
log = LogWriter(log)

log_model(module, log)


@runtime_checkable
Expand Down Expand Up @@ -298,8 +308,8 @@ def freeze_parameters(module: Optional[Module], value: bool = True) -> None:
def select_parameters(
module: Module, names: Sequence[str], *, exclude: bool = False
) -> Iterable[Tuple[str, Parameter]]:
"""Select the parameters of ``module`` and of its descendant modules whose
name matches ``names``.
"""Select the parameters of ``module`` and its descendant modules whose
names match ``names``.
:param module:
The module to check.
Expand Down Expand Up @@ -347,7 +357,7 @@ def infer_device(
The name of the module for error reporting purposes.
:param recurse:
If ``True``, infers the device by checking the parameters and buffers of
the descendant modules as well.
descendant modules as well.
"""
devices = set()

Expand All @@ -372,7 +382,7 @@ def infer_device(

def load_state_dict(module: Module, state_dict: Mapping[str, Any]) -> None:
"""Copy parameters and buffers from ``state_dict`` into ``module`` and its
descendants.
descendant modules.
This implementation internally calls :meth:`Module.load_state_dict()` with
``strict`` set to ``True``, and also enforces that ``state_dict`` does not
Expand Down Expand Up @@ -476,7 +486,7 @@ class ModuleSizeInfo:


def get_module_size(module: Module) -> ModuleSizeInfo:
"""Return the size information of ``module`` and its descendants."""
"""Return the size information of ``module`` and its descendant modules."""
info = ModuleSizeInfo()

for param in module.parameters():
Expand Down
50 changes: 20 additions & 30 deletions src/fairseq2/utils/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,25 @@ def log_config(config: Any, log: LogWriter, file: Optional[Path] = None) -> None
if file is not None:
_dump_dataclass(config, file)

log.info("Config:\n{}", config)


def log_environment_info(
log: Union[LogWriter, Logger], device: Optional[Device] = None
) -> None:
"""Log information about the software and hardware environments."""
log_software_info(log, device)
log_hardware_info(log, device)
log.info("Job Config:\n{}", config)


# compat
# TODO: Keep only LogWriter
def log_software_info(
def log_environment_info(
log: Union[LogWriter, Logger], device: Optional[Device] = None
) -> None:
"""Log information about the software environment."""
"""Log information about the installed software and the host system."""
if isinstance(log, Logger):
log = LogWriter(log)

log_software_info(log, device)

log_system_info(log, device)


def log_software_info(log: LogWriter, device: Optional[Device] = None) -> None:
"""Log information about the installed software."""
if not log.is_enabled_for(logging.INFO):
return

Expand All @@ -86,18 +85,11 @@ def log_software_info(

s = f"{s} | Intraop Thread Count: {torch.get_num_threads()}"

log.info("Software Info - {}", s)
log.info("Software - {}", s)


# compat
# TODO: Keep only LogWriter
def log_hardware_info(
log: Union[LogWriter, Logger], device: Optional[Device] = None
) -> None:
"""Log information about the host and device hardware environments."""
if isinstance(log, Logger):
log = LogWriter(log)

def log_system_info(log: LogWriter, device: Optional[Device] = None) -> None:
"""Log information about the host system."""
if not log.is_enabled_for(logging.INFO):
return

Expand Down Expand Up @@ -146,7 +138,8 @@ def log_hardware_info(
memory = psutil.virtual_memory()

s = (
f"Host: {socket.getfqdn()} | "
f"Name: {socket.getfqdn()} | "
f"PID: {os.getpid()} | "
f"Number of CPUs: {cpu_info} | "
f"Memory: {memory.total // (1024 * 1024 * 1024):,}GiB"
)
Expand All @@ -165,22 +158,19 @@ def log_hardware_info(
f"Compute Capability: {pr.major}.{pr.minor}"
)

log.info("Hardware Info - {}", s)
log.info("Host System - {}", s)


def log_module(module: Module, log: Union[LogWriter, Logger]) -> None:
"""Log information about ``module`` and its descendants."""
def log_model(model: Module, log: LogWriter) -> None:
"""Log information about ``model``."""
# compat
# TODO: move to module scope.
from fairseq2.nn.utils.module import get_module_size

if isinstance(log, Logger):
log = LogWriter(log)

if not log.is_enabled_for(logging.INFO):
return

si = get_module_size(module)
si = get_module_size(model)

s = (
f"Parameter Size: {si.param_size:,} | "
Expand All @@ -193,4 +183,4 @@ def log_module(module: Module, log: Union[LogWriter, Logger]) -> None:
f"Total Size (bytes): {si.total_size_bytes:,}"
)

log.info("Module Info - {} | Layout:\n{}", s, module)
log.info("Model - {} | Layout:\n{}", s, model)
51 changes: 34 additions & 17 deletions src/fairseq2/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
import os
import time
from logging import (
DEBUG,
Expand All @@ -21,7 +22,7 @@


def setup_logging(
log_file: Optional[Path] = None,
log_file: Path,
*,
debug: bool = False,
utc_time: bool = False,
Expand All @@ -43,26 +44,23 @@ def setup_logging(

rank = get_rank()

handlers: List[Handler] = [StreamHandler()] # Log to stderr.
filename = log_file.name.format(rank=rank)

if log_file is not None:
filename = log_file.name.format(rank=rank)
if filename == log_file.name:
raise ValueError(
f"`log_file` must contain a 'rank' replacement field (i.e. {{rank}}) in its filename, but is '{log_file}' instead."
)

if filename == log_file.name:
raise ValueError(
f"`log_file` must contain a 'rank' replacement field (i.e. {{rank}}) in its filename, but is '{log_file}' instead."
)
log_file = log_file.with_name(filename)

try:
log_file.parent.mkdir(parents=True, exist_ok=True)
except OSError as ex:
raise RuntimeError(
f"The log directory ({log_file.parent}) cannot be created. See nested exception for details."
) from ex
try:
log_file.parent.mkdir(parents=True, exist_ok=True)
except OSError as ex:
raise RuntimeError(
f"The log directory ({log_file.parent}) cannot be created. See nested exception for details."
) from ex

handler = FileHandler(log_file.with_name(filename))

handlers.append(handler) # Log to file.
handlers: List[Handler] = [StreamHandler(), FileHandler(log_file)]

fmt = f"[Rank {rank}] %(asctime)s %(levelname)s %(name)s - %(message)s"

Expand All @@ -79,6 +77,25 @@ def setup_logging(
if utc_time:
Formatter.converter = time.gmtime

_setup_nccl_logging(log_file, force)


def _setup_nccl_logging(log_file: Path, force: bool) -> None:
if "NCCL_DEBUG" in os.environ and not force:
return

nccl_log_file = log_file.parent.joinpath("nccl", log_file.name)

try:
nccl_log_file.parent.mkdir(parents=True, exist_ok=True)
except OSError as ex:
raise RuntimeError(
f"The log directory ({nccl_log_file.parent}) cannot be created. See nested exception for details."
) from ex

os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_FILE"] = str(nccl_log_file)


@final
class LogWriter:
Expand Down

0 comments on commit 4926ae5

Please sign in to comment.