Skip to content

Commit

Permalink
Merge branch 'develop' into feat/loong-train-mla
Browse files Browse the repository at this point in the history
  • Loading branch information
BingyangWu committed Jan 2, 2025
2 parents d421847 + d03c6f9 commit 526503b
Show file tree
Hide file tree
Showing 9 changed files with 594 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint_check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
# lint check can be auto-executed by the workflow
lint-check:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3

Expand Down
3 changes: 3 additions & 0 deletions internlm/core/parallel/comm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .attn_offload import get_offload_manager, initialize_offload_manager

__all__ = ["initialize_offload_manager", "get_offload_manager"]
127 changes: 127 additions & 0 deletions internlm/core/parallel/comm/attn_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch

from internlm.utils.common import get_current_device

global_attn_offload = None


class AttnOffloadManager:
"""
A manager for attention output CPU offloading and GPU prefetch loading.
"""

def __init__(self, enable_cpu_offload: bool = False) -> None:
# cpu offload overlapping
self.cpu_offload = enable_cpu_offload
# layer id mapping to flash attn output
self.fa_output_mapping = {}
self.fa_stream = torch.cuda.Stream()
self.d2h_final_event = torch.cuda.Event()
self.h2d_final_event = torch.cuda.Event()
# prepare for tensor buffer
self.tensor_id_to_tensor_bufs = {}

def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id):
"""Get tensor buffer for offloaded tensor."""
layer_id = layer_id % 2
if layer_id not in self.tensor_id_to_tensor_bufs:
self.tensor_id_to_tensor_bufs[layer_id] = {}

if tensor_id not in self.tensor_id_to_tensor_bufs[layer_id]:
allocate_new_buf = True
else:
tensor_buf = self.tensor_id_to_tensor_bufs[layer_id][tensor_id]
allocate_new_buf = tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype

if allocate_new_buf:
# supposed to only execute once
buffer = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
device=tensor.device,
)

self.tensor_id_to_tensor_bufs[layer_id][tensor_id] = buffer

return self.tensor_id_to_tensor_bufs[layer_id][tensor_id]

def insert_fa_output_with_layer(self, layer_idx, output):
assert layer_idx not in self.fa_output_mapping
if self.cpu_offload is False:
self.fa_output_mapping[layer_idx] = output
return

tensors = []
for tensor_id, tensor in enumerate(output):
if tensor is None:
tensors.append(None)
continue
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, layer_idx, tensor_id)
tensor_buf.copy_(tensor)
tensors.append(tensor_buf)
self.fa_output_mapping[layer_idx] = tensors

def get_fa_output_with_layer(self, layer_idx):
assert layer_idx in self.fa_output_mapping
return self.fa_output_mapping.pop(layer_idx)

def offload_fa_output_with_layer(self, layer_idx):
assert layer_idx in self.fa_output_mapping

self.fa_stream.wait_stream(torch.cuda.current_stream())
self.fa_stream.wait_event(self.d2h_final_event)

with torch.cuda.stream(self.fa_stream):
_gpu_tensors = self.fa_output_mapping.pop(layer_idx)
_cpu_tensors = []
for _tensor in _gpu_tensors:
if _tensor is None:
_cpu_tensors.append(_tensor)
continue

_cpu_backup = torch.empty(
_tensor.size(),
dtype=_tensor.dtype,
layout=_tensor.layout,
device="cpu",
pin_memory=True,
)
_cpu_backup.copy_(_tensor, non_blocking=True)
_cpu_tensors.append(_cpu_backup)

# _cpu_tensors.append(_tensor.to("cpu", non_blocking=False))

self.fa_output_mapping[layer_idx] = _cpu_tensors

self.fa_stream.record_event(self.d2h_final_event)

def preload_fa_output_with_layer(self, layer_idx):
assert layer_idx in self.fa_output_mapping

self.fa_stream.wait_stream(torch.cuda.current_stream())
self.fa_stream.wait_event(self.h2d_final_event)

# Important: get device before with stream, in stream get device is error
_device = get_current_device()
with torch.cuda.stream(self.fa_stream):
_cpu_tensors = self.fa_output_mapping.pop(layer_idx)
self.fa_output_mapping[layer_idx] = [
_tensor.to(device=_device, non_blocking=True) if _tensor is not None else _tensor
for _tensor in _cpu_tensors
]

self.fa_stream.record_event(self.h2d_final_event)


def initialize_offload_manager(enable_cpu_offload: bool = False):
global global_attn_offload
if global_attn_offload is None:
global_attn_offload = AttnOffloadManager(enable_cpu_offload)

return global_attn_offload


def get_offload_manager():
assert global_attn_offload is not None
return global_attn_offload
74 changes: 68 additions & 6 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
params_dispatch_with_condition,
)

from .attn_offload import get_offload_manager


# not really useful, only for code hint.
class WPCommunicator(ABC):
Expand Down Expand Up @@ -306,6 +308,8 @@ def __init__(
overlap: bool = False,
process_group: dist.ProcessGroup = None,
is_moe: bool = False,
selective_ckpt_offload: bool = False,
early_reduce_scatter_release: bool = True,
) -> None:
self.process_group = process_group
self.overlap = overlap
Expand All @@ -314,8 +318,21 @@ def __init__(
self.is_forward = True
self.reduce_scatter_handlers = {}
self._forward_prefetch_prerequisites = []
self._zero_const_pool = {}

self._enable_early_reduce_scatter_release = early_reduce_scatter_release
self._early_prev_layer_rs_handles = []
self._early_curr_layer_rs_handles = []
self._forward_overlap_per = self._get_forward_overlap_granularity()
self._launch_before_module = self._get_launch_before_module()
# As an optimization, do not release weight after forward for the last
# transformer block since wp would prefetch it immediately
self.layers_wp_not_release = [] # [gpc.config.isp_num_layers - 1]
self.layers_fa_not_release = [
gpc.config.isp_num_layers - 1,
int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) - 1,
]
self.sc_offload = selective_ckpt_offload

