Skip to content

Commit

Permalink
fix torch compile log act (#23)
Browse files Browse the repository at this point in the history
* fix renaming logic for key

* fix stuff

* fix exploding norm

* remove print
  • Loading branch information
samsja authored Aug 21, 2024
1 parent 35cd120 commit 8ce08c0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 48 deletions.
20 changes: 11 additions & 9 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@


from open_diloco.utils import (
ActivationNormMetric,
FakeTokenizedDataset,
get_compression_kwargs,
get_sharding_strategy,
register_metrics_hooks,
)


Expand Down Expand Up @@ -353,6 +353,8 @@ def scheduler_fn(opt):
if world_messenger_hv:
max_num_peers = 0

log_activations = {}

for step, batch in enumerate(iterable=train_dataloader, start=start_step * gradient_accumulation_steps):
real_step = (step + 1) // gradient_accumulation_steps
is_accumulating = bool((step + 1) % gradient_accumulation_steps)
Expand All @@ -362,11 +364,9 @@ def scheduler_fn(opt):
)

if logging_activations_steps:
activation_monitor = ActivationNormMetric(
target_layers=TARGET_LAYER_ACTIVATIONS,
gradient_accumulation_steps=gradient_accumulation_steps,
handles = register_metrics_hooks(
model, TARGET_LAYER_ACTIVATIONS, log_activations, gradient_accumulation_steps
)
activation_monitor.register_metrics_hooks(model)

for key in batch.keys():
batch[key] = batch[key].to("cuda")
Expand All @@ -379,6 +379,10 @@ def scheduler_fn(opt):

scaler.scale(loss).backward()

if logging_activations_steps:
for handle in handles:
handle.remove()

if not is_accumulating:
if world_messenger_hv:
scaler.unscale_(optimizer=optimizer.inner_optimizer)
Expand All @@ -400,9 +404,6 @@ def scheduler_fn(opt):
scheduler.step()
optimizer.zero_grad()

if logging_activations_steps:
activation_monitor.remove_hooks()

if config.hv is not None:
if int(real_step) % config.hv.local_steps == 0:
for param in model.parameters():
Expand Down Expand Up @@ -442,7 +443,8 @@ def scheduler_fn(opt):
metrics["num_peers"] = num_peers

if logging_activations_steps:
metrics.update(activation_monitor.log_activations)
metrics.update(log_activations)
log_activations = {}

if world_messenger_hv and num_peers < max_num_peers:
log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}")
Expand Down
66 changes: 27 additions & 39 deletions open_diloco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,73 +10,61 @@
import wandb


_FSDP_WRAPPED_MODULE = ["_forward_module.", "_fsdp_wrapped_module."]
_WRAPPED_NAME_TO_REMOVE = ["_forward_module.", "_fsdp_wrapped_module.", "_orig_mod."]


def _remove_fsdp_prefix(name: str) -> str:
for prefix in _FSDP_WRAPPED_MODULE:
for prefix in _WRAPPED_NAME_TO_REMOVE:
if prefix in name:
return name.replace(prefix, "")
name = name.replace(prefix, "")
return name


@torch.compiler.disable()
@torch.no_grad()
def log_activations_hook(
_mod: torch.nn.Module,
_inp: torch.Tensor,
outp: torch.Tensor | tuple[torch.Tensor, ...],
mod_name: str,
gradient_accumulation_steps: int,
log_activations: dict[str, float],
) -> None:
if isinstance(outp, tuple):
outp = outp[0]

norm = outp.norm(p=2)

norm = outp.norm(p=2) / gradient_accumulation_steps
name = _remove_fsdp_prefix(mod_name)

if f"activation/{name}" not in log_activations:
log_activations[f"activation/{name}"] = norm
else:
log_activations[f"activation/{name}"] += norm


class ActivationNormMetric:
def register_metrics_hooks(
model: torch.nn.Module,
target_layers: list[str],
log_activations: dict[str, torch.Tensor],
gradient_accumulation_steps: int,
) -> list[RemovableHandle]:
"""
This class is used to monitor the norm of the activation of the target layers.
It attached hook to the forward of each layer that will log the output, and remove them after.
this function take a torch module, a list of layer name and apply a hook function that
monitor the output norm of the layers.
"""

def __init__(self, target_layers: list[str], gradient_accumulation_steps: int):
self.target_layers = target_layers
self.handles: list[RemovableHandle] = []
self._log_activations: dict[str, torch.Tensor] = {}
self.gradient_accumulation_steps = gradient_accumulation_steps

def register_metrics_hooks(self, model: torch.nn.Module):
"""
this function take a torch module, a list of layer name and apply a hook function that
monitor the output norm of the layers.
"""
handles = []
for name, mod in model.named_modules():
for layer in self.target_layers:
if name.endswith(layer):
handle = mod.register_forward_hook(
partial(log_activations_hook, log_activations=self._log_activations, mod_name=name)
handles = []
for name, mod in model.named_modules():
for layer in target_layers:
if name.endswith(layer):
handle = mod.register_forward_hook(
partial(
log_activations_hook,
log_activations=log_activations,
mod_name=name,
gradient_accumulation_steps=gradient_accumulation_steps,
)
handles.append(handle)
break

self.handles = handles

def remove_hooks(self) -> None:
for handle in self.handles:
handle.remove()
)
handles.append(handle)

@property
def log_activations(self) -> dict[str, torch.Tensor]:
return {k: v / self.gradient_accumulation_steps for k, v in self._log_activations.items()}
return handles


def _round_str(x: float):
Expand Down

0 comments on commit 8ce08c0

Please sign in to comment.