diff --git a/.github/workflows/lint_check.yaml b/.github/workflows/lint_check.yaml index ab1a532e..fe86bd05 100644 --- a/.github/workflows/lint_check.yaml +++ b/.github/workflows/lint_check.yaml @@ -10,7 +10,7 @@ on: jobs: # lint check can be auto-executed by the workflow lint-check: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v3 diff --git a/internlm/core/parallel/comm/__init__.py b/internlm/core/parallel/comm/__init__.py index e69de29b..be170f28 100644 --- a/internlm/core/parallel/comm/__init__.py +++ b/internlm/core/parallel/comm/__init__.py @@ -0,0 +1,3 @@ +from .attn_offload import get_offload_manager, initialize_offload_manager + +__all__ = ["initialize_offload_manager", "get_offload_manager"] diff --git a/internlm/core/parallel/comm/attn_offload.py b/internlm/core/parallel/comm/attn_offload.py new file mode 100644 index 00000000..da23f3ae --- /dev/null +++ b/internlm/core/parallel/comm/attn_offload.py @@ -0,0 +1,127 @@ +import torch + +from internlm.utils.common import get_current_device + +global_attn_offload = None + + +class AttnOffloadManager: + """ + A manager for attention output CPU offloading and GPU prefetch loading. + """ + + def __init__(self, enable_cpu_offload: bool = False) -> None: + # cpu offload overlapping + self.cpu_offload = enable_cpu_offload + # layer id mapping to flash attn output + self.fa_output_mapping = {} + self.fa_stream = torch.cuda.Stream() + self.d2h_final_event = torch.cuda.Event() + self.h2d_final_event = torch.cuda.Event() + # prepare for tensor buffer + self.tensor_id_to_tensor_bufs = {} + + def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id): + """Get tensor buffer for offloaded tensor.""" + layer_id = layer_id % 2 + if layer_id not in self.tensor_id_to_tensor_bufs: + self.tensor_id_to_tensor_bufs[layer_id] = {} + + if tensor_id not in self.tensor_id_to_tensor_bufs[layer_id]: + allocate_new_buf = True + else: + tensor_buf = self.tensor_id_to_tensor_bufs[layer_id][tensor_id] + allocate_new_buf = tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype + + if allocate_new_buf: + # supposed to only execute once + buffer = torch.empty( + tensor.size(), + dtype=tensor.dtype, + layout=tensor.layout, + device=tensor.device, + ) + + self.tensor_id_to_tensor_bufs[layer_id][tensor_id] = buffer + + return self.tensor_id_to_tensor_bufs[layer_id][tensor_id] + + def insert_fa_output_with_layer(self, layer_idx, output): + assert layer_idx not in self.fa_output_mapping + if self.cpu_offload is False: + self.fa_output_mapping[layer_idx] = output + return + + tensors = [] + for tensor_id, tensor in enumerate(output): + if tensor is None: + tensors.append(None) + continue + tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, layer_idx, tensor_id) + tensor_buf.copy_(tensor) + tensors.append(tensor_buf) + self.fa_output_mapping[layer_idx] = tensors + + def get_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + return self.fa_output_mapping.pop(layer_idx) + + def offload_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + + self.fa_stream.wait_stream(torch.cuda.current_stream()) + self.fa_stream.wait_event(self.d2h_final_event) + + with torch.cuda.stream(self.fa_stream): + _gpu_tensors = self.fa_output_mapping.pop(layer_idx) + _cpu_tensors = [] + for _tensor in _gpu_tensors: + if _tensor is None: + _cpu_tensors.append(_tensor) + continue + + _cpu_backup = torch.empty( + _tensor.size(), + dtype=_tensor.dtype, + layout=_tensor.layout, + device="cpu", + pin_memory=True, + ) + _cpu_backup.copy_(_tensor, non_blocking=True) + _cpu_tensors.append(_cpu_backup) + + # _cpu_tensors.append(_tensor.to("cpu", non_blocking=False)) + + self.fa_output_mapping[layer_idx] = _cpu_tensors + + self.fa_stream.record_event(self.d2h_final_event) + + def preload_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + + self.fa_stream.wait_stream(torch.cuda.current_stream()) + self.fa_stream.wait_event(self.h2d_final_event) + + # Important: get device before with stream, in stream get device is error + _device = get_current_device() + with torch.cuda.stream(self.fa_stream): + _cpu_tensors = self.fa_output_mapping.pop(layer_idx) + self.fa_output_mapping[layer_idx] = [ + _tensor.to(device=_device, non_blocking=True) if _tensor is not None else _tensor + for _tensor in _cpu_tensors + ] + + self.fa_stream.record_event(self.h2d_final_event) + + +def initialize_offload_manager(enable_cpu_offload: bool = False): + global global_attn_offload + if global_attn_offload is None: + global_attn_offload = AttnOffloadManager(enable_cpu_offload) + + return global_attn_offload + + +def get_offload_manager(): + assert global_attn_offload is not None + return global_attn_offload diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 7e722c2f..23a92980 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -37,6 +37,8 @@ params_dispatch_with_condition, ) +from .attn_offload import get_offload_manager + # not really useful, only for code hint. class WPCommunicator(ABC): @@ -306,6 +308,8 @@ def __init__( overlap: bool = False, process_group: dist.ProcessGroup = None, is_moe: bool = False, + selective_ckpt_offload: bool = False, + early_reduce_scatter_release: bool = True, ) -> None: self.process_group = process_group self.overlap = overlap @@ -314,8 +318,21 @@ def __init__( self.is_forward = True self.reduce_scatter_handlers = {} self._forward_prefetch_prerequisites = [] + self._zero_const_pool = {} + + self._enable_early_reduce_scatter_release = early_reduce_scatter_release + self._early_prev_layer_rs_handles = [] + self._early_curr_layer_rs_handles = [] self._forward_overlap_per = self._get_forward_overlap_granularity() self._launch_before_module = self._get_launch_before_module() + # As an optimization, do not release weight after forward for the last + # transformer block since wp would prefetch it immediately + self.layers_wp_not_release = [] # [gpc.config.isp_num_layers - 1] + self.layers_fa_not_release = [ + gpc.config.isp_num_layers - 1, + int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) - 1, + ] + self.sc_offload = selective_ckpt_offload # real overlap state for each chunk. self._overlap_states: Dict[int, ISPOverlapState] = {} @@ -411,6 +428,7 @@ def is_allgather_launch_module(name, module): self._overlap_states[cid].index_to_isp_modules[idx].append(child) setattr(child, "isp_name", name) + setattr(child, "isp_layer_idx", idx) full_name = f"{cid}.{idx}.{name}" setattr( @@ -506,6 +524,25 @@ def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args) if block_index + 1 < self._num_blocks: self._all_gather_block_weight(block_index + 1) + # register offload and prefetch hook for selective ckpt with wo linear + if self.sc_offload is True: + # move current layer's attn output from GPU to CPU asynchronizely + if ( + self.is_forward is True + and gpc.config.selective_checkpoint + and block_index not in self.layers_fa_not_release + and block_index < self._ckpt_block_num + ): + get_offload_manager().offload_fa_output_with_layer(layer_idx=block_index) + + # load previous layer's attn output from CPU to GPU asynchronizely + if ( + self.is_forward is False + and gpc.config.selective_checkpoint + and (0 <= (block_index - 1) < self._ckpt_block_num) + ): + get_offload_manager().preload_fa_output_with_layer(layer_idx=block_index - 1) + def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 if module not in self._weight_global_handle: self._all_gather_module_weight(module) @@ -539,6 +576,9 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis self._all_gather_module_weight(next_module) def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 + if int(module.isp_layer_idx) in self.layers_wp_not_release: + # print(f"the layer {module.isp_layer_idx} after forward not clear weight") + return if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False): self._clear_handle(module) self._clear_weight(module) @@ -561,6 +601,13 @@ def _post_backward_hook_for_module(self, module, *args): # pylint: disable=W061 self._clear_handle(module) self._clear_weight(module) + def _early_reduce_scatter_release_hook(self, *args): # pylint: disable=W0613 + for handle in self._early_prev_layer_rs_handles: + handle.wait() + + self._early_prev_layer_rs_handles = self._early_curr_layer_rs_handles + self._early_curr_layer_rs_handles = [] + def _register_sync_parameters_hook(self) -> None: """ register forward hooks and backward hooks for isp modules. @@ -591,12 +638,18 @@ def _register_sync_parameters_hook(self) -> None: for module in self._isp_modules: module.register_full_backward_hook(self._post_backward_hook_for_module) + if self._enable_early_reduce_scatter_release: + for block_idx in range(self._num_blocks): + block = self._index_to_block[block_idx] + block.register_full_backward_hook(self._early_reduce_scatter_release_hook) + def _get_constant_zero(self, size: tuple) -> torch.Tensor: - return torch.zeros( - *size, - dtype=self.model_conf.dtype, - device=self.model_conf.device, - ).contiguous() + if size not in self._zero_const_pool: + self._zero_const_pool[size] = torch.zeros( + *size, dtype=self.model_conf.dtype, device=self.model_conf.device + ).contiguous() + + return self._zero_const_pool[size] def communication_mode(self) -> str: return "wp" @@ -683,13 +736,18 @@ def grad_hook( assert hasattr(module.weight, "isp_reduce_scatter_name") key = getattr(module.weight, "isp_reduce_scatter_name") - self.reduce_scatter_handlers[key] = reduce_scatter_raw( + output, handle = reduce_scatter_raw( tensor, self.process_group, op=reduce_op, async_op=async_op, ) + if self._enable_early_reduce_scatter_release: + self._early_curr_layer_rs_handles.append(handle) + + self.reduce_scatter_handlers[key] = (output, handle) + result, handle = ( self._get_constant_zero( ( @@ -744,6 +802,10 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06 ): self._zero_optim.reduce_left_grads_after_backward() + if self._isp_communicator and self._isp_communicator._enable_early_reduce_scatter_release: + self._isp_communicator._early_prev_layer_rs_handles = [] + self._isp_communicator._early_curr_layer_rs_handles = [] + def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613 pass diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index dfe64373..f7c182a3 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -11,6 +11,7 @@ from internlm.checkpoint.checkpoint_manager import CheckpointManager from internlm.core.context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode +from internlm.core.parallel.comm import initialize_offload_manager from internlm.core.trainer import Trainer from internlm.data.streaming.utils import streaming_simple_resume from internlm.data.train_state import get_train_state @@ -119,6 +120,9 @@ def __init__( # initialize isp communicator isp_communicator = initialize_parallel_communicator(model) + # initialize cpu offload manager for selective checkpoint + initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) + # initialize train state train_state = get_train_state(train_dl) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index c8b16516..d6038d18 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -66,6 +66,8 @@ def get_default_parser(): def args_sanity_check(): assert gpc.config is not None, "config is not load!" + gpc.is_forward = True + if "JOB_NAME" not in gpc.config: gpc.config._add_item("JOB_NAME", "AnonymousJob") @@ -73,6 +75,13 @@ def args_sanity_check(): if "model_type" not in gpc.config: gpc.config._add_item("model_type", ModelType.INTERNLM.name) + if gpc.config.model_type == "InternLM3_M": + # TODO: need check for isp overlap + num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers + else: + num_layers = gpc.config.model.num_layers + gpc.config.isp_num_layers = num_layers + if "use_apex_adam" not in gpc.config: gpc.config._add_item("use_apex_adam", False) @@ -388,17 +397,18 @@ def args_sanity_check(): gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name) if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name - assert ( - gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0 - ), "VOCAB_SIZE must be integer multiple of tensor parallel size" if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name: assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp" assert ( torch.__version__ >= "2.1.0" ), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}" - assert ( - gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0 - ), "VOCAB_SIZE must be integer multiple of wp size" + + assert ( + gpc.config.model.vocab_size % gpc.config.parallel.weight.size == 0 + ), "model.vocab_size must be integer multiple of weight parallel size" + assert ( + gpc.config.model.vocab_size % gpc.config.parallel.tensor.size == 0 + ), "model.vocab_size must be integer multiple of tensor parallel size" assert gpc.config.parallel["tensor"].get("mode", None) in [ TensorParallelMode.mtp.name, @@ -446,11 +456,19 @@ def args_sanity_check(): gpc.config.parallel["weight"]["overlap"] = False if gpc.config.parallel["tensor"]["mode"] != TensorParallelMode.isp.name: assert gpc.config.parallel["weight"]["size"] <= 1, "weight parallel is only supported with isp" + + if "early_reduce_scatter_release" not in gpc.config.parallel.weight: + gpc.config.parallel.weight["early_reduce_scatter_release"] = True + # set default value for expert_weight parallel if gpc.config.parallel["expert_weight"].get("overlap", None) is None: gpc.config.parallel["expert_weight"]["overlap"] = False if gpc.config.parallel["expert"].get("no_tp", None) is None: gpc.config.parallel["expert"]["no_tp"] = False + + if "early_reduce_scatter_release" not in gpc.config.parallel.expert_weight: + gpc.config.parallel.expert_weight["early_reduce_scatter_release"] = True + # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: assert ( @@ -524,7 +542,20 @@ def args_sanity_check(): gpc.config.loss._add_item("moe_loss_coeff", 1.0) if "selective_checkpoint" not in gpc.config: - gpc.config._add_item("selective_checkpoint", False) + gpc.config.selective_checkpoint = False + if "selective_checkpoint_offload" not in gpc.config: + gpc.config.selective_checkpoint_offload = False + if gpc.config.selective_checkpoint is True: + assert ( + gpc.config.parallel["tensor"]["mode"] == "isp" + ), "When using selective_checkpoint, tensor parallel mode must be isp" + if gpc.config.selective_checkpoint_offload is True: + assert ( + gpc.config.selective_checkpoint is True + ), "When using selective_checkpoint_offload, selective_checkpoint must be True" + assert ( + gpc.config.parallel.weight.launch_allgather_before == "wo" + ), "When using selective_checkpoint_offload, wp launch allgather communication should be set before 'wo' module" # moe not support overlap and zero1.5 for now if gpc.config.model.get("num_experts", 1) > 1: diff --git a/internlm/model/ops/_flash_attn.py b/internlm/model/ops/_flash_attn.py new file mode 100644 index 00000000..87aac2eb --- /dev/null +++ b/internlm/model/ops/_flash_attn.py @@ -0,0 +1,331 @@ +# Copyright (c) InternLM. All rights reserved. +import torch + +from internlm.accelerator import get_accelerator +from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm import get_offload_manager + +try: + import flash_attn + from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_backward, + _flash_attn_varlen_forward, + ) + + gpu_flash_attn_impl = True +except (ModuleNotFoundError, ImportError): + gpu_flash_attn_impl = False + +internlm_accelerator = get_accelerator() +device_backend = internlm_accelerator.get_accelerator_backend() + + +class FlashAttnVarlenKVPackedFunc_V263(torch.autograd.Function): + """ + Varlen KVPacked Func from Flash Attn v2.6.3. + """ + + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + layer_idx, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + k, v = kv[:, 0], kv[:, 1] + + _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + + if gpc.is_forward is False and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + out, out_padded, softmax_lse, S_dmask, rng_state = get_offload_manager().get_fa_output_with_layer(layer_idx) + else: + ( + out, + q, + k, + v, + out_padded, + softmax_lse, + S_dmask, + rng_state, + ) = _flash_attn_varlen_forward( # pylint: disable=E1123 + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + block_table=None, + ) + + # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled + if gpc.is_forward and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + get_offload_manager().insert_fa_output_with_layer( + layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) + ) + + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): # pylint: disable=W0613 + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( # pylint: disable=E1121,E1124 + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenKVPackedFunc_V221(torch.autograd.Function): + """ + Varlen KVPacked Func from Flash Attn v2.2.1. + """ + + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, + layer_idx, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + k, v = kv[:, 0], kv[:, 1] + + _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + + if gpc.is_forward is False and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + out, out_padded, softmax_lse, S_dmask, rng_state = get_offload_manager().get_fa_output_with_layer(layer_idx) + else: + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + return_softmax=return_softmax and dropout_p > 0, + ) + + # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled + if gpc.is_forward and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + get_offload_manager().insert_fa_output_with_layer( + layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): # pylint: disable=W0613 + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + layer_idx=0, +): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + + assert gpu_flash_attn_impl is True and flash_attn.__version__ in [ + "2.2.1", + "2.6.3", + ], "flash-attn should be installed and version must be v2.2.1 or v2.6.3" + + if flash_attn.__version__ == "2.2.1": + return FlashAttnVarlenKVPackedFunc_V221.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_attn_probs, + layer_idx, + ) + + return FlashAttnVarlenKVPackedFunc_V263.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + layer_idx, + ) diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index 604ea77a..3aec51f5 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -93,13 +93,14 @@ from flash_attn.flash_attn_interface import ( flash_attn_varlen_func as _flash_varlen_qkvsplited_func, ) - from flash_attn.flash_attn_interface import ( - flash_attn_varlen_kvpacked_func as _flash_varlen_kvpacked_func, - ) from flash_attn.flash_attn_interface import ( flash_attn_varlen_qkvpacked_func as _flash_varlen_qkvpacked_func, ) + from ._flash_attn import ( + flash_attn_varlen_kvpacked_func as _flash_varlen_kvpacked_func, + ) + gpu_flash_attn_impl = True except (ModuleNotFoundError, ImportError): gpu_flash_attn_impl = False @@ -187,6 +188,7 @@ def _flash_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, ): # compatible data format: [1, packelen, 3, n_head, headim] q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) @@ -204,6 +206,7 @@ def _flash_varlen_kvpacked_attn( dropout_p, softmax_scale, causal, + layer_idx=layer_idx, ) return output.unsqueeze(dim=0) @@ -521,6 +524,7 @@ def _npu_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, # pylint: disable=W0613 ): # TODO: support npu native varlen flash attention k, v = kv.unbind(dim=2) @@ -579,6 +583,7 @@ def _deeplink_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, # pylint: disable=W0613 ): # compatible data format: [1, packelen, 3, n_head, headim] q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) @@ -1012,7 +1017,17 @@ def _q_kv_with_cu_seqlens( extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( - q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout, + softmax_scale, + causal, + *extra_args, + layer_idx=self.layer_idx, ) @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.With))) diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 79e9caf4..784a5305 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -363,6 +363,8 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.weight.overlap, gpc.get_group(ParallelMode.WEIGHT), is_moe=False, + selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False), + early_reduce_scatter_release=gpc.config.parallel.weight.early_reduce_scatter_release, ) # register communicator for isp column parallel linear. ColumnParallelLinear.register_cls_communicator(isp_communicator) @@ -388,6 +390,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.expert_weight.overlap, gpc.get_group(ParallelMode.EXPERT_WEIGHT), is_moe=True, + early_reduce_scatter_release=gpc.config.parallel.expert_weight.early_reduce_scatter_release, ) for moe in _submodule_filter(model, Experts): for column_linear in _submodule_filter(moe, (ColumnParallelLinear, GroupedWPLinear)):