From b6429c06b9f370eb00cf10c20bd7c1a9ce163762 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Wed, 11 Dec 2024 18:02:23 +0800 Subject: [PATCH 01/24] remove rr deep copy --- paddlenlp/transformers/llama/modeling.py | 34 +++++++++++---------- paddlenlp/transformers/qwen/modeling.py | 30 +++++++++--------- paddlenlp/transformers/qwen2/modeling.py | 33 ++++++++++---------- paddlenlp/transformers/refined_recompute.py | 12 ++++---- 4 files changed, 56 insertions(+), 53 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 8bf0d5938902..3ee8509382b6 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -605,7 +605,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class LlamaMLP(nn.Layer): - def __init__(self, config): + def __init__(self, config, layer_idx: int = 0): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size @@ -618,9 +618,9 @@ def __init__(self, config): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("mlp_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops.get("mlp_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear @@ -628,9 +628,9 @@ def __init__(self, config): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("mlp_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops.get("mlp_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -682,7 +682,7 @@ def forward(self, x): class LlamaAttention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): + def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, layer_idx: int = 0): super().__init__() self.config = config @@ -746,18 +746,18 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("attention_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops.get("attention_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("attention_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops.get("attention_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -862,7 +862,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): if ( config.recompute and not config.recompute_use_reentrant - and config.skip_recompute_ops.get("flash_attn", False) + and config.skip_recompute_ops[layer_idx].get("flash_attn", False) ): self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True) @@ -1168,12 +1168,12 @@ def forward( class LlamaDecoderLayer(nn.Layer): - def __init__(self, config, layerwise_recompute: bool = False): + def __init__(self, config, layerwise_recompute: bool = False, layer_idx: int = 0): super().__init__() self.config = config self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config, layerwise_recompute) - self.mlp = LlamaMLP(config) + self.self_attn = LlamaAttention(config, layerwise_recompute, layer_idx=layer_idx) + self.mlp = LlamaMLP(config, layer_idx=layer_idx) self.input_layernorm = LlamaRMSNorm(config) self.post_attention_layernorm = LlamaRMSNorm(config) self.sequence_parallel = config.sequence_parallel @@ -1518,9 +1518,11 @@ def __init__(self, config: LlamaConfig): self.layers = nn.LayerList( [ LlamaDecoderLayer( - create_skip_config_for_refined_recompute(i, config), i not in self.no_recompute_layers + create_skip_config_for_refined_recompute(layer_idx, config), + layerwise_recompute=layer_idx not in self.no_recompute_layers, + layer_idx=layer_idx, ) - for i in range(config.num_hidden_layers) + for layer_idx in range(config.num_hidden_layers) ] ) self.norm = LlamaRMSNorm(config) diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index 58e49e69d989..d58304a8e334 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -138,9 +138,9 @@ def get_triangle_upper_mask(x, mask=None): class QWenAttention(nn.Layer): - def __init__(self, config): + def __init__(self, config, layer_idx: int = 0): super().__init__() - + self.layer_idx = layer_idx self.config = config self.seq_length = config.seq_length self.hidden_size = config.hidden_size @@ -166,18 +166,18 @@ def __init__(self, config): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("attention_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops.get("attention_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("attention_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops.get("attention_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -252,7 +252,7 @@ def _attn(self, query, key, value, attention_mask=None): skip_recompute = ( self.config.recompute and not self.config.recompute_use_reentrant - and self.config.skip_recompute_ops.get("flash_attn", False) + and self.config.skip_recompute_ops[self.layer_idx].get("flash_attn", False) ) attn_output = no_recompute( F.scaled_dot_product_attention, @@ -409,7 +409,7 @@ def forward( class QWenMLP(nn.Layer): - def __init__(self, config): + def __init__(self, config, layer_idx: int = 0): super().__init__() ff_dim_in = config.intermediate_size // 2 self.fuse_attention_ffn = config.fuse_attention_ffn @@ -420,18 +420,18 @@ def __init__(self, config): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("mlp_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops.get("mlp_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("mlp_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops.get("mlp_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -484,13 +484,13 @@ def forward(self, hidden_states): class QWenBlock(nn.Layer): - def __init__(self, config): + def __init__(self, config, layer_idx: int = 0): super().__init__() self.sequence_parallel = config.sequence_parallel self.ln_1 = QWenRMSNorm(config) - self.attn = QWenAttention(config) + self.attn = QWenAttention(config, layer_idx=layer_idx) self.ln_2 = QWenRMSNorm(config) - self.mlp = QWenMLP(config) + self.mlp = QWenMLP(config, layer_idx=layer_idx) def forward( self, diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 9b1ab534cc0e..9c2586f332e3 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -364,7 +364,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class Qwen2MLP(nn.Layer): - def __init__(self, config: Qwen2Config, is_shared=False): + def __init__(self, config: Qwen2Config, is_shared=False, layer_idx: int = 0): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size @@ -377,9 +377,9 @@ def __init__(self, config: Qwen2Config, is_shared=False): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("mlp_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops.get("mlp_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear @@ -387,9 +387,9 @@ def __init__(self, config: Qwen2Config, is_shared=False): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("mlp_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops.get("mlp_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -441,7 +441,7 @@ class Qwen2Attention(nn.Layer): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True): + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, layer_idx: int = 0): super().__init__() self.config = config @@ -493,9 +493,9 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("attention_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops.get("attention_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear @@ -503,9 +503,9 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops.get("attention_column_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops.get("attention_row_ln", False): + if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -531,7 +531,7 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True): if ( config.recompute and not config.recompute_use_reentrant - and config.skip_recompute_ops.get("flash_attn", False) + and config.skip_recompute_ops[layer_idx].get("flash_attn", False) ): self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True) @@ -652,13 +652,13 @@ def forward( class Qwen2DecoderLayer(nn.Layer): - def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False): + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False, layer_idx: int = 0): super().__init__() self.config = config self.hidden_size = config.hidden_size - self.self_attn = Qwen2Attention(config, layerwise_recompute) + self.self_attn = Qwen2Attention(config, layerwise_recompute, layer_idx=layer_idx) - self.mlp = Qwen2MLP(config) + self.mlp = Qwen2MLP(config, layer_idx=layer_idx) self.input_layernorm = Qwen2RMSNorm(config) self.post_attention_layernorm = Qwen2RMSNorm(config) @@ -949,8 +949,9 @@ def __init__(self, config: Qwen2Config): self.layers = nn.LayerList( [ Qwen2DecoderLayer( - create_skip_config_for_refined_recompute(layer_idx, config), - layerwise_recompute=layer_idx not in self.no_recompute_layers, + create_skip_config_for_refined_recompute(config.skip_recompute_ops[layer_idx], config), + layerwise_recompute=config.skip_recompute_ops[layer_idx] not in self.no_recompute_layers, + layer_idx=layer_idx, ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/paddlenlp/transformers/refined_recompute.py b/paddlenlp/transformers/refined_recompute.py index 0884e4d688df..da7d8780d4c7 100644 --- a/paddlenlp/transformers/refined_recompute.py +++ b/paddlenlp/transformers/refined_recompute.py @@ -20,7 +20,6 @@ import queue import uuid import weakref -from copy import deepcopy import paddle import paddle.autograd @@ -520,16 +519,15 @@ def create_skip_config_for_refined_recompute(layer_idx, config): Returns: dict: Returns an updated configuration file containing the following key-value pairs: - - skip_recompute_ops (dict): A dictionary with each operation's name and a boolean - indicating whether to skip recomputation, defaults to None. + - skip_recompute_ops (dict): A dictionary with each model layer's each operation's name + and a boolean indicating whether to skip recomputation, defaults to None. - If the refined_recompute key does not exist or recompute is set to False, the original configuration file is returned. """ - if not config.recompute: + if not config.recompute or config.refined_recompute is None: return config skip_config = dict() - config = deepcopy(config) try: hcg = fleet.get_hybrid_communicate_group() @@ -557,7 +555,9 @@ def create_skip_config_for_refined_recompute(layer_idx, config): skip_config[op_name] = True else: skip_config[op_name] = False - config.skip_recompute_ops = skip_config + + config.skip_recompute_ops[layer_idx] = skip_config + return config From 64084ca308d1752b27cc78fcb2d02904c740b343 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Wed, 18 Dec 2024 17:32:22 +0800 Subject: [PATCH 02/24] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=AF=B9=E5=BA=94?= =?UTF-8?q?=E7=9A=84=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/trainer.md | 41 ++++++++++++++++++++++++++++++ llm/docs/finetune.md | 1 + paddlenlp/trainer/training_args.py | 13 ++++++++++ 3 files changed, 55 insertions(+) diff --git a/docs/trainer.md b/docs/trainer.md index d643c99268e4..4b651d6af97e 100644 --- a/docs/trainer.md +++ b/docs/trainer.md @@ -601,6 +601,47 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并 Recompute the forward pass to calculate gradients. Used for saving memory (default: False) + --refined_recompute + 精化重新计算参数,用于在GPU显存使用和计算速度之间寻求最佳平衡。 + 此参数允许用户对重新计算过程进行细致控制,以优化资源利用。具体配置示例如下: + `"attention_column_ln:-1, attention_row_ln:-1, flash_attn:-1, mlp_column_ln:5, mlp_row_ln:-1"` + + 在配置中,支持的参数包括: + `attention_column_ln` + `attention_row_ln` + `mlp_column_ln` + `mlp_row_ln` + `flash_attn` + + 每个参数后的数字,即`skip_num`,决定了对应操作跳过重计算的次数。具体解释如下: + `skip_num` 为 `-1`:表示在所有阶段均不进行重新计算,从而最大化显存使用。 + `skip_num` 为 `0`:表示在每个阶段都强制进行重新计算,以最小化显存使用。 + + 此外,您还可以将`skip_num`设置为`[1, ..., num_layers]`范围内的任意值。若`skip_num`超出`num_layers`,其行为将等同于设置为`-1`。 + 若配置中省略了某个参数,则系统默认将其设置为`xxx:0`。 + + (类型: `str`, 可选, 默认为: "") + + Refined recompute parameter for optimizing the balance between GPU memory usage and computational speed. + This parameter allows fine-grained control over the recomputation process to optimize resource utilization. An example configuration is as follows: + `"attention_column_ln:-1, attention_row_ln:-1, flash_attn:-1, mlp_column_ln:5, mlp_row_ln:-1"` + + The supported parameters in the configuration include: + `attention_column_ln` + `attention_row_ln` + `mlp_column_ln` + `mlp_row_ln` + `flash_attn` + + The number following each parameter, `skip_num`, determines the number of times to bypass recomputation for the specified operation. Specifically: + `skip_num of -1`: Indicates no recomputation across all stages, maximizing memory usage. + `skip_num of 0`: Enforces recomputation at every stage, minimizing memory usage. + + Additionally, you can set skip_num to any value within the range `[1, ..., num_layers]`. If `skip_num` exceeds `num_layers`, it will behave as if set to `-1`. + If a parameter is omitted from the configuration, it defaults to `xxx:0`. + + (Type: `str`, optional, default: "") + --minimum_eval_times 最少评估次数,如果当前设置的eval_steps,评估次数少于minimum_eval_times, 此选项会覆盖eval_steps参数。 diff --git a/llm/docs/finetune.md b/llm/docs/finetune.md index 233213a9b73b..737b4d1ba6e4 100644 --- a/llm/docs/finetune.md +++ b/llm/docs/finetune.md @@ -180,6 +180,7 @@ python merge_lora_params.py \ - `do_train`: 是否打开训练,默认为 False。 - `do_eval`: 是否打开评估,默认为 False。 - `recompute`: 重计算,暂支持 full 策略。开启后可降低显存以达到增大 batch size 的目的,默认为 False。 +- `refined_recompute`: 精细化重计算,通过精细化控制所需重计算的部分从而达到显存和性能之间的均衡,当前仅支持`llama`系列模型以及`qwen`系列模型,详细使用请参考[TrainingArguments 文档](https://paddlenlp.readthedocs.io/zh/latest/trainer.html)。 - `tensor_parallel_degree`: 此参数 tensor_parallel_degree 表示将一层 transformer 结构的份数,该方法对通信开销较大, 建议 tensor_parallel_degree<=8, 尽量使用机器内部通信。默认为-1,表示不启用张量并行。 - `pipeline_parallel_degree`: 表示划分流水线的大小.(假设该参数为4, 模型12层, 则每一个 pp stage 包含3层模型) 默认值-1, 表示不启用流水线并行。 - `sharding_parallel_degree`: 表示分组参数切片的数据并行大小. 默认值1, 表示不启用分组参数切片的数据并行。 diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index ab9aeed74449..058c641d504f 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -740,6 +740,19 @@ class TrainingArguments: "Only support for networks with transformer blocks." }, ) + refined_recompute: str = field( + default="", + metadata={ + "help": "The refined recompute parameter is designed to optimize the balance between GPU memory usage and computational speed.\n" + "An example configuration could be: `attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1`.\n" + "The supported parameters for refining recompute are `attention_column_ln`, `attention_row_ln`, `flash_attn`, `mlp_column_ln`, and `mlp_row_ln`.\n" + "The associated number, `skip_num`, determines how many times to bypass recomputation for the specified operation.\n" + "A `skip_num` of `-1` indicates no recomputation across all stages, maximizing memory usage;\n" + "A `skip_num` of `0` enforces recomputation at every stage, minimizing memory usage.\n" + "You can also set `skip_num` to a value within the range [1, ..., num_layers]. If `skip_num` exceeds `num_layers`, it will behave as if set to `-1`.\n" + "If a parameter is omitted, it defaults to `xxx:0`." + }, + ) scale_loss: float = field(default=2**15, metadata={"help": "The value of initial scale_loss for fp16."}) From 258b48ba69095ca92a7a3d8b6ae361314a09a30c Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Wed, 18 Dec 2024 17:34:38 +0800 Subject: [PATCH 03/24] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=AF=B9=E5=BA=94?= =?UTF-8?q?=E7=9A=84=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddlenlp/transformers/qwen/modeling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index d58304a8e334..1d8598e83bd8 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -726,9 +726,10 @@ def __init__(self, config): self.h = nn.LayerList( [ QWenBlock( - create_skip_config_for_refined_recompute(i, config), + config=create_skip_config_for_refined_recompute(layer_idx, config), + layer_idx=layer_idx, ) - for i in range(config.num_hidden_layers) + for layer_idx in range(config.num_hidden_layers) ] ) self.ln_f = QWenRMSNorm(config) From 79b86f6c29fed575ecc2e5c09d73b19d276ffe15 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Wed, 18 Dec 2024 17:39:25 +0800 Subject: [PATCH 04/24] fix --- paddlenlp/transformers/llama/modeling.py | 2 +- paddlenlp/transformers/llama/modeling_pp.py | 2 +- paddlenlp/transformers/qwen/modeling.py | 2 +- paddlenlp/transformers/qwen/modeling_pp.py | 5 ++++- paddlenlp/transformers/qwen2/modeling.py | 2 +- paddlenlp/transformers/qwen2/modeling_pp.py | 2 +- 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index d571d8e0bb17..24d8a0aea093 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1518,7 +1518,7 @@ def __init__(self, config: LlamaConfig): self.layers = nn.LayerList( [ LlamaDecoderLayer( - create_skip_config_for_refined_recompute(layer_idx, config), + config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[layer_idx], config), layerwise_recompute=layer_idx not in self.no_recompute_layers, layer_idx=layer_idx, ) diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index 1ec4c027a72a..db8f42661665 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -376,7 +376,7 @@ def get_hcg(): self.add_sequential_layer( LayerDesc( LlamaDecoderLayerPipe, - config=create_skip_config_for_refined_recompute(i, config), + config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[i], config), layerwise_recompute=i not in self.no_recompute_layers, ), f"llama.layers.{i}", diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index 1d8598e83bd8..3bd215216965 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -726,7 +726,7 @@ def __init__(self, config): self.h = nn.LayerList( [ QWenBlock( - config=create_skip_config_for_refined_recompute(layer_idx, config), + config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[layer_idx], config), layer_idx=layer_idx, ) for layer_idx in range(config.num_hidden_layers) diff --git a/paddlenlp/transformers/qwen/modeling_pp.py b/paddlenlp/transformers/qwen/modeling_pp.py index 889ed60e5416..f29fb58a0447 100644 --- a/paddlenlp/transformers/qwen/modeling_pp.py +++ b/paddlenlp/transformers/qwen/modeling_pp.py @@ -173,7 +173,10 @@ def get_hcg(): self.add_sequential_layer(LayerDesc(QWenEmbeddingPipe, config=config), "qwen") for i in range(config.num_hidden_layers): self.add_sequential_layer( - LayerDesc(QWenBlockPipe, config=create_skip_config_for_refined_recompute(i, config)), + LayerDesc( + QWenBlockPipe, + config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[i], config), + ), f"qwen.h.{i}", ) self.add_sequential_layer(LayerDesc(QWenRMSNormPipe, config=config), "qwen.ln_f") diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index e94191b620b3..ff4804be95cf 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -953,7 +953,7 @@ def __init__(self, config: Qwen2Config): self.layers = nn.LayerList( [ Qwen2DecoderLayer( - create_skip_config_for_refined_recompute(config.skip_recompute_ops[layer_idx], config), + config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[layer_idx], config), layerwise_recompute=config.skip_recompute_ops[layer_idx] not in self.no_recompute_layers, layer_idx=layer_idx, ) diff --git a/paddlenlp/transformers/qwen2/modeling_pp.py b/paddlenlp/transformers/qwen2/modeling_pp.py index 916baad328ce..3343149e2798 100644 --- a/paddlenlp/transformers/qwen2/modeling_pp.py +++ b/paddlenlp/transformers/qwen2/modeling_pp.py @@ -300,7 +300,7 @@ def get_hcg(): self.add_sequential_layer( LayerDesc( Qwen2DecoderLayerPipe, - config=create_skip_config_for_refined_recompute(i, config), + config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[i], config), layerwise_recompute=i not in self.no_recompute_layers, ), f"qwen2.layers.{i}", From 927d157635e6f267393a64ed158e0637bebde89f Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Wed, 18 Dec 2024 17:44:45 +0800 Subject: [PATCH 05/24] update --- paddlenlp/transformers/llama/modeling.py | 2 +- paddlenlp/transformers/llama/modeling_pp.py | 2 +- paddlenlp/transformers/qwen/modeling.py | 2 +- paddlenlp/transformers/qwen/modeling_pp.py | 2 +- paddlenlp/transformers/qwen2/modeling.py | 4 ++-- paddlenlp/transformers/qwen2/modeling_pp.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 24d8a0aea093..987387db7be5 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1518,7 +1518,7 @@ def __init__(self, config: LlamaConfig): self.layers = nn.LayerList( [ LlamaDecoderLayer( - config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[layer_idx], config), + config=create_skip_config_for_refined_recompute(layer_idx, config), layerwise_recompute=layer_idx not in self.no_recompute_layers, layer_idx=layer_idx, ) diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index db8f42661665..1ec4c027a72a 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -376,7 +376,7 @@ def get_hcg(): self.add_sequential_layer( LayerDesc( LlamaDecoderLayerPipe, - config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[i], config), + config=create_skip_config_for_refined_recompute(i, config), layerwise_recompute=i not in self.no_recompute_layers, ), f"llama.layers.{i}", diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index 3bd215216965..0383bca25e05 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -726,7 +726,7 @@ def __init__(self, config): self.h = nn.LayerList( [ QWenBlock( - config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[layer_idx], config), + config=create_skip_config_for_refined_recompute(layer_idx, config), layer_idx=layer_idx, ) for layer_idx in range(config.num_hidden_layers) diff --git a/paddlenlp/transformers/qwen/modeling_pp.py b/paddlenlp/transformers/qwen/modeling_pp.py index f29fb58a0447..52cc528f39eb 100644 --- a/paddlenlp/transformers/qwen/modeling_pp.py +++ b/paddlenlp/transformers/qwen/modeling_pp.py @@ -175,7 +175,7 @@ def get_hcg(): self.add_sequential_layer( LayerDesc( QWenBlockPipe, - config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[i], config), + config=create_skip_config_for_refined_recompute(i, config), ), f"qwen.h.{i}", ) diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index ff4804be95cf..7e52739cd33f 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -953,8 +953,8 @@ def __init__(self, config: Qwen2Config): self.layers = nn.LayerList( [ Qwen2DecoderLayer( - config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[layer_idx], config), - layerwise_recompute=config.skip_recompute_ops[layer_idx] not in self.no_recompute_layers, + config=create_skip_config_for_refined_recompute(layer_idx, config), + layerwise_recompute=layer_idx not in self.no_recompute_layers, layer_idx=layer_idx, ) for layer_idx in range(config.num_hidden_layers) diff --git a/paddlenlp/transformers/qwen2/modeling_pp.py b/paddlenlp/transformers/qwen2/modeling_pp.py index 3343149e2798..916baad328ce 100644 --- a/paddlenlp/transformers/qwen2/modeling_pp.py +++ b/paddlenlp/transformers/qwen2/modeling_pp.py @@ -300,7 +300,7 @@ def get_hcg(): self.add_sequential_layer( LayerDesc( Qwen2DecoderLayerPipe, - config=create_skip_config_for_refined_recompute(config.skip_recompute_ops[i], config), + config=create_skip_config_for_refined_recompute(i, config), layerwise_recompute=i not in self.no_recompute_layers, ), f"qwen2.layers.{i}", From 287817911361051a68e07407338f2a32b0a03bc4 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Wed, 18 Dec 2024 17:52:25 +0800 Subject: [PATCH 06/24] update --- paddlenlp/transformers/refined_recompute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/refined_recompute.py b/paddlenlp/transformers/refined_recompute.py index da7d8780d4c7..3a22df19525c 100644 --- a/paddlenlp/transformers/refined_recompute.py +++ b/paddlenlp/transformers/refined_recompute.py @@ -563,7 +563,7 @@ def create_skip_config_for_refined_recompute(layer_idx, config): def update_refined_recompute(rr, lora=False): """update refined recompute dict.""" - if rr == "": + if rr is None or rr == "": return {} else: From 92d7e8247386d2bc956fbcf96aa141b12552d57d Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Thu, 19 Dec 2024 16:08:18 +0800 Subject: [PATCH 07/24] update --- paddlenlp/transformers/configuration_utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/paddlenlp/transformers/configuration_utils.py b/paddlenlp/transformers/configuration_utils.py index eb7fb6060a50..59edfa5b3b09 100644 --- a/paddlenlp/transformers/configuration_utils.py +++ b/paddlenlp/transformers/configuration_utils.py @@ -268,14 +268,6 @@ class LlmMetaConfig: "Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']", ), ("recompute_use_reentrant", bool, False, "recompute_use_reentrant"), - # refined_recompute attributes - ( - "refined_recompute", - str, - "", - "refined_recompute, Choose from 'mlp_row_ln', 'mlp_column_ln', 'attention_row_ln', 'attention_column_ln', 'flash_attn']", - ), - ("skip_recompute_ops", Optional[Dict[str, int]], None, "skip_recompute_ops"), ] @classmethod @@ -569,6 +561,11 @@ def __init__(self, **kwargs): self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False) self.fuse_attention_ffn = kwargs.pop("fuse_attention_ffn", False) + # for refined_recompute + self.refined_recompute = kwargs.pop("refined_recompute", {}) + self.skip_recompute_ops = kwargs.pop("skip_recompute_ops", {}) + self.register_unsavable_keys(["refined_recompute", "skip_recompute_ops"]) + if "quantization_config" in kwargs and isinstance(kwargs["quantization_config"], Dict): kwargs["quantization_config"] = QuantizationConfig.from_dict(kwargs["quantization_config"]) self.quantization_config = kwargs.pop("quantization_config", QuantizationConfig()) From aa91e8e9a7905d03b37f742bc67dc102ae2fc14d Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Fri, 20 Dec 2024 12:06:34 +0800 Subject: [PATCH 08/24] add missing docs --- paddlenlp/trainer/training_args.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 058c641d504f..1ca83317de8a 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -293,6 +293,15 @@ class TrainingArguments: recompute (`bool`, *optional*, defaults to `False`): Recompute the forward pass to calculate gradients. Used for saving memory. Only support for networks with transformer blocks. + refined_recompute (`str`, *optional*, defaults to `""`): + The refined recompute parameter is designed to optimize the balance between GPU memory usage and computational speed. + An example configuration could be: `attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1`. + The supported parameters for refining recompute are `attention_column_ln`, `attention_row_ln`, `flash_attn`, `mlp_column_ln`, and `mlp_row_ln`. + The associated number, `skip_num`, determines how many times to bypass recomputation for the specified operation. + A `skip_num` of `-1` indicates no recomputation across all stages, maximizing memory usage; + A `skip_num` of `0` enforces recomputation at every stage, minimizing memory usage. + You can also set `skip_num` to a value within the range [1, ..., num_layers]. If `skip_num` exceeds `num_layers`, it will behave as if set to `-1`. + If a parameter is omitted, it defaults to `xxx:0`." scale_loss (`float`, *optional*, defaults to 32768): The value of initial scale_loss for fp16. (default: 32768) local_rank (`int`, *optional*, defaults to -1): From 2486c2a54b913496ca85aabe5ca22ae2b673981f Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Mon, 23 Dec 2024 15:56:26 +0800 Subject: [PATCH 09/24] add refined_recompute yaml --- docs/trainer.md | 4 +-- tests/fixtures/llm/refined_recompute.yaml | 41 +++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 tests/fixtures/llm/refined_recompute.yaml diff --git a/docs/trainer.md b/docs/trainer.md index 4b651d6af97e..44f4597f065d 100644 --- a/docs/trainer.md +++ b/docs/trainer.md @@ -604,7 +604,7 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并 --refined_recompute 精化重新计算参数,用于在GPU显存使用和计算速度之间寻求最佳平衡。 此参数允许用户对重新计算过程进行细致控制,以优化资源利用。具体配置示例如下: - `"attention_column_ln:-1, attention_row_ln:-1, flash_attn:-1, mlp_column_ln:5, mlp_row_ln:-1"` + `"attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1"` 在配置中,支持的参数包括: `attention_column_ln` @@ -624,7 +624,7 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并 Refined recompute parameter for optimizing the balance between GPU memory usage and computational speed. This parameter allows fine-grained control over the recomputation process to optimize resource utilization. An example configuration is as follows: - `"attention_column_ln:-1, attention_row_ln:-1, flash_attn:-1, mlp_column_ln:5, mlp_row_ln:-1"` + `"attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1"` The supported parameters in the configuration include: `attention_column_ln` diff --git a/tests/fixtures/llm/refined_recompute.yaml b/tests/fixtures/llm/refined_recompute.yaml new file mode 100644 index 000000000000..b34e8f401413 --- /dev/null +++ b/tests/fixtures/llm/refined_recompute.yaml @@ -0,0 +1,41 @@ +finetune: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 2 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-05 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "epoch" + save_strategy: "epoch" + src_length: 1024 + max_length: 2048 + fp16: true + fp16_opt_level: "O2" + do_train: true + do_eval: true + use_flash_attention: true + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + refined_recompute: "attention_column_ln:1,attention_row_ln:2,flash_attn:-1,mlp_column_ln:2,mlp_row_ln:-1" + save_total_limit: 1 + tensor_parallel_degree: 4 + pipeline_parallel_degree: 1 + ignore_save_lr_and_optim: 1 + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + sequence_parallel: 1 + qwen2: + model_name_or_path: __internal_testing__/tiny-random-qwen2 + sequence_parallel: 1 + qwen: + model_name_or_path: __internal_testing__/tiny-fused-qwen + sequence_parallel: 0 \ No newline at end of file From 688eb186ed774f6b61a0cf46d3b86e09f2c32c27 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Mon, 23 Dec 2024 15:57:20 +0800 Subject: [PATCH 10/24] strip op --- paddlenlp/transformers/refined_recompute.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlenlp/transformers/refined_recompute.py b/paddlenlp/transformers/refined_recompute.py index 3a22df19525c..b43ad2520039 100644 --- a/paddlenlp/transformers/refined_recompute.py +++ b/paddlenlp/transformers/refined_recompute.py @@ -577,6 +577,7 @@ def update_refined_recompute(rr, lora=False): ops = rr.split(",") enable_rr = False for op in ops: + op = op.strip() if ":" not in op: raise ValueError("Illegal refined_recompute input, please check.") op_name, skip_num = op.split(":")[0], int(op.split(":")[1]) From 1b14cbab208919154b9169a1c9f556878582f77f Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Mon, 23 Dec 2024 18:10:55 +0800 Subject: [PATCH 11/24] refined code --- llm/run_embedding.py | 2 - llm/run_finetune.py | 5 -- llm/run_pretrain.py | 4 -- paddlenlp/trainer/training_args.py | 37 ++++++++++ paddlenlp/transformers/configuration_utils.py | 12 ++-- paddlenlp/transformers/llama/modeling.py | 44 ++++++------ paddlenlp/transformers/llama/modeling_pp.py | 8 +-- paddlenlp/transformers/qwen/modeling.py | 36 +++++----- paddlenlp/transformers/qwen/modeling_pp.py | 7 +- paddlenlp/transformers/qwen2/modeling.py | 41 ++++++----- paddlenlp/transformers/qwen2/modeling_pp.py | 8 +-- paddlenlp/transformers/refined_recompute.py | 71 ++++--------------- tests/fixtures/llm/refined_recompute.yaml | 20 ++++-- 13 files changed, 144 insertions(+), 151 deletions(-) diff --git a/llm/run_embedding.py b/llm/run_embedding.py index e598f24839cf..785188008148 100644 --- a/llm/run_embedding.py +++ b/llm/run_embedding.py @@ -29,7 +29,6 @@ Qwen2SentenceEmbedding, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig -from paddlenlp.transformers.refined_recompute import update_refined_recompute from paddlenlp.trl import DataConfig, EmbeddingTrainer, ModelConfig, SFTConfig from paddlenlp.trl.llm_utils import compute_metrics, init_chat_template from paddlenlp.utils.log import logger @@ -88,7 +87,6 @@ def main(): assert isinstance(model_config, Qwen2Config), "Now only qwen2 supported" LlmMetaConfig.set_llm_config(model_config, training_args) - model_config.refined_recompute = update_refined_recompute(training_args.refined_recompute) model_config.use_fast_layer_norm = model_args.use_fast_layer_norm # Config for model using dropout, such as GPT. diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 0a1df229c142..c63a1724cacf 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -61,7 +61,6 @@ register_sequence_parallel_allreduce_hooks, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig -from paddlenlp.transformers.refined_recompute import update_refined_recompute from paddlenlp.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer from paddlenlp.trl.llm_utils import ( ZeroPaddingIterDatasetCallback, @@ -154,10 +153,6 @@ def main(): ) LlmMetaConfig.set_llm_config(model_config, training_args) - model_config.refined_recompute = update_refined_recompute( - training_args.refined_recompute, - model_args.lora, - ) model_config.use_fast_layer_norm = model_args.use_fast_layer_norm # Config for model using dropout, such as GPT. diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 18436f015e54..fc5e4510cc4c 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -44,7 +44,6 @@ register_sequence_parallel_allreduce_hooks, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass -from paddlenlp.transformers.refined_recompute import update_refined_recompute from paddlenlp.utils.batch_sampler import DistributedBatchSampler from paddlenlp.utils.log import logger from paddlenlp.utils.tools import get_env_device @@ -406,9 +405,6 @@ def main(): config = AutoConfig.from_pretrained(model_args.model_name_or_path) # set all llm config LlmMetaConfig.set_llm_config(config, training_args) - config.refined_recompute = update_refined_recompute( - training_args.refined_recompute, - ) config.use_fast_layer_norm = model_args.use_fast_layer_norm config.seq_length = data_args.max_seq_length diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 1ca83317de8a..d651cea87283 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1777,6 +1777,43 @@ def is_segment_parallel_supported(): f"The local_ran: {self.local_rank} should be consistent with the world size: {paddle.distributed.get_world_size()}." ) + # arse_refined_recompute string to dict + if self.refined_recompute in [None, ""]: + self.refined_recompute = dict() + else: + refined_recompute_dict = { + "mlp_row_ln": 0, + "attention_row_ln": 0, + "attention_column_ln": 0, + "mlp_column_ln": 0, + "flash_attn": 0, + } + ops = self.refined_recompute.split(",") + enable_rr = False + for op in ops: + op = op.strip() + if ":" not in op: + raise ValueError("Illegal refined_recompute input, please check.") + op_name, skip_num = op.split(":")[0], int(op.split(":")[1]) + if op_name not in refined_recompute_dict: + raise ValueError(f"Refined recompute do not support {op_name}, please check.") + if ( + op_name in ["mlp_row_ln", "attention_row_ln", "attention_column_ln", "mlp_column_ln"] + and self.tensor_parallel_degree <= 1 + ): + logger.warning( + f"Refined recompute is only supported for the `{op_name}` operation when `tensor_parallel_degree` is greater than 1. \ + This refined recompute operation will be ignored." + ) + continue + + refined_recompute_dict[op_name] = skip_num + if skip_num != 0: + enable_rr = True + if not enable_rr: + refined_recompute_dict = dict() + self.refined_recompute = refined_recompute_dict + def __str__(self): self_as_dict = asdict(self) self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()} diff --git a/paddlenlp/transformers/configuration_utils.py b/paddlenlp/transformers/configuration_utils.py index 59edfa5b3b09..761e7fd4096f 100644 --- a/paddlenlp/transformers/configuration_utils.py +++ b/paddlenlp/transformers/configuration_utils.py @@ -268,6 +268,13 @@ class LlmMetaConfig: "Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']", ), ("recompute_use_reentrant", bool, False, "recompute_use_reentrant"), + # refined_recompute attributes + ( + "refined_recompute", + str, + "", + "refined_recompute, Choose from 'mlp_row_ln', 'mlp_column_ln', 'attention_row_ln', 'attention_column_ln', 'flash_attn']", + ), ] @classmethod @@ -561,11 +568,6 @@ def __init__(self, **kwargs): self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False) self.fuse_attention_ffn = kwargs.pop("fuse_attention_ffn", False) - # for refined_recompute - self.refined_recompute = kwargs.pop("refined_recompute", {}) - self.skip_recompute_ops = kwargs.pop("skip_recompute_ops", {}) - self.register_unsavable_keys(["refined_recompute", "skip_recompute_ops"]) - if "quantization_config" in kwargs and isinstance(kwargs["quantization_config"], Dict): kwargs["quantization_config"] = QuantizationConfig.from_dict(kwargs["quantization_config"]) self.quantization_config = kwargs.pop("quantization_config", QuantizationConfig()) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 987387db7be5..a4110f0ed4a6 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -35,7 +35,7 @@ RRColumnSequenceParallelLinear, RRRowParallelLinear, RRRowSequenceParallelLinear, - create_skip_config_for_refined_recompute, + get_skip_recompte_ops, recompute, ) @@ -605,8 +605,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class LlamaMLP(nn.Layer): - def __init__(self, config, layer_idx: int = 0): + def __init__(self, config, skip_recompute_ops={}): super().__init__() + self.skip_recompute_ops = skip_recompute_ops self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.tensor_parallel_degree = config.tensor_parallel_degree @@ -618,9 +619,9 @@ def __init__(self, config, layer_idx: int = 0): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): + if skip_recompute_ops.get("mlp_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): + if skip_recompute_ops.get("mlp_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear @@ -628,9 +629,9 @@ def __init__(self, config, layer_idx: int = 0): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): + if skip_recompute_ops.get("mlp_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): + if skip_recompute_ops.get("mlp_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -682,9 +683,11 @@ def forward(self, x): class LlamaAttention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, layer_idx: int = 0): + def __init__( + self, config: LlamaConfig, layerwise_recompute: bool = False, layer_idx: int = 0, skip_recompute_ops={} + ): super().__init__() - + self.skip_recompute_ops = skip_recompute_ops self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -746,18 +749,18 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, layer # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): + if skip_recompute_ops.get("attention_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): + if skip_recompute_ops.get("attention_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): + if skip_recompute_ops.get("attention_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): + if skip_recompute_ops.get("attention_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -859,11 +862,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, layer self.attn_func = scaled_dot_product_attention # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` - if ( - config.recompute - and not config.recompute_use_reentrant - and config.skip_recompute_ops[layer_idx].get("flash_attn", False) - ): + if config.recompute and not config.recompute_use_reentrant and skip_recompute_ops.get("flash_attn", False): self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True) def _init_rope(self): @@ -1168,12 +1167,13 @@ def forward( class LlamaDecoderLayer(nn.Layer): - def __init__(self, config, layerwise_recompute: bool = False, layer_idx: int = 0): + def __init__(self, config, layerwise_recompute: bool = False, skip_recompte_ops={}): super().__init__() self.config = config + self.skip_recompte_ops = skip_recompte_ops self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config, layerwise_recompute, layer_idx=layer_idx) - self.mlp = LlamaMLP(config, layer_idx=layer_idx) + self.self_attn = LlamaAttention(config, layerwise_recompute, skip_recompte_ops=self.skip_recompte_ops) + self.mlp = LlamaMLP(config, skip_recompte_ops=self.skip_recompte_ops) self.input_layernorm = LlamaRMSNorm(config) self.post_attention_layernorm = LlamaRMSNorm(config) self.sequence_parallel = config.sequence_parallel @@ -1518,9 +1518,9 @@ def __init__(self, config: LlamaConfig): self.layers = nn.LayerList( [ LlamaDecoderLayer( - config=create_skip_config_for_refined_recompute(layer_idx, config), + config=config, layerwise_recompute=layer_idx not in self.no_recompute_layers, - layer_idx=layer_idx, + skip_recompte_ops=get_skip_recompte_ops(config, layer_idx), ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index 1ec4c027a72a..bf28d01c8292 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -24,10 +24,7 @@ ) from paddlenlp.transformers.model_utils import PipelinePretrainedModel -from paddlenlp.transformers.refined_recompute import ( - create_skip_config_for_refined_recompute, - recompute, -) +from paddlenlp.transformers.refined_recompute import get_skip_recompte_ops, recompute from paddlenlp.utils.tools import get_env_device from .modeling import ( @@ -376,8 +373,9 @@ def get_hcg(): self.add_sequential_layer( LayerDesc( LlamaDecoderLayerPipe, - config=create_skip_config_for_refined_recompute(i, config), + config=config, layerwise_recompute=i not in self.no_recompute_layers, + skip_recompte_ops=get_skip_recompte_ops(config, i), ), f"llama.layers.{i}", ) diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index 0383bca25e05..ab656731d5b1 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -31,7 +31,7 @@ RRColumnSequenceParallelLinear, RRRowParallelLinear, RRRowSequenceParallelLinear, - create_skip_config_for_refined_recompute, + get_skip_recompte_ops, no_recompute, recompute, ) @@ -138,9 +138,9 @@ def get_triangle_upper_mask(x, mask=None): class QWenAttention(nn.Layer): - def __init__(self, config, layer_idx: int = 0): + def __init__(self, config, skip_recompte_ops={}): super().__init__() - self.layer_idx = layer_idx + self.skip_recompte_ops = skip_recompte_ops self.config = config self.seq_length = config.seq_length self.hidden_size = config.hidden_size @@ -166,18 +166,18 @@ def __init__(self, config, layer_idx: int = 0): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): + if skip_recompte_ops.get("attention_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): + if skip_recompte_ops.get("attention_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): + if skip_recompte_ops.get("attention_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): + if skip_recompte_ops.get("attention_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -252,7 +252,7 @@ def _attn(self, query, key, value, attention_mask=None): skip_recompute = ( self.config.recompute and not self.config.recompute_use_reentrant - and self.config.skip_recompute_ops[self.layer_idx].get("flash_attn", False) + and self.skip_recompute_ops.get("flash_attn", False) ) attn_output = no_recompute( F.scaled_dot_product_attention, @@ -409,10 +409,11 @@ def forward( class QWenMLP(nn.Layer): - def __init__(self, config, layer_idx: int = 0): + def __init__(self, config, skip_recompte_ops={}): super().__init__() ff_dim_in = config.intermediate_size // 2 self.fuse_attention_ffn = config.fuse_attention_ffn + self.skip_recompte_ops = skip_recompte_ops if config.sequence_parallel: ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear @@ -420,18 +421,18 @@ def __init__(self, config, layer_idx: int = 0): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): + if skip_recompte_ops.get("mlp_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): + if skip_recompte_ops.get("mlp_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): + if skip_recompte_ops.get("mlp_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): + if skip_recompte_ops.get("mlp_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -484,13 +485,13 @@ def forward(self, hidden_states): class QWenBlock(nn.Layer): - def __init__(self, config, layer_idx: int = 0): + def __init__(self, config, skip_recompte_ops={}): super().__init__() self.sequence_parallel = config.sequence_parallel self.ln_1 = QWenRMSNorm(config) - self.attn = QWenAttention(config, layer_idx=layer_idx) + self.attn = QWenAttention(config, skip_recompte_ops=skip_recompte_ops) self.ln_2 = QWenRMSNorm(config) - self.mlp = QWenMLP(config, layer_idx=layer_idx) + self.mlp = QWenMLP(config, skip_recompte_ops=skip_recompte_ops) def forward( self, @@ -726,8 +727,9 @@ def __init__(self, config): self.h = nn.LayerList( [ QWenBlock( - config=create_skip_config_for_refined_recompute(layer_idx, config), + config=config, layer_idx=layer_idx, + skip_recompte_ops=get_skip_recompte_ops(config, layer_idx), ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/paddlenlp/transformers/qwen/modeling_pp.py b/paddlenlp/transformers/qwen/modeling_pp.py index 52cc528f39eb..95ac91834a9e 100644 --- a/paddlenlp/transformers/qwen/modeling_pp.py +++ b/paddlenlp/transformers/qwen/modeling_pp.py @@ -18,9 +18,7 @@ from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer from paddlenlp.transformers.model_utils import PipelinePretrainedModel -from paddlenlp.transformers.refined_recompute import ( - create_skip_config_for_refined_recompute, -) +from paddlenlp.transformers.refined_recompute import get_skip_recompte_ops from .modeling import ( QWenBlock, @@ -175,7 +173,8 @@ def get_hcg(): self.add_sequential_layer( LayerDesc( QWenBlockPipe, - config=create_skip_config_for_refined_recompute(i, config), + config=config, + skip_recompte_ops=get_skip_recompte_ops(config, i), ), f"qwen.h.{i}", ) diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 7e52739cd33f..e2be4c9e8021 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -39,7 +39,7 @@ RRColumnSequenceParallelLinear, RRRowParallelLinear, RRRowSequenceParallelLinear, - create_skip_config_for_refined_recompute, + get_skip_recompte_ops, recompute, ) from paddlenlp.utils.tools import get_env_device @@ -368,8 +368,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class Qwen2MLP(nn.Layer): - def __init__(self, config: Qwen2Config, is_shared=False, layer_idx: int = 0): + def __init__(self, config: Qwen2Config, is_shared=False, skip_recompute_ops={}): super().__init__() + self.skip_recompute_ops = skip_recompute_ops self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size @@ -381,9 +382,9 @@ def __init__(self, config: Qwen2Config, is_shared=False, layer_idx: int = 0): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): + if skip_recompute_ops.get("mlp_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): + if skip_recompute_ops.get("mlp_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear @@ -391,9 +392,9 @@ def __init__(self, config: Qwen2Config, is_shared=False, layer_idx: int = 0): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False): + if skip_recompute_ops.get("mlp_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False): + if skip_recompute_ops.get("mlp_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -445,10 +446,11 @@ class Qwen2Attention(nn.Layer): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, layer_idx: int = 0): + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, skip_recompute_ops={}): super().__init__() self.config = config + self.skip_recompute_ops = skip_recompute_ops self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -497,9 +499,9 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, layer_ # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): + if skip_recompute_ops.get("attention_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): + if skip_recompute_ops.get("attention_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear @@ -507,9 +509,9 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, layer_ # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False): + if skip_recompute_ops.get("attention_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False): + if skip_recompute_ops.get("attention_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -532,11 +534,7 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, layer_ self.attn_func = scaled_dot_product_attention # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` - if ( - config.recompute - and not config.recompute_use_reentrant - and config.skip_recompute_ops[layer_idx].get("flash_attn", False) - ): + if config.recompute and not config.recompute_use_reentrant and skip_recompute_ops.get("flash_attn", False): self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True) def forward( @@ -656,13 +654,14 @@ def forward( class Qwen2DecoderLayer(nn.Layer): - def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False, layer_idx: int = 0): + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False, skip_recompte_ops={}): super().__init__() self.config = config + self.skip_recompte_ops = skip_recompte_ops self.hidden_size = config.hidden_size - self.self_attn = Qwen2Attention(config, layerwise_recompute, layer_idx=layer_idx) + self.self_attn = Qwen2Attention(config, layerwise_recompute, skip_recompte_ops=skip_recompte_ops) - self.mlp = Qwen2MLP(config, layer_idx=layer_idx) + self.mlp = Qwen2MLP(config, skip_recompte_ops=skip_recompte_ops) self.input_layernorm = Qwen2RMSNorm(config) self.post_attention_layernorm = Qwen2RMSNorm(config) @@ -953,9 +952,9 @@ def __init__(self, config: Qwen2Config): self.layers = nn.LayerList( [ Qwen2DecoderLayer( - config=create_skip_config_for_refined_recompute(layer_idx, config), + config=config, layerwise_recompute=layer_idx not in self.no_recompute_layers, - layer_idx=layer_idx, + skip_recompte_ops=get_skip_recompte_ops(config, layer_idx), ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/paddlenlp/transformers/qwen2/modeling_pp.py b/paddlenlp/transformers/qwen2/modeling_pp.py index 916baad328ce..7f6829ad2892 100644 --- a/paddlenlp/transformers/qwen2/modeling_pp.py +++ b/paddlenlp/transformers/qwen2/modeling_pp.py @@ -24,10 +24,7 @@ SharedLayerDesc, ) -from paddlenlp.transformers.refined_recompute import ( - create_skip_config_for_refined_recompute, - recompute, -) +from paddlenlp.transformers.refined_recompute import get_skip_recompte_ops, recompute from ...utils.tools import get_env_device from ..model_utils import PipelinePretrainedModel @@ -300,8 +297,9 @@ def get_hcg(): self.add_sequential_layer( LayerDesc( Qwen2DecoderLayerPipe, - config=create_skip_config_for_refined_recompute(i, config), + config=config, layerwise_recompute=i not in self.no_recompute_layers, + skip_recompte_ops=get_skip_recompte_ops(config, i), ), f"qwen2.layers.{i}", ) diff --git a/paddlenlp/transformers/refined_recompute.py b/paddlenlp/transformers/refined_recompute.py index b43ad2520039..079b9b441d57 100644 --- a/paddlenlp/transformers/refined_recompute.py +++ b/paddlenlp/transformers/refined_recompute.py @@ -43,7 +43,6 @@ RowParallelLinear, RowSequenceParallelLinear, ) -from paddlenlp.utils.log import logger try: from paddle.base import core, framework @@ -55,7 +54,7 @@ "no_recompute", "recompute", "get_global_rr_queue_dict", - "update_refined_recompute", + "get_skip_recompte_ops", "RRColumnSequenceParallelLinear", "RRRowSequenceParallelLinear", "RRColumnParallelLinear", @@ -508,14 +507,14 @@ def get_pp_vp_split_layers(layer_num, pp_size, vp_size, skip_recompute_num=-1): return set(sum(no_recompute_layer_num, [])) -def create_skip_config_for_refined_recompute(layer_idx, config): +def get_skip_recompte_ops(config, layer_idx): """ - Creates a configuration for skipping recomputation based on the configuration file, + Creates a dictionary for skipping recomputation based on the configuration file, effective only at the specified layer index. Args: - layer_idx (int): The layer index used to check whether recomputation should be skipped. config (dict): The configuration file of the input model. + layer_idx (int): The layer index used to check whether recomputation should be skipped. Returns: dict: Returns an updated configuration file containing the following key-value pairs: @@ -525,9 +524,9 @@ def create_skip_config_for_refined_recompute(layer_idx, config): the original configuration file is returned. """ - if not config.recompute or config.refined_recompute is None: - return config - skip_config = dict() + skip_recompute_ops = dict() + if not config.recompute or not isinstance(config.refined_recompute, dict): + return skip_recompute_ops try: hcg = fleet.get_hybrid_communicate_group() @@ -542,62 +541,20 @@ def create_skip_config_for_refined_recompute(layer_idx, config): layer_num = config.num_layers if hasattr(config, "num_layers") else config.num_hidden_layers no_recompute_layers = get_pp_vp_split_layers(layer_num, pp_size, vp_size, skip_num) if layer_idx in no_recompute_layers: - skip_config[op_name] = True + skip_recompute_ops[op_name] = True else: - skip_config[op_name] = False + skip_recompute_ops[op_name] = False else: if skip_num == 0: # 0 means all recompute - skip_config[op_name] = False + skip_recompute_ops[op_name] = False elif skip_num < 0: # < 0 means all skip recompute - skip_config[op_name] = True + skip_recompute_ops[op_name] = True else: if layer_idx < skip_num: # < the number of layers to skip recompute - skip_config[op_name] = True + skip_recompute_ops[op_name] = True else: - skip_config[op_name] = False - - config.skip_recompute_ops[layer_idx] = skip_config - - return config - - -def update_refined_recompute(rr, lora=False): - """update refined recompute dict.""" - if rr is None or rr == "": - return {} - else: - - rr_res = { - "mlp_row_ln": 0, - "attention_row_ln": 0, - "attention_column_ln": 0, - "mlp_column_ln": 0, - "flash_attn": 0, - } - ops = rr.split(",") - enable_rr = False - for op in ops: - op = op.strip() - if ":" not in op: - raise ValueError("Illegal refined_recompute input, please check.") - op_name, skip_num = op.split(":")[0], int(op.split(":")[1]) - if op_name not in rr_res: - raise ValueError(f"Refined recompute do not support {op_name}, please check.") - - if op_name in ["mlp_row_ln", "attention_row_ln", "attention_column_ln", "mlp_column_ln"]: - if lora: - logger.warning( - "Currently, LoRA does not support refined recompute " - f"for the `{op_name}` op. This refined recompute op will be ignored." - ) - continue - rr_res[op_name] = skip_num - if skip_num != 0: - enable_rr = True - - if not enable_rr: - rr_res = {} - return rr_res + skip_recompute_ops[op_name] = False + return skip_recompute_ops class RRColumnParallelLinear(ColumnParallelLinear): diff --git a/tests/fixtures/llm/refined_recompute.yaml b/tests/fixtures/llm/refined_recompute.yaml index b34e8f401413..020a5412ec36 100644 --- a/tests/fixtures/llm/refined_recompute.yaml +++ b/tests/fixtures/llm/refined_recompute.yaml @@ -30,12 +30,24 @@ finetune: ignore_save_lr_and_optim: 1 default: - llama: + llama_sft_tp4_sp: model_name_or_path: __internal_testing__/tiny-random-llama sequence_parallel: 1 - qwen2: + llama_sft_tp4_pp2_sp: + model_name_or_path: __internal_testing__/tiny-random-llama + sequence_parallel: 1 + pipeline_parallel_degree: 2 + qwen2_sft_tp4: + model_name_or_path: __internal_testing__/tiny-random-qwen2 + sequence_parallel: 0 + qwen2_sft_tp4_sp: model_name_or_path: __internal_testing__/tiny-random-qwen2 sequence_parallel: 1 - qwen: + qwen_sft_tp4: model_name_or_path: __internal_testing__/tiny-fused-qwen - sequence_parallel: 0 \ No newline at end of file + sequence_parallel: 0 + qwen_lora_tp4_sp: + model_name_or_path: __internal_testing__/tiny-fused-qwen + sequence_parallel: 1 + learning_rate: 3e-04 + lora: true \ No newline at end of file From 0035f4f2854b1cf50ac41c331b742780370a9717 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Tue, 24 Dec 2024 10:07:01 +0800 Subject: [PATCH 12/24] fix typo and fix rr recompute --- paddlenlp/transformers/llama/modeling.py | 16 +++++----- paddlenlp/transformers/llama/modeling_pp.py | 4 +-- paddlenlp/transformers/qwen/modeling.py | 35 ++++++++++----------- paddlenlp/transformers/qwen/modeling_pp.py | 4 +-- paddlenlp/transformers/qwen2/modeling.py | 12 +++---- paddlenlp/transformers/qwen2/modeling_pp.py | 4 +-- paddlenlp/transformers/refined_recompute.py | 26 +++++---------- 7 files changed, 43 insertions(+), 58 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index a4110f0ed4a6..91cb5d882d47 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -35,7 +35,7 @@ RRColumnSequenceParallelLinear, RRRowParallelLinear, RRRowSequenceParallelLinear, - get_skip_recompte_ops, + get_skip_recompute_ops, recompute, ) @@ -683,9 +683,7 @@ def forward(self, x): class LlamaAttention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__( - self, config: LlamaConfig, layerwise_recompute: bool = False, layer_idx: int = 0, skip_recompute_ops={} - ): + def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, skip_recompute_ops={}): super().__init__() self.skip_recompute_ops = skip_recompute_ops self.config = config @@ -1167,13 +1165,13 @@ def forward( class LlamaDecoderLayer(nn.Layer): - def __init__(self, config, layerwise_recompute: bool = False, skip_recompte_ops={}): + def __init__(self, config, layerwise_recompute: bool = False, skip_recompute_ops={}): super().__init__() self.config = config - self.skip_recompte_ops = skip_recompte_ops + self.skip_recompute_ops = skip_recompute_ops self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config, layerwise_recompute, skip_recompte_ops=self.skip_recompte_ops) - self.mlp = LlamaMLP(config, skip_recompte_ops=self.skip_recompte_ops) + self.self_attn = LlamaAttention(config, layerwise_recompute, skip_recompute_ops=skip_recompute_ops) + self.mlp = LlamaMLP(config, skip_recompute_ops=skip_recompute_ops) self.input_layernorm = LlamaRMSNorm(config) self.post_attention_layernorm = LlamaRMSNorm(config) self.sequence_parallel = config.sequence_parallel @@ -1520,7 +1518,7 @@ def __init__(self, config: LlamaConfig): LlamaDecoderLayer( config=config, layerwise_recompute=layer_idx not in self.no_recompute_layers, - skip_recompte_ops=get_skip_recompte_ops(config, layer_idx), + skip_recompute_ops=get_skip_recompute_ops(config, layer_idx), ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index bf28d01c8292..b8a0fd849e1d 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -24,7 +24,7 @@ ) from paddlenlp.transformers.model_utils import PipelinePretrainedModel -from paddlenlp.transformers.refined_recompute import get_skip_recompte_ops, recompute +from paddlenlp.transformers.refined_recompute import get_skip_recompute_ops, recompute from paddlenlp.utils.tools import get_env_device from .modeling import ( @@ -375,7 +375,7 @@ def get_hcg(): LlamaDecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers, - skip_recompte_ops=get_skip_recompte_ops(config, i), + skip_recompute_ops=get_skip_recompute_ops(config, i), ), f"llama.layers.{i}", ) diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index ab656731d5b1..13e24898e94e 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -31,7 +31,7 @@ RRColumnSequenceParallelLinear, RRRowParallelLinear, RRRowSequenceParallelLinear, - get_skip_recompte_ops, + get_skip_recompute_ops, no_recompute, recompute, ) @@ -138,9 +138,9 @@ def get_triangle_upper_mask(x, mask=None): class QWenAttention(nn.Layer): - def __init__(self, config, skip_recompte_ops={}): + def __init__(self, config, skip_recompute_ops={}): super().__init__() - self.skip_recompte_ops = skip_recompte_ops + self.skip_recompute_ops = skip_recompute_ops self.config = config self.seq_length = config.seq_length self.hidden_size = config.hidden_size @@ -166,18 +166,18 @@ def __init__(self, config, skip_recompte_ops={}): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if skip_recompte_ops.get("attention_column_ln", False): + if skip_recompute_ops.get("attention_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if skip_recompte_ops.get("attention_row_ln", False): + if skip_recompute_ops.get("attention_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if skip_recompte_ops.get("attention_column_ln", False): + if skip_recompute_ops.get("attention_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if skip_recompte_ops.get("attention_row_ln", False): + if skip_recompute_ops.get("attention_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -409,11 +409,11 @@ def forward( class QWenMLP(nn.Layer): - def __init__(self, config, skip_recompte_ops={}): + def __init__(self, config, skip_recompute_ops={}): super().__init__() ff_dim_in = config.intermediate_size // 2 self.fuse_attention_ffn = config.fuse_attention_ffn - self.skip_recompte_ops = skip_recompte_ops + self.skip_recompute_ops = skip_recompute_ops if config.sequence_parallel: ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear @@ -421,18 +421,18 @@ def __init__(self, config, skip_recompte_ops={}): # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if skip_recompte_ops.get("mlp_column_ln", False): + if skip_recompute_ops.get("mlp_column_ln", False): ColumnParallelLinear = RRColumnSequenceParallelLinear - if skip_recompte_ops.get("mlp_row_ln", False): + if skip_recompute_ops.get("mlp_row_ln", False): RowParallelLinear = RRRowSequenceParallelLinear else: ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` if config.recompute and not config.recompute_use_reentrant: - if skip_recompte_ops.get("mlp_column_ln", False): + if skip_recompute_ops.get("mlp_column_ln", False): ColumnParallelLinear = RRColumnParallelLinear - if skip_recompte_ops.get("mlp_row_ln", False): + if skip_recompute_ops.get("mlp_row_ln", False): RowParallelLinear = RRRowParallelLinear if config.tensor_parallel_degree > 1: @@ -485,13 +485,13 @@ def forward(self, hidden_states): class QWenBlock(nn.Layer): - def __init__(self, config, skip_recompte_ops={}): + def __init__(self, config, skip_recompute_ops={}): super().__init__() self.sequence_parallel = config.sequence_parallel self.ln_1 = QWenRMSNorm(config) - self.attn = QWenAttention(config, skip_recompte_ops=skip_recompte_ops) + self.attn = QWenAttention(config, skip_recompute_ops=skip_recompute_ops) self.ln_2 = QWenRMSNorm(config) - self.mlp = QWenMLP(config, skip_recompte_ops=skip_recompte_ops) + self.mlp = QWenMLP(config, skip_recompute_ops=skip_recompute_ops) def forward( self, @@ -728,8 +728,7 @@ def __init__(self, config): [ QWenBlock( config=config, - layer_idx=layer_idx, - skip_recompte_ops=get_skip_recompte_ops(config, layer_idx), + skip_recompute_ops=get_skip_recompute_ops(config, layer_idx), ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/paddlenlp/transformers/qwen/modeling_pp.py b/paddlenlp/transformers/qwen/modeling_pp.py index 95ac91834a9e..570465bb9f5e 100644 --- a/paddlenlp/transformers/qwen/modeling_pp.py +++ b/paddlenlp/transformers/qwen/modeling_pp.py @@ -18,7 +18,7 @@ from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer from paddlenlp.transformers.model_utils import PipelinePretrainedModel -from paddlenlp.transformers.refined_recompute import get_skip_recompte_ops +from paddlenlp.transformers.refined_recompute import get_skip_recompute_ops from .modeling import ( QWenBlock, @@ -174,7 +174,7 @@ def get_hcg(): LayerDesc( QWenBlockPipe, config=config, - skip_recompte_ops=get_skip_recompte_ops(config, i), + skip_recompute_ops=get_skip_recompute_ops(config, i), ), f"qwen.h.{i}", ) diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index a01a7a1af0f7..91f24d14aced 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -39,7 +39,7 @@ RRColumnSequenceParallelLinear, RRRowParallelLinear, RRRowSequenceParallelLinear, - get_skip_recompte_ops, + get_skip_recompute_ops, recompute, ) from paddlenlp.utils.tools import get_env_device @@ -654,14 +654,14 @@ def forward( class Qwen2DecoderLayer(nn.Layer): - def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False, skip_recompte_ops={}): + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False, skip_recompute_ops={}): super().__init__() self.config = config - self.skip_recompte_ops = skip_recompte_ops + self.skip_recompute_ops = skip_recompute_ops self.hidden_size = config.hidden_size - self.self_attn = Qwen2Attention(config, layerwise_recompute, skip_recompte_ops=skip_recompte_ops) + self.self_attn = Qwen2Attention(config, layerwise_recompute, skip_recompute_ops=skip_recompute_ops) - self.mlp = Qwen2MLP(config, skip_recompte_ops=skip_recompte_ops) + self.mlp = Qwen2MLP(config, skip_recompute_ops=skip_recompute_ops) self.input_layernorm = Qwen2RMSNorm(config) self.post_attention_layernorm = Qwen2RMSNorm(config) @@ -954,7 +954,7 @@ def __init__(self, config: Qwen2Config): Qwen2DecoderLayer( config=config, layerwise_recompute=layer_idx not in self.no_recompute_layers, - skip_recompte_ops=get_skip_recompte_ops(config, layer_idx), + skip_recompute_ops=get_skip_recompute_ops(config, layer_idx), ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/paddlenlp/transformers/qwen2/modeling_pp.py b/paddlenlp/transformers/qwen2/modeling_pp.py index 7f6829ad2892..aae86311f100 100644 --- a/paddlenlp/transformers/qwen2/modeling_pp.py +++ b/paddlenlp/transformers/qwen2/modeling_pp.py @@ -24,7 +24,7 @@ SharedLayerDesc, ) -from paddlenlp.transformers.refined_recompute import get_skip_recompte_ops, recompute +from paddlenlp.transformers.refined_recompute import get_skip_recompute_ops, recompute from ...utils.tools import get_env_device from ..model_utils import PipelinePretrainedModel @@ -299,7 +299,7 @@ def get_hcg(): Qwen2DecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers, - skip_recompte_ops=get_skip_recompte_ops(config, i), + skip_recompute_ops=get_skip_recompute_ops(config, i), ), f"qwen2.layers.{i}", ) diff --git a/paddlenlp/transformers/refined_recompute.py b/paddlenlp/transformers/refined_recompute.py index 079b9b441d57..1dc7bc261e17 100644 --- a/paddlenlp/transformers/refined_recompute.py +++ b/paddlenlp/transformers/refined_recompute.py @@ -54,7 +54,7 @@ "no_recompute", "recompute", "get_global_rr_queue_dict", - "get_skip_recompte_ops", + "get_skip_recompute_ops", "RRColumnSequenceParallelLinear", "RRRowSequenceParallelLinear", "RRColumnParallelLinear", @@ -329,10 +329,7 @@ def _recompute_without_reentrant(function, preserve_rng_state=True, *args, **kwa amp_white_list, amp_black_list = tracer._get_amp_op_list() class IntermediateHolder: - def __init__(self, name, shape, dtype) -> None: - self.name = name - self.shape = shape - self.dtype = dtype + pass storage = weakref.WeakKeyDictionary() holder_list = [] @@ -341,11 +338,11 @@ def __init__(self, name, shape, dtype) -> None: def pack(x): # [PACK] in no recompute context or input tensor no need recompute, return the input tensor directly - if x.persistable or (in_no_recompute_ctx() and not x.name.endswith(recompute_suffix)): + if x is not None and x.persistable or (in_no_recompute_ctx() and not x.name.endswith(recompute_suffix)): return share_buffer_to_tensor_or_param(x) # remove the recompute suffix - res = IntermediateHolder(x.name, x.shape, x.dtype) + res = IntermediateHolder() holder_list.append(weakref.ref(res)) return res @@ -358,7 +355,7 @@ def unpack(x): if len(storage) == 0: def inner_pack(inner_x): - if inner_x.persistable: + if inner_x is not None and inner_x.persistable: return nonlocal unpack_counter @@ -404,16 +401,7 @@ def inner_unpack(inner_x): raise Exception( "Not supported to retrieve a tensor saved by autograd multiple times that is no need to recompute." ) - tensor = storage.pop(x) - assert x.shape == tensor.shape, ( - f"The shape:{x.shape} of the tensor saved by autograd is not " - f"consistent with the original tensor shape:{tensor.shape}! " - ) - assert x.dtype == tensor.dtype, ( - f"The dtype:{x.dtype} of the tensor saved by autograd is not" - f"consistent with the original tensor dtype:{tensor.dtype}! " - ) - return tensor + return storage[x] with switch_recompute_id_ctx(recompute_id + "@first"): with paddle.autograd.saved_tensors_hooks(pack, unpack): @@ -507,7 +495,7 @@ def get_pp_vp_split_layers(layer_num, pp_size, vp_size, skip_recompute_num=-1): return set(sum(no_recompute_layer_num, [])) -def get_skip_recompte_ops(config, layer_idx): +def get_skip_recompute_ops(config, layer_idx): """ Creates a dictionary for skipping recomputation based on the configuration file, effective only at the specified layer index. From b1f0ef86b03557258a18addae04b33a0d5c60dcb Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Tue, 24 Dec 2024 12:09:55 +0800 Subject: [PATCH 13/24] test --- paddlenlp/transformers/llama/modeling.py | 12 +++- paddlenlp/transformers/qwen/modeling.py | 12 +++- paddlenlp/transformers/qwen2/modeling.py | 13 ++-- tests/transformers/test_refined_recompute.py | 65 ++++++++++++++++++++ 4 files changed, 92 insertions(+), 10 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 91cb5d882d47..7dd3389fac8e 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -605,8 +605,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class LlamaMLP(nn.Layer): - def __init__(self, config, skip_recompute_ops={}): + def __init__(self, config, skip_recompute_ops=None): super().__init__() + if skip_recompute_ops is None: + skip_recompute_ops = {} self.skip_recompute_ops = skip_recompute_ops self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size @@ -683,8 +685,10 @@ def forward(self, x): class LlamaAttention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, skip_recompute_ops={}): + def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, skip_recompute_ops=None): super().__init__() + if skip_recompute_ops is None: + skip_recompute_ops = {} self.skip_recompute_ops = skip_recompute_ops self.config = config self.hidden_size = config.hidden_size @@ -1165,9 +1169,11 @@ def forward( class LlamaDecoderLayer(nn.Layer): - def __init__(self, config, layerwise_recompute: bool = False, skip_recompute_ops={}): + def __init__(self, config, layerwise_recompute: bool = False, skip_recompute_ops=None): super().__init__() self.config = config + if skip_recompute_ops is None: + skip_recompute_ops = {} self.skip_recompute_ops = skip_recompute_ops self.hidden_size = config.hidden_size self.self_attn = LlamaAttention(config, layerwise_recompute, skip_recompute_ops=skip_recompute_ops) diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index 13e24898e94e..a2a34ffbc8c2 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -138,8 +138,10 @@ def get_triangle_upper_mask(x, mask=None): class QWenAttention(nn.Layer): - def __init__(self, config, skip_recompute_ops={}): + def __init__(self, config, skip_recompute_ops=None): super().__init__() + if skip_recompute_ops is None: + skip_recompute_ops = {} self.skip_recompute_ops = skip_recompute_ops self.config = config self.seq_length = config.seq_length @@ -409,8 +411,10 @@ def forward( class QWenMLP(nn.Layer): - def __init__(self, config, skip_recompute_ops={}): + def __init__(self, config, skip_recompute_ops=None): super().__init__() + if skip_recompute_ops is None: + skip_recompute_ops = {} ff_dim_in = config.intermediate_size // 2 self.fuse_attention_ffn = config.fuse_attention_ffn self.skip_recompute_ops = skip_recompute_ops @@ -485,8 +489,10 @@ def forward(self, hidden_states): class QWenBlock(nn.Layer): - def __init__(self, config, skip_recompute_ops={}): + def __init__(self, config, skip_recompute_ops=None): super().__init__() + if skip_recompute_ops is None: + skip_recompute_ops = {} self.sequence_parallel = config.sequence_parallel self.ln_1 = QWenRMSNorm(config) self.attn = QWenAttention(config, skip_recompute_ops=skip_recompute_ops) diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 91f24d14aced..928246e3a470 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -368,8 +368,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class Qwen2MLP(nn.Layer): - def __init__(self, config: Qwen2Config, is_shared=False, skip_recompute_ops={}): + def __init__(self, config: Qwen2Config, is_shared=False, skip_recompute_ops=None): super().__init__() + if skip_recompute_ops is None: + skip_recompute_ops = {} self.skip_recompute_ops = skip_recompute_ops self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size @@ -446,9 +448,10 @@ class Qwen2Attention(nn.Layer): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, skip_recompute_ops={}): + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, skip_recompute_ops=None): super().__init__() - + if skip_recompute_ops is None: + skip_recompute_ops = {} self.config = config self.skip_recompute_ops = skip_recompute_ops self.hidden_size = config.hidden_size @@ -654,8 +657,10 @@ def forward( class Qwen2DecoderLayer(nn.Layer): - def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False, skip_recompute_ops={}): + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False, skip_recompute_ops=None): super().__init__() + if skip_recompute_ops is None: + skip_recompute_ops = {} self.config = config self.skip_recompute_ops = skip_recompute_ops self.hidden_size = config.hidden_size diff --git a/tests/transformers/test_refined_recompute.py b/tests/transformers/test_refined_recompute.py index 25a1cdee7bf5..3798f7d459e6 100644 --- a/tests/transformers/test_refined_recompute.py +++ b/tests/transformers/test_refined_recompute.py @@ -29,6 +29,7 @@ import paddle.nn.functional as F from paddle.distributed.fleet.recompute import recompute as original_recompute +from paddlenlp.trainer.training_args import TrainingArguments from paddlenlp.transformers.refined_recompute import no_recompute as rr_no_recompute from paddlenlp.transformers.refined_recompute import recompute as rr_recompute from paddlenlp.utils.import_utils import is_paddle_cuda_available @@ -557,3 +558,67 @@ def test_refined_recompute_pp(self): del layer1, layer2, layer3 paddle.device.cuda.empty_cache() paddle.set_default_dtype(raw_dtype) + + +class TestRefinedRecomputeModel(unittest.TestCase): + def setUp(self): + self.args = TrainingArguments( + output_dir="./", + do_train=True, + max_steps=100, + tensor_parallel_degree=1, + pipeline_parallel_degree=1, + refined_recompute="attention_column_ln:1,attention_row_ln:2,flash_attn:-1,mlp_column_ln:2,mlp_row_ln:-1", + ) + + @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") + def test_llama_refined_recompute(self): + from paddlenlp.transformers.llama import LlamaConfig, LlamaModel + + llama_model = "__internal_testing__/tiny-random-llama" + config = LlamaConfig.from_pretrained(llama_model) + config.recompute = True + config.recompute_granularity = "full" + config.recompute_use_reentrant = False + config.sequence_parallel = False + config.use_flash_attention = True + config.refined_recompute = self.args.refined_recompute + model = LlamaModel.from_config(config=config, dtype="bfloat16") + input_ids = paddle.randint(0, 100, shape=[1, 1024], dtype="int64") + output = model(input_ids) + output[0].mean().backward() + + @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") + def test_qwen_refined_recompute(self): + from paddlenlp.transformers.qwen import QWenConfig, QWenModel + + llama_model = "__internal_testing__/tiny-random-qwen" + config = QWenConfig.from_pretrained(llama_model) + config.recompute = True + config.recompute_granularity = "full" + config.recompute_use_reentrant = False + config.sequence_parallel = False + config.use_flash_attention = True + config.refined_recompute = self.args.refined_recompute + config.seq_length = 1024 + model = QWenModel.from_config(config=config, dtype="bfloat16") + input_ids = paddle.randint(0, 100, shape=[1, 1024], dtype="int64") + output = model(input_ids) + output[0].mean().backward() + + @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") + def test_qwen2_refined_recompute(self): + from paddlenlp.transformers.qwen2 import Qwen2Config, Qwen2Model + + llama_model = "__internal_testing__/tiny-random-qwen2" + config = Qwen2Config.from_pretrained(llama_model) + config.recompute = True + config.recompute_granularity = "full" + config.recompute_use_reentrant = False + config.sequence_parallel = False + config.use_flash_attention = True + config.refined_recompute = self.args.refined_recompute + model = Qwen2Model.from_config(config=config, dtype="bfloat16") + input_ids = paddle.randint(0, 100, shape=[1, 1024], dtype="int64") + output = model(input_ids) + output[0].mean().backward() From a8f6ac8af5e257f18cd0ba37871a78bab0d580ad Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Wed, 25 Dec 2024 14:45:41 +0800 Subject: [PATCH 14/24] pop storage --- paddlenlp/transformers/refined_recompute.py | 8 ++-- tests/fixtures/llm/refined_recompute.yaml | 53 --------------------- 2 files changed, 5 insertions(+), 56 deletions(-) delete mode 100644 tests/fixtures/llm/refined_recompute.yaml diff --git a/paddlenlp/transformers/refined_recompute.py b/paddlenlp/transformers/refined_recompute.py index 1dc7bc261e17..21dc69e40cc4 100644 --- a/paddlenlp/transformers/refined_recompute.py +++ b/paddlenlp/transformers/refined_recompute.py @@ -395,13 +395,13 @@ def inner_unpack(inner_x): ): with switch_recompute_id_ctx(recompute_id + "@second"): with paddle.autograd.saved_tensors_hooks(inner_pack, inner_unpack): - unused_outputs = function(*args, **kwargs) # noqa: F841 + function(*args, **kwargs) if x not in storage: raise Exception( "Not supported to retrieve a tensor saved by autograd multiple times that is no need to recompute." ) - return storage[x] + return storage.pop(x) with switch_recompute_id_ctx(recompute_id + "@first"): with paddle.autograd.saved_tensors_hooks(pack, unpack): @@ -521,12 +521,14 @@ def get_skip_recompute_ops(config, layer_idx): pp_size = max(hcg.get_pipe_parallel_world_size(), 1) except: pp_size = 1 + layer_num = config.num_layers if hasattr(config, "num_layers") else config.num_hidden_layers + if hasattr(config, "add_tail_layer") and config.add_tail_layer: + layer_num += 1 for op_name, skip_num in config.refined_recompute.items(): # is pp model if pp_size > 1: vp_size = max(config.virtual_pp_degree, 1) - layer_num = config.num_layers if hasattr(config, "num_layers") else config.num_hidden_layers no_recompute_layers = get_pp_vp_split_layers(layer_num, pp_size, vp_size, skip_num) if layer_idx in no_recompute_layers: skip_recompute_ops[op_name] = True diff --git a/tests/fixtures/llm/refined_recompute.yaml b/tests/fixtures/llm/refined_recompute.yaml deleted file mode 100644 index 020a5412ec36..000000000000 --- a/tests/fixtures/llm/refined_recompute.yaml +++ /dev/null @@ -1,53 +0,0 @@ -finetune: - base: - dataset_name_or_path: "./data" - per_device_train_batch_size: 2 - gradient_accumulation_steps: 4 - per_device_eval_batch_size: 8 - eval_accumulation_steps: 16 - num_train_epochs: 3 - learning_rate: 3e-05 - warmup_steps: 30 - logging_steps: 1 - evaluation_strategy: "epoch" - save_strategy: "epoch" - src_length: 1024 - max_length: 2048 - fp16: true - fp16_opt_level: "O2" - do_train: true - do_eval: true - use_flash_attention: true - disable_tqdm: true - load_best_model_at_end: true - eval_with_do_generation: false - metric_for_best_model: "accuracy" - recompute: true - refined_recompute: "attention_column_ln:1,attention_row_ln:2,flash_attn:-1,mlp_column_ln:2,mlp_row_ln:-1" - save_total_limit: 1 - tensor_parallel_degree: 4 - pipeline_parallel_degree: 1 - ignore_save_lr_and_optim: 1 - - default: - llama_sft_tp4_sp: - model_name_or_path: __internal_testing__/tiny-random-llama - sequence_parallel: 1 - llama_sft_tp4_pp2_sp: - model_name_or_path: __internal_testing__/tiny-random-llama - sequence_parallel: 1 - pipeline_parallel_degree: 2 - qwen2_sft_tp4: - model_name_or_path: __internal_testing__/tiny-random-qwen2 - sequence_parallel: 0 - qwen2_sft_tp4_sp: - model_name_or_path: __internal_testing__/tiny-random-qwen2 - sequence_parallel: 1 - qwen_sft_tp4: - model_name_or_path: __internal_testing__/tiny-fused-qwen - sequence_parallel: 0 - qwen_lora_tp4_sp: - model_name_or_path: __internal_testing__/tiny-fused-qwen - sequence_parallel: 1 - learning_rate: 3e-04 - lora: true \ No newline at end of file From 150c4dcd008da7c07c34b4670e54bd168302776d Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Fri, 27 Dec 2024 12:02:17 +0800 Subject: [PATCH 15/24] =?UTF-8?q?=E5=8F=AA=E6=9C=89=E5=BD=93=E5=BC=80?= =?UTF-8?q?=E5=90=AFrefined=20recompute=E7=9A=84=E6=97=B6=E5=80=99?= =?UTF-8?q?=E6=89=8D=E4=BC=9A=E4=BD=BF=E7=94=A8=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E7=9A=84recompute,=E9=81=BF=E5=85=8D?= =?UTF-8?q?=E6=98=BE=E8=91=97=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddlenlp/transformers/llama/modeling.py | 12 ++++++++---- paddlenlp/transformers/qwen/modeling.py | 9 ++++++--- paddlenlp/transformers/qwen2/modeling.py | 12 ++++++++---- paddlenlp/transformers/refined_recompute.py | 12 ++++++++++++ 4 files changed, 34 insertions(+), 11 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 7dd3389fac8e..5248cc97b885 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -29,6 +29,7 @@ from paddle.autograd import PyLayer from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.recompute.recompute import recompute from paddlenlp.transformers.refined_recompute import ( RRColumnParallelLinear, @@ -36,8 +37,8 @@ RRRowParallelLinear, RRRowSequenceParallelLinear, get_skip_recompute_ops, - recompute, ) +from paddlenlp.transformers.refined_recompute import recompute as rr_recompute try: from paddle.incubate.nn.functional import fused_rotary_position_embedding @@ -1114,7 +1115,8 @@ def forward( and has_gradient and self.recompute_granularity == "core_attn" ): - outputs = recompute( + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute + outputs = recompute_fn( self.attn_func, query_states, self.config, @@ -1225,7 +1227,8 @@ def forward( and has_gradient and self.recompute_granularity == "full_attn" ): - outputs = recompute( + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute + outputs = recompute_fn( self.self_attn, hidden_states, position_ids, @@ -1633,7 +1636,8 @@ def custom_forward(*inputs): return custom_forward - hidden_states = recompute( + recompute_fn = rr_recompute if self.config.refined_recompute else recompute + hidden_states = recompute_fn( create_custom_forward(layer_module), hidden_states, position_ids, diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index a2a34ffbc8c2..f0f73c70427f 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -24,6 +24,7 @@ from paddle import Tensor, nn from paddle.distributed import fleet from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker +from paddle.distributed.fleet.recompute.recompute import recompute from paddle.utils import try_import from paddlenlp.transformers.refined_recompute import ( @@ -33,8 +34,8 @@ RRRowSequenceParallelLinear, get_skip_recompute_ops, no_recompute, - recompute, ) +from paddlenlp.transformers.refined_recompute import recompute as rr_recompute try: from paddle.incubate.nn.functional import swiglu @@ -390,7 +391,8 @@ def forward( has_gradient = not (query.stop_gradient and key.stop_gradient and value.stop_gradient) if self.enable_recompute and self.training and has_gradient and self.recompute_granularity == "core_attn": - attn_output, attn_weight = recompute( + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute + attn_output, attn_weight = recompute_fn( self._attn, query, key, @@ -799,7 +801,8 @@ def custom_forward(*inputs): return custom_forward - hidden_states = recompute( + recompute_fn = rr_recompute if any(block.skip_recompute_ops.values()) else recompute + hidden_states = recompute_fn( create_custom_forward(block), hidden_states, layer_past, diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 928246e3a470..6df986d48c46 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -32,6 +32,7 @@ from paddle import Tensor, nn from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.recompute.recompute import recompute from paddlenlp.transformers.contrastive_loss import SimpleContrastiveLoss from paddlenlp.transformers.refined_recompute import ( @@ -40,8 +41,8 @@ RRRowParallelLinear, RRRowSequenceParallelLinear, get_skip_recompute_ops, - recompute, ) +from paddlenlp.transformers.refined_recompute import recompute as rr_recompute from paddlenlp.utils.tools import get_env_device from .. import linear_utils @@ -605,7 +606,8 @@ def forward( and has_gradient and self.recompute_granularity == "core_attn" ): - outputs = recompute( + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute + outputs = recompute_fn( self.attn_func, query_states, self.config, @@ -714,7 +716,8 @@ def forward( and has_gradient and self.recompute_granularity == "full_attn" ): - outputs = recompute( + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute + outputs = recompute_fn( self.self_attn, hidden_states, position_ids, @@ -1058,7 +1061,8 @@ def custom_forward(*inputs): return custom_forward - hidden_states = recompute( + recompute_fn = rr_recompute if any(layer_module.skip_recompute_ops.values()) else recompute + hidden_states = recompute_fn( create_custom_forward(layer_module), hidden_states, position_ids, diff --git a/paddlenlp/transformers/refined_recompute.py b/paddlenlp/transformers/refined_recompute.py index 21dc69e40cc4..722383715424 100644 --- a/paddlenlp/transformers/refined_recompute.py +++ b/paddlenlp/transformers/refined_recompute.py @@ -429,9 +429,21 @@ def recompute(function, *args, **kwargs): Returns: Output of function on args. """ + # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop("preserve_rng_state", True) + + # whether to use reentrant method to implement recompute use_reentrant = kwargs.pop("use_reentrant", True) + + if not paddle.in_dynamic_mode(): + from paddle.distributed.auto_parallel.interface import ( + recompute as static_auto_recompute, + ) + + return static_auto_recompute(function)(*args, **kwargs) + if not use_reentrant: + _ = kwargs.pop("offload_indices", []) # currently not support offload_indices if framework._dygraph_tracer()._has_grad: check_args = list(args) check_args.extend(list(kwargs.values())) From 00b5e2c0edf1c93cef3d22f7361bd23fdd3a7020 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Fri, 27 Dec 2024 14:31:40 +0800 Subject: [PATCH 16/24] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dllama=20qwen=E7=9A=84PP?= =?UTF-8?q?,=20=E4=BD=BF=E7=94=A8=E5=90=8C=E6=A0=B7=E7=9A=84=E9=80=BB?= =?UTF-8?q?=E8=BE=91,=E4=BB=85=E6=9C=89=E5=BD=93=E5=BC=80=E5=90=AFrr?= =?UTF-8?q?=E7=9A=84=E6=97=B6=E5=80=99=E6=89=8D=E4=BC=9A=E4=BD=BF=E7=94=A8?= =?UTF-8?q?rr=5Frecompute?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddlenlp/transformers/llama/modeling_pp.py | 9 ++++++--- paddlenlp/transformers/qwen2/modeling_pp.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index 9a4030d5e520..048967960ee0 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -22,9 +22,11 @@ PipelineLayer, SharedLayerDesc, ) +from paddle.distributed.fleet.recompute.recompute import recompute from paddlenlp.transformers.model_utils import PipelinePretrainedModel -from paddlenlp.transformers.refined_recompute import get_skip_recompute_ops, recompute +from paddlenlp.transformers.refined_recompute import get_skip_recompute_ops +from paddlenlp.transformers.refined_recompute import recompute as rr_recompute from paddlenlp.utils.tools import get_env_device from .modeling import ( @@ -248,8 +250,9 @@ def forward(self, args): and self.config.recompute_granularity == "full" and has_gradient ): + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute if attention_mask is not None or alibi is not None or attn_mask_startend_row_indices is not None: - hidden_states = recompute( + hidden_states = recompute_fn( super().forward, hidden_states, position_ids=position_ids, @@ -260,7 +263,7 @@ def forward(self, args): ) else: # for pretrain - hidden_states = recompute( + hidden_states = recompute_fn( super().forward, hidden_states, position_ids=position_ids, diff --git a/paddlenlp/transformers/qwen2/modeling_pp.py b/paddlenlp/transformers/qwen2/modeling_pp.py index cf78b50e6284..80ce448e368b 100644 --- a/paddlenlp/transformers/qwen2/modeling_pp.py +++ b/paddlenlp/transformers/qwen2/modeling_pp.py @@ -23,8 +23,10 @@ PipelineLayer, SharedLayerDesc, ) +from paddle.distributed.fleet.recompute.recompute import recompute -from paddlenlp.transformers.refined_recompute import get_skip_recompute_ops, recompute +from paddlenlp.transformers.refined_recompute import get_skip_recompute_ops +from paddlenlp.transformers.refined_recompute import recompute as rr_recompute from ...utils.tools import get_env_device from ..model_utils import PipelinePretrainedModel @@ -170,8 +172,9 @@ def forward(self, args): attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute if attention_mask is not None or attn_mask_startend_row_indices is not None: - hidden_states = recompute( + hidden_states = recompute_fn( super().forward, hidden_states, position_ids=position_ids, @@ -181,7 +184,7 @@ def forward(self, args): ) else: # for pretrain - hidden_states = recompute( + hidden_states = recompute_fn( super().forward, hidden_states, position_ids=position_ids, From b19fd124a1e53e19bcc6cb41e0f2853baf1f8bee Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Mon, 30 Dec 2024 12:13:03 +0800 Subject: [PATCH 17/24] fix missing --- paddlenlp/transformers/qwen/modeling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index f0f73c70427f..2f465e9c3d8c 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -495,6 +495,7 @@ def __init__(self, config, skip_recompute_ops=None): super().__init__() if skip_recompute_ops is None: skip_recompute_ops = {} + self.skip_recompute_ops = skip_recompute_ops self.sequence_parallel = config.sequence_parallel self.ln_1 = QWenRMSNorm(config) self.attn = QWenAttention(config, skip_recompute_ops=skip_recompute_ops) From 13b4d9f2408f770e8ec7f28b4c74b605aa31d38e Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Mon, 30 Dec 2024 12:16:25 +0800 Subject: [PATCH 18/24] dpo test refined_recompute --- tests/fixtures/llm/dpo.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/fixtures/llm/dpo.yaml b/tests/fixtures/llm/dpo.yaml index 1fdc486866bd..0e52138119eb 100644 --- a/tests/fixtures/llm/dpo.yaml +++ b/tests/fixtures/llm/dpo.yaml @@ -21,6 +21,7 @@ dpo: use_flash_attention: true disable_tqdm: true recompute: true + refined_recompute: "flash_attn:-1" save_total_limit: 1 tensor_parallel_degree: 1 pipeline_parallel_degree: 1 From db26410a83d0626cf7e43a84983cc9dbc589363a Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Mon, 30 Dec 2024 13:05:49 +0800 Subject: [PATCH 19/24] fix dpo --- llm/alignment/dpo/run_dpo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llm/alignment/dpo/run_dpo.py b/llm/alignment/dpo/run_dpo.py index ad69941beef8..d80ee0b7e3ea 100644 --- a/llm/alignment/dpo/run_dpo.py +++ b/llm/alignment/dpo/run_dpo.py @@ -80,14 +80,14 @@ def main(): hasattr(training_args, "pipeline_parallel_config") and "enable_clear_every_step_cache" in training_args.pipeline_parallel_config ), "Should set '--pipeline_parallel_config enable_clear_every_step_cache' in bash script for pp." - if model_args.sequence_parallel: + if training_args.sequence_parallel: if training_args.pipeline_parallel_degree > 1: assert ( hasattr(training_args, "pipeline_parallel_config") and "disable_partial_send_recv" in training_args.pipeline_parallel_config ), "Should set '--pipeline_parallel_config disable_partial_send_recv' in bash script for pp with sp." if training_args.tensor_parallel_degree <= 1: - model_args.sequence_parallel = False + training_args.sequence_parallel = False logger.info("Tensor_parallel_degree = 1. Set sequence_parallel to False.") training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") From fc5bbf9192b3b676d3db018f5c810ddc981efb45 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Mon, 30 Dec 2024 18:12:45 +0800 Subject: [PATCH 20/24] fix --- paddlenlp/transformers/llama/modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 5248cc97b885..dc3318b621a2 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1636,7 +1636,7 @@ def custom_forward(*inputs): return custom_forward - recompute_fn = rr_recompute if self.config.refined_recompute else recompute + recompute_fn = rr_recompute if any(layer_module.skip_recompute_ops.values()) else recompute hidden_states = recompute_fn( create_custom_forward(layer_module), hidden_states, From 592c964ef9479a5dd7db63654e9821a4c33d200c Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Tue, 31 Dec 2024 11:38:13 +0800 Subject: [PATCH 21/24] set device gpu --- tests/transformers/test_refined_recompute.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/transformers/test_refined_recompute.py b/tests/transformers/test_refined_recompute.py index 3798f7d459e6..c87378c6d5f8 100644 --- a/tests/transformers/test_refined_recompute.py +++ b/tests/transformers/test_refined_recompute.py @@ -562,6 +562,7 @@ def test_refined_recompute_pp(self): class TestRefinedRecomputeModel(unittest.TestCase): def setUp(self): + paddle.set_device("gpu") self.args = TrainingArguments( output_dir="./", do_train=True, @@ -573,6 +574,7 @@ def setUp(self): @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") def test_llama_refined_recompute(self): + paddle.set_device("gpu") from paddlenlp.transformers.llama import LlamaConfig, LlamaModel llama_model = "__internal_testing__/tiny-random-llama" @@ -590,6 +592,7 @@ def test_llama_refined_recompute(self): @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") def test_qwen_refined_recompute(self): + paddle.set_device("gpu") from paddlenlp.transformers.qwen import QWenConfig, QWenModel llama_model = "__internal_testing__/tiny-random-qwen" @@ -608,6 +611,7 @@ def test_qwen_refined_recompute(self): @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") def test_qwen2_refined_recompute(self): + paddle.set_device("gpu") from paddlenlp.transformers.qwen2 import Qwen2Config, Qwen2Model llama_model = "__internal_testing__/tiny-random-qwen2" From 9a3b541f5e35ed9d831a7393676ff7a0f3c415e3 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Tue, 31 Dec 2024 11:51:48 +0800 Subject: [PATCH 22/24] is_compiled_with_cuda only test gpu --- tests/common_test.py | 2 +- tests/transformers/test_refined_recompute.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/common_test.py b/tests/common_test.py index c1e1191089dc..0b6424dd27ce 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -36,7 +36,7 @@ def __init__(self, methodName="runTest"): self.config = {} self.places = ["cpu"] if paddle.is_compiled_with_cuda(): - self.places.append("gpu") + self.places = ["gpu"] @classmethod def setUpClass(cls): diff --git a/tests/transformers/test_refined_recompute.py b/tests/transformers/test_refined_recompute.py index c87378c6d5f8..3798f7d459e6 100644 --- a/tests/transformers/test_refined_recompute.py +++ b/tests/transformers/test_refined_recompute.py @@ -562,7 +562,6 @@ def test_refined_recompute_pp(self): class TestRefinedRecomputeModel(unittest.TestCase): def setUp(self): - paddle.set_device("gpu") self.args = TrainingArguments( output_dir="./", do_train=True, @@ -574,7 +573,6 @@ def setUp(self): @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") def test_llama_refined_recompute(self): - paddle.set_device("gpu") from paddlenlp.transformers.llama import LlamaConfig, LlamaModel llama_model = "__internal_testing__/tiny-random-llama" @@ -592,7 +590,6 @@ def test_llama_refined_recompute(self): @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") def test_qwen_refined_recompute(self): - paddle.set_device("gpu") from paddlenlp.transformers.qwen import QWenConfig, QWenModel llama_model = "__internal_testing__/tiny-random-qwen" @@ -611,7 +608,6 @@ def test_qwen_refined_recompute(self): @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") def test_qwen2_refined_recompute(self): - paddle.set_device("gpu") from paddlenlp.transformers.qwen2 import Qwen2Config, Qwen2Model llama_model = "__internal_testing__/tiny-random-qwen2" From 55d4bbf4270cca52ff902a826f8ad00e69b22ce8 Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Thu, 2 Jan 2025 09:47:22 +0800 Subject: [PATCH 23/24] fix test --- llm/alignment/dpo/run_dpo.py | 10 +--------- llm/alignment/kto/run_kto.py | 10 +--------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/llm/alignment/dpo/run_dpo.py b/llm/alignment/dpo/run_dpo.py index 5f1960e3cf9f..bc58fb51f612 100644 --- a/llm/alignment/dpo/run_dpo.py +++ b/llm/alignment/dpo/run_dpo.py @@ -46,7 +46,6 @@ register_sequence_parallel_allreduce_hooks, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig -from paddlenlp.transformers.refined_recompute import update_refined_recompute from paddlenlp.trl import ( DPOTrainer, calculate_effective_tokens, @@ -122,17 +121,10 @@ def main(): model_config = AutoConfig.from_pretrained(model_args.model_name_or_path, dtype=dtype) LlmMetaConfig.set_llm_config(model_config, training_args) - model_config.refined_recompute = update_refined_recompute( - training_args.refined_recompute, - dpo_config.lora, - ) + if not dpo_config.reference_free and not dpo_config.lora: ref_model_config = AutoConfig.from_pretrained(model_args.model_name_or_path, dtype=dtype) LlmMetaConfig.set_llm_config(ref_model_config, training_args) - ref_model_config.refined_recompute = update_refined_recompute( - training_args.refined_recompute, - dpo_config.lora, - ) if training_args.pipeline_parallel_degree > 1: model_class = AutoModelForCausalLMPipe diff --git a/llm/alignment/kto/run_kto.py b/llm/alignment/kto/run_kto.py index 41a5b68d7608..94fe99dd8b75 100644 --- a/llm/alignment/kto/run_kto.py +++ b/llm/alignment/kto/run_kto.py @@ -41,7 +41,6 @@ register_sequence_parallel_allreduce_hooks, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig -from paddlenlp.transformers.refined_recompute import update_refined_recompute from paddlenlp.trl import ( KTOTrainer, calculate_effective_tokens, @@ -108,17 +107,10 @@ def main(): logger.info("Start to load model & tokenizer.") model_config = AutoConfig.from_pretrained(model_args.model_name_or_path, dtype=dtype) LlmMetaConfig.set_llm_config(model_config, training_args) - model_config.refined_recompute = update_refined_recompute( - training_args.refined_recompute, - kto_config.lora, - ) + if not kto_config.lora: ref_model_config = AutoConfig.from_pretrained(model_args.model_name_or_path, dtype=dtype) LlmMetaConfig.set_llm_config(ref_model_config, training_args) - ref_model_config.refined_recompute = update_refined_recompute( - training_args.refined_recompute, - kto_config.lora, - ) if training_args.pipeline_parallel_degree > 1: model_class = AutoModelForCausalLMPipe From aa17c8097db8176383aa9c992ca1c1dda9b257ba Mon Sep 17 00:00:00 2001 From: yujun <573009727@qq.com> Date: Thu, 2 Jan 2025 12:07:18 +0800 Subject: [PATCH 24/24] make sure test on gpu --- tests/transformers/test_refined_recompute.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/transformers/test_refined_recompute.py b/tests/transformers/test_refined_recompute.py index 3798f7d459e6..4217c82791bf 100644 --- a/tests/transformers/test_refined_recompute.py +++ b/tests/transformers/test_refined_recompute.py @@ -530,6 +530,7 @@ def pp_fwd_bwd( @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") def test_refined_recompute_pp(self): + paddle.set_device("gpu") raw_dtype = paddle.get_default_dtype() grad1, layer1 = self.pp_fwd_bwd(recompute=True, use_rr_recompute=False) grad2, layer2 = self.pp_fwd_bwd(recompute=True, use_rr_recompute=True) @@ -573,6 +574,7 @@ def setUp(self): @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") def test_llama_refined_recompute(self): + paddle.set_device("gpu") from paddlenlp.transformers.llama import LlamaConfig, LlamaModel llama_model = "__internal_testing__/tiny-random-llama" @@ -590,6 +592,7 @@ def test_llama_refined_recompute(self): @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") def test_qwen_refined_recompute(self): + paddle.set_device("gpu") from paddlenlp.transformers.qwen import QWenConfig, QWenModel llama_model = "__internal_testing__/tiny-random-qwen" @@ -608,6 +611,7 @@ def test_qwen_refined_recompute(self): @unittest.skipIf(not is_paddle_cuda_available(), "refined-recompute-pp only support on gpu") def test_qwen2_refined_recompute(self): + paddle.set_device("gpu") from paddlenlp.transformers.qwen2 import Qwen2Config, Qwen2Model llama_model = "__internal_testing__/tiny-random-qwen2"