# real overlap state for each chunk.
self._overlap_states: Dict[int, ISPOverlapState] = {}
Expand Down Expand Up @@ -411,6 +428,7 @@ def is_allgather_launch_module(name, module):
self._overlap_states[cid].index_to_isp_modules[idx].append(child)

setattr(child, "isp_name", name)
setattr(child, "isp_layer_idx", idx)

full_name = f"{cid}.{idx}.{name}"
setattr(
Expand Down Expand Up @@ -506,6 +524,25 @@ def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args)
if block_index + 1 < self._num_blocks:
self._all_gather_block_weight(block_index + 1)

# register offload and prefetch hook for selective ckpt with wo linear
if self.sc_offload is True:
# move current layer's attn output from GPU to CPU asynchronizely
if (
self.is_forward is True
and gpc.config.selective_checkpoint
and block_index not in self.layers_fa_not_release
and block_index < self._ckpt_block_num
):
get_offload_manager().offload_fa_output_with_layer(layer_idx=block_index)

# load previous layer's attn output from CPU to GPU asynchronizely
if (
self.is_forward is False
and gpc.config.selective_checkpoint
and (0 <= (block_index - 1) < self._ckpt_block_num)
):
get_offload_manager().preload_fa_output_with_layer(layer_idx=block_index - 1)

def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
if module not in self._weight_global_handle:
self._all_gather_module_weight(module)
Expand Down Expand Up @@ -539,6 +576,9 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis
self._all_gather_module_weight(next_module)

def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
if int(module.isp_layer_idx) in self.layers_wp_not_release:
# print(f"the layer {module.isp_layer_idx} after forward not clear weight")
return
if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False):
self._clear_handle(module)
self._clear_weight(module)
Expand All @@ -561,6 +601,13 @@ def _post_backward_hook_for_module(self, module, *args): # pylint: disable=W061
self._clear_handle(module)
self._clear_weight(module)

def _early_reduce_scatter_release_hook(self, *args): # pylint: disable=W0613
for handle in self._early_prev_layer_rs_handles:
handle.wait()

self._early_prev_layer_rs_handles = self._early_curr_layer_rs_handles
self._early_curr_layer_rs_handles = []

def _register_sync_parameters_hook(self) -> None:
"""
register forward hooks and backward hooks for isp modules.
Expand Down Expand Up @@ -591,12 +638,18 @@ def _register_sync_parameters_hook(self) -> None:
for module in self._isp_modules:
module.register_full_backward_hook(self._post_backward_hook_for_module)

if self._enable_early_reduce_scatter_release:
for block_idx in range(self._num_blocks):
block = self._index_to_block[block_idx]
block.register_full_backward_hook(self._early_reduce_scatter_release_hook)

def _get_constant_zero(self, size: tuple) -> torch.Tensor:
return torch.zeros(
*size,
dtype=self.model_conf.dtype,
device=self.model_conf.device,
).contiguous()
if size not in self._zero_const_pool:
self._zero_const_pool[size] = torch.zeros(
*size, dtype=self.model_conf.dtype, device=self.model_conf.device
).contiguous()

return self._zero_const_pool[size]

def communication_mode(self) -> str:
return "wp"
Expand Down Expand Up @@ -683,13 +736,18 @@ def grad_hook(
assert hasattr(module.weight, "isp_reduce_scatter_name")
key = getattr(module.weight, "isp_reduce_scatter_name")

self.reduce_scatter_handlers[key] = reduce_scatter_raw(
output, handle = reduce_scatter_raw(
tensor,
self.process_group,
op=reduce_op,
async_op=async_op,
)

if self._enable_early_reduce_scatter_release:
self._early_curr_layer_rs_handles.append(handle)

self.reduce_scatter_handlers[key] = (output, handle)

result, handle = (
self._get_constant_zero(
(
Expand Down Expand Up @@ -744,6 +802,10 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06
):
self._zero_optim.reduce_left_grads_after_backward()

if self._isp_communicator and self._isp_communicator._enable_early_reduce_scatter_release:
self._isp_communicator._early_prev_layer_rs_handles = []
self._isp_communicator._early_curr_layer_rs_handles = []

def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613
pass

Expand Down
4 changes: 4 additions & 0 deletions internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from internlm.checkpoint.checkpoint_manager import CheckpointManager
from internlm.core.context import global_context as gpc
from internlm.core.context.process_group_initializer import ParallelMode
from internlm.core.parallel.comm import initialize_offload_manager
from internlm.core.trainer import Trainer
from internlm.data.streaming.utils import streaming_simple_resume
from internlm.data.train_state import get_train_state
Expand Down Expand Up @@ -119,6 +120,9 @@ def __init__(
# initialize isp communicator
isp_communicator = initialize_parallel_communicator(model)

# initialize cpu offload manager for selective checkpoint
initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False))

# initialize train state
train_state = get_train_state(train_dl)

Expand Down
Loading

0 comments on commit 526503b

Please sign in to comment.