From ef81a4d9352a4900d145e599057943b71213e7dc Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Fri, 26 Apr 2024 19:23:33 +0000 Subject: [PATCH] Add first round of debugging features --- src/fairseq2/gang.py | 49 +++++++++++++++++-------------- src/fairseq2/nn/utils/module.py | 22 ++++++++++---- src/fairseq2/utils/log.py | 50 +++++++++++++------------------- src/fairseq2/utils/logging.py | 51 ++++++++++++++++++++++----------- 4 files changed, 97 insertions(+), 75 deletions(-) diff --git a/src/fairseq2/gang.py b/src/fairseq2/gang.py index 5288f7aa8..31dd5a24d 100644 --- a/src/fairseq2/gang.py +++ b/src/fairseq2/gang.py @@ -247,7 +247,7 @@ def init_default_process_group( device: Optional[Device] = None, timeout: Optional[timedelta] = None, num_threads: Optional[int] = None, - warn_only: bool = False, + debug: bool = False, ok_initialized: bool = False, ) -> ProcessGroupGang: """Initialize the default process group and wrap it as a gang. @@ -259,9 +259,9 @@ def init_default_process_group( The timeout for collective operations. :param num_threads: The number of threads to use for interaop parallelism. - :param warn_only: - If ``True``, logs a warning instead of raising an error if the gang - is not set up reliably. + :param debug: + If ``True``, turns on additional logging and synchronization checks + to help diagnose distributed training related issues. :param ok_initialized: If ``True``, does not raise an error if the default process group is already initialized. @@ -275,6 +275,11 @@ def init_default_process_group( raise RuntimeError("The default process group is already initialized.") + # Turn on `torch.distributed` debugging. + if debug: + for debug_flag in ["TORCH_CPP_LOG_LEVEL", "TORCH_DISTRBUTED_DEBUG"]: + os.environ[debug_flag] = "INFO" + num_procs = get_local_world_size() if num_threads is None: @@ -303,25 +308,17 @@ def init_default_process_group( ) if device.type == "cuda": + nccl_env_name = "NCCL_ASYNC_ERROR_HANDLING" - def check_async_handling() -> None: - env_name = "NCCL_ASYNC_ERROR_HANDLING" - if env_name in os.environ: - return - - if torch_greater_or_equal(2, 2): - env_name = "TORCH_NCCL_ASYNC_ERROR_HANDLING" - if env_name in os.environ: - return + if torch_greater_or_equal(2, 2): + try: + del os.environ[nccl_env_name] # Suppress the deprecation warning. + except KeyError: + pass - if warn_only: - log.warning("The default process group uses the `nccl` backend, but the `{}` environment variable is not set. Your collective communication calls can hang indefinitely. Learn more at https://github.com/pytorch/pytorch/issues/46874.", env_name) # fmt: skip - else: - raise RuntimeError( - f"The default process group uses the `nccl` backend, but the `{env_name}` environment variable is not set. Learn more at https://github.com/pytorch/pytorch/issues/46874." - ) + nccl_env_name = "TORCH_NCCL_ASYNC_ERROR_HANDLING" - check_async_handling() + os.environ[nccl_env_name] = "1" if timeout is None: timeout = timedelta(minutes=30) @@ -586,7 +583,10 @@ def _get_int_from_env(var_name: str, allow_zero: bool = False) -> Optional[int]: def setup_default_gang( - *, device: Optional[Device] = None, timeout: Optional[timedelta] = None + *, + device: Optional[Device] = None, + timeout: Optional[timedelta] = None, + debug: bool = False, ) -> Gang: """Set up the default gang of this process. @@ -595,11 +595,16 @@ def setup_default_gang( device of the process; otherwise, it will use the CPU. :param timeout: The timeout for collective operations. + :param debug: + If ``True``, turns on additional logging and synchronization checks + to help diagnose distributed training related issues. """ if get_world_size() == 1: return FakeGang(device=device) - return ProcessGroupGang.init_default_process_group(device=device, timeout=timeout) + return ProcessGroupGang.init_default_process_group( + device=device, timeout=timeout, debug=debug + ) def setup_parallel_gangs(root_gang: Gang, *, tp_size: int = 1) -> Dict[str, Gang]: diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index 3b230e569..9cd50d0a5 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -7,6 +7,7 @@ import re from dataclasses import dataclass from itertools import chain +from logging import Logger from typing import ( Any, Callable, @@ -20,6 +21,7 @@ Sequence, Set, Tuple, + Union, runtime_checkable, ) @@ -29,9 +31,17 @@ from torch.nn.utils import remove_weight_norm from fairseq2.typing import CPU, META, Device +from fairseq2.utils.logging import LogWriter + # compat -from fairseq2.utils.log import log_module as log_module # noqa: F401 +def log_module(module: Module, log: Union[LogWriter, Logger]) -> None: + from fairseq2.utils.log import log_model + + if isinstance(log, Logger): + log = LogWriter(log) + + log_model(module, log) @runtime_checkable @@ -298,8 +308,8 @@ def freeze_parameters(module: Optional[Module], value: bool = True) -> None: def select_parameters( module: Module, names: Sequence[str], *, exclude: bool = False ) -> Iterable[Tuple[str, Parameter]]: - """Select the parameters of ``module`` and of its descendant modules whose - name matches ``names``. + """Select the parameters of ``module`` and its descendant modules whose + names match ``names``. :param module: The module to check. @@ -347,7 +357,7 @@ def infer_device( The name of the module for error reporting purposes. :param recurse: If ``True``, infers the device by checking the parameters and buffers of - the descendant modules as well. + descendant modules as well. """ devices = set() @@ -372,7 +382,7 @@ def infer_device( def load_state_dict(module: Module, state_dict: Mapping[str, Any]) -> None: """Copy parameters and buffers from ``state_dict`` into ``module`` and its - descendants. + descendant modules. This implementation internally calls :meth:`Module.load_state_dict()` with ``strict`` set to ``True``, and also enforces that ``state_dict`` does not @@ -476,7 +486,7 @@ class ModuleSizeInfo: def get_module_size(module: Module) -> ModuleSizeInfo: - """Return the size information of ``module`` and its descendants.""" + """Return the size information of ``module`` and its descendant modules.""" info = ModuleSizeInfo() for param in module.parameters(): diff --git a/src/fairseq2/utils/log.py b/src/fairseq2/utils/log.py index 9dafbcf73..b88ab1ca4 100644 --- a/src/fairseq2/utils/log.py +++ b/src/fairseq2/utils/log.py @@ -52,26 +52,25 @@ def log_config(config: Any, log: LogWriter, file: Optional[Path] = None) -> None if file is not None: _dump_dataclass(config, file) - log.info("Config:\n{}", config) - - -def log_environment_info( - log: Union[LogWriter, Logger], device: Optional[Device] = None -) -> None: - """Log information about the software and hardware environments.""" - log_software_info(log, device) - log_hardware_info(log, device) + log.info("Job Config:\n{}", config) # compat # TODO: Keep only LogWriter -def log_software_info( +def log_environment_info( log: Union[LogWriter, Logger], device: Optional[Device] = None ) -> None: - """Log information about the software environment.""" + """Log information about the installed software and the host system.""" if isinstance(log, Logger): log = LogWriter(log) + log_software_info(log, device) + + log_system_info(log, device) + + +def log_software_info(log: LogWriter, device: Optional[Device] = None) -> None: + """Log information about the installed software.""" if not log.is_enabled_for(logging.INFO): return @@ -86,18 +85,11 @@ def log_software_info( s = f"{s} | Intraop Thread Count: {torch.get_num_threads()}" - log.info("Software Info - {}", s) + log.info("Software - {}", s) -# compat -# TODO: Keep only LogWriter -def log_hardware_info( - log: Union[LogWriter, Logger], device: Optional[Device] = None -) -> None: - """Log information about the host and device hardware environments.""" - if isinstance(log, Logger): - log = LogWriter(log) - +def log_system_info(log: LogWriter, device: Optional[Device] = None) -> None: + """Log information about the host system.""" if not log.is_enabled_for(logging.INFO): return @@ -146,7 +138,8 @@ def log_hardware_info( memory = psutil.virtual_memory() s = ( - f"Host: {socket.getfqdn()} | " + f"Name: {socket.getfqdn()} | " + f"PID: {os.getpid()} | " f"Number of CPUs: {cpu_info} | " f"Memory: {memory.total // (1024 * 1024 * 1024):,}GiB" ) @@ -165,22 +158,19 @@ def log_hardware_info( f"Compute Capability: {pr.major}.{pr.minor}" ) - log.info("Hardware Info - {}", s) + log.info("Host System - {}", s) -def log_module(module: Module, log: Union[LogWriter, Logger]) -> None: - """Log information about ``module`` and its descendants.""" +def log_model(model: Module, log: LogWriter) -> None: + """Log information about ``model``.""" # compat # TODO: move to module scope. from fairseq2.nn.utils.module import get_module_size - if isinstance(log, Logger): - log = LogWriter(log) - if not log.is_enabled_for(logging.INFO): return - si = get_module_size(module) + si = get_module_size(model) s = ( f"Parameter Size: {si.param_size:,} | " @@ -193,4 +183,4 @@ def log_module(module: Module, log: Union[LogWriter, Logger]) -> None: f"Total Size (bytes): {si.total_size_bytes:,}" ) - log.info("Module Info - {} | Layout:\n{}", s, module) + log.info("Model - {} | Layout:\n{}", s, model) diff --git a/src/fairseq2/utils/logging.py b/src/fairseq2/utils/logging.py index 70f6b1df1..eb2444af9 100644 --- a/src/fairseq2/utils/logging.py +++ b/src/fairseq2/utils/logging.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging +import os import time from logging import ( DEBUG, @@ -21,7 +22,7 @@ def setup_logging( - log_file: Optional[Path] = None, + log_file: Path, *, debug: bool = False, utc_time: bool = False, @@ -43,26 +44,23 @@ def setup_logging( rank = get_rank() - handlers: List[Handler] = [StreamHandler()] # Log to stderr. + filename = log_file.name.format(rank=rank) - if log_file is not None: - filename = log_file.name.format(rank=rank) + if filename == log_file.name: + raise ValueError( + f"`log_file` must contain a 'rank' replacement field (i.e. {{rank}}) in its filename, but is '{log_file}' instead." + ) - if filename == log_file.name: - raise ValueError( - f"`log_file` must contain a 'rank' replacement field (i.e. {{rank}}) in its filename, but is '{log_file}' instead." - ) + log_file = log_file.with_name(filename) - try: - log_file.parent.mkdir(parents=True, exist_ok=True) - except OSError as ex: - raise RuntimeError( - f"The log directory ({log_file.parent}) cannot be created. See nested exception for details." - ) from ex + try: + log_file.parent.mkdir(parents=True, exist_ok=True) + except OSError as ex: + raise RuntimeError( + f"The log directory ({log_file.parent}) cannot be created. See nested exception for details." + ) from ex - handler = FileHandler(log_file.with_name(filename)) - - handlers.append(handler) # Log to file. + handlers: List[Handler] = [StreamHandler(), FileHandler(log_file)] fmt = f"[Rank {rank}] %(asctime)s %(levelname)s %(name)s - %(message)s" @@ -79,6 +77,25 @@ def setup_logging( if utc_time: Formatter.converter = time.gmtime + _setup_nccl_logging(log_file, force) + + +def _setup_nccl_logging(log_file: Path, force: bool) -> None: + if "NCCL_DEBUG" in os.environ and not force: + return + + nccl_log_file = log_file.parent.joinpath("nccl", log_file.name) + + try: + nccl_log_file.parent.mkdir(parents=True, exist_ok=True) + except OSError as ex: + raise RuntimeError( + f"The NCCL log directory ({nccl_log_file.parent}) cannot be created. See nested exception for details." + ) from ex + + os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_DEBUG_FILE"] = str(nccl_log_file) + @final class LogWriter: