Skip to content

Commit

Permalink
clean up, fix total grad norm reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 8, 2025
1 parent 8882a3c commit 533f016
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 41 deletions.
28 changes: 0 additions & 28 deletions src/olmo_core/distributed/parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import math
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed._tensor import DTensor
from torch.distributed.pipelining import PipelineStage
from torch.distributed.pipelining.schedules import (
PipelineScheduleMulti,
Expand Down Expand Up @@ -163,28 +160,3 @@ def step(
else:
self.base_schedule.step()
return None, None

def clip_grad_norm_(
self, max_norm: float, norm_type: float = 2.0, foreach: Optional[bool] = None
) -> torch.Tensor:
parameters = [p for m in self.model_parts for p in m.parameters()]
grads = [p.grad for p in parameters if p.grad is not None]

total_norm = nn.utils.get_total_norm(grads, norm_type, False, True)
if isinstance(total_norm, DTensor):
# Will reach here if PP + other parallelism is used. If only using PP, total_norm will be a local tensor.
# If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
# We can simply reduce the DTensor to get the total norm in this tensor's process group
# and then convert it to a local tensor
total_norm = total_norm.full_tensor()

# TODO: cleanup maybe using DTensor
if math.isinf(norm_type):
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=self.pp_mesh.get_group())
else:
total_norm **= norm_type
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=self.pp_mesh.get_group())
total_norm **= 1.0 / norm_type

torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach=foreach)
return total_norm
2 changes: 1 addition & 1 deletion src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor

from ..exceptions import OLMoEnvironmentError
from ..utils import logging_configured, move_to_device, set_env_var
Expand Down
52 changes: 41 additions & 11 deletions src/olmo_core/train/train_module/transformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import copy
import logging
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Generator, List, Optional, Tuple, cast
Expand All @@ -12,6 +13,7 @@
from torch.distributed import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.pipelining import PipelineStage
from torch.distributed.tensor import DTensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer

Expand Down Expand Up @@ -788,17 +790,7 @@ def eval_batch(
def optim_step(self):
# Maybe clip gradients.
if self.max_grad_norm is not None:
grad_norm: torch.Tensor
if self.train_pp_schedule is None:
assert len(self.model_parts) == 1
model = self.model_parts[0]
if isinstance(model, FSDP):
grad_norm = model.clip_grad_norm_(self.max_grad_norm)
else:
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), self.max_grad_norm)
else:
grad_norm = self.train_pp_schedule.clip_grad_norm_(self.max_grad_norm, foreach=True)

grad_norm = self._clip_grad_norm(self.max_grad_norm)
# NOTE: grad norm is already reduced over ranks, so we set `reduce_type` to `None`.
self.trainer.record_metric(
"total grad norm", grad_norm, reduce_type=None, namespace="optim"
Expand Down Expand Up @@ -966,3 +958,41 @@ def _get_state_dict(self, sd_options: dist_cp_sd.StateDictOptions) -> Dict[str,
for k, v in sd.items()
},
}

def _clip_grad_norm(
self, max_grad_norm: float, norm_type: float = 2.0, foreach: Optional[bool] = None
) -> torch.Tensor:
if not self.pp_enabled and isinstance(self.model_parts[0], FSDP):
return self.model_parts[0].clip_grad_norm_(max_grad_norm)

# Adapted from https://github.com/pytorch/torchtitan/blob/2a4437014e66bcf88a3f0419b816266e6326d539/torchtitan/utils.py#L348

parameters = [p for m in self.model_parts for p in m.parameters()]
grads = [p.grad for p in parameters if p.grad is not None]

total_norm = nn.utils.get_total_norm(
grads, norm_type=norm_type, error_if_nonfinite=False, foreach=foreach
)

# If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
# We can simply reduce the DTensor to get the total norm in this tensor's process group
# and then convert it to a local tensor.
# NOTE: It has two purposes:
# 1. to make sure the total norm is computed correctly when PP is used (see below)
# 2. to return a reduced total_norm tensor whose .item() would return the correct value
if isinstance(total_norm, DTensor):
# Will reach here if any non-PP parallelism is used.
# If only using PP, total_norm will be a local tensor.
total_norm = total_norm.full_tensor()

if self.train_pp_schedule is not None:
pp_mesh = self.train_pp_schedule.pp_mesh
if math.isinf(norm_type):
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
else:
total_norm **= norm_type
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
total_norm **= 1.0 / norm_type

torch.nn.utils.clip_grads_with_norm_(parameters, max_grad_norm, total_norm, foreach=foreach)
return total_norm
2 changes: 1 addition & 1 deletion src/test/nn/transformer/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch
import torch.nn as nn
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed.tensor import DTensor, init_device_mesh

from olmo_core.distributed.checkpoint import (
load_model_and_optim_state,
Expand Down

0 comments on commit 533f016

Please sign in to comment.