Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove refined recompute deep copy #9617

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,47 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并

Recompute the forward pass to calculate gradients. Used for saving memory (default: False)

--refined_recompute
精化重新计算参数,用于在GPU显存使用和计算速度之间寻求最佳平衡。
此参数允许用户对重新计算过程进行细致控制,以优化资源利用。具体配置示例如下:
`"attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1"`

在配置中,支持的参数包括:
`attention_column_ln`
`attention_row_ln`
`mlp_column_ln`
`mlp_row_ln`
`flash_attn`

每个参数后的数字,即`skip_num`,决定了对应操作跳过重计算的次数。具体解释如下:
`skip_num` 为 `-1`:表示在所有阶段均不进行重新计算,从而最大化显存使用。
`skip_num` 为 `0`:表示在每个阶段都强制进行重新计算,以最小化显存使用。

此外,您还可以将`skip_num`设置为`[1, ..., num_layers]`范围内的任意值。若`skip_num`超出`num_layers`,其行为将等同于设置为`-1`。
若配置中省略了某个参数,则系统默认将其设置为`xxx:0`。

(类型: `str`, 可选, 默认为: "")

Refined recompute parameter for optimizing the balance between GPU memory usage and computational speed.
This parameter allows fine-grained control over the recomputation process to optimize resource utilization. An example configuration is as follows:
`"attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1"`

The supported parameters in the configuration include:
`attention_column_ln`
`attention_row_ln`
`mlp_column_ln`
`mlp_row_ln`
`flash_attn`

The number following each parameter, `skip_num`, determines the number of times to bypass recomputation for the specified operation. Specifically:
`skip_num of -1`: Indicates no recomputation across all stages, maximizing memory usage.
`skip_num of 0`: Enforces recomputation at every stage, minimizing memory usage.

Additionally, you can set skip_num to any value within the range `[1, ..., num_layers]`. If `skip_num` exceeds `num_layers`, it will behave as if set to `-1`.
If a parameter is omitted from the configuration, it defaults to `xxx:0`.

(Type: `str`, optional, default: "")

--minimum_eval_times
最少评估次数,如果当前设置的eval_steps,评估次数少于minimum_eval_times,
此选项会覆盖eval_steps参数。
Expand Down
1 change: 1 addition & 0 deletions llm/docs/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ python ./predict/reft_predictor.py \
- `do_train`: 是否打开训练,默认为 False。
- `do_eval`: 是否打开评估,默认为 False。
- `recompute`: 重计算,暂支持 full 策略。开启后可降低显存以达到增大 batch size 的目的,默认为 False。
- `refined_recompute`: 精细化重计算,通过精细化控制所需重计算的部分从而达到显存和性能之间的均衡,当前仅支持`llama`系列模型以及`qwen`系列模型,详细使用请参考[TrainingArguments 文档](https://paddlenlp.readthedocs.io/zh/latest/trainer.html)。
- `tensor_parallel_degree`: 此参数 tensor_parallel_degree 表示将一层 transformer 结构的份数,该方法对通信开销较大, 建议 tensor_parallel_degree<=8, 尽量使用机器内部通信。默认为-1,表示不启用张量并行。
- `pipeline_parallel_degree`: 表示划分流水线的大小.(假设该参数为4, 模型12层, 则每一个 pp stage 包含3层模型) 默认值-1, 表示不启用流水线并行。
- `sharding_parallel_degree`: 表示分组参数切片的数据并行大小. 默认值1, 表示不启用分组参数切片的数据并行。
Expand Down
2 changes: 0 additions & 2 deletions llm/run_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Qwen2SentenceEmbedding,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.trl import DataConfig, EmbeddingTrainer, ModelConfig, SFTConfig
from paddlenlp.trl.llm_utils import compute_metrics, init_chat_template
from paddlenlp.utils.log import logger
Expand Down Expand Up @@ -88,7 +87,6 @@ def main():
assert isinstance(model_config, Qwen2Config), "Now only qwen2 supported"

LlmMetaConfig.set_llm_config(model_config, training_args)
model_config.refined_recompute = update_refined_recompute(training_args.refined_recompute)
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm

# Config for model using dropout, such as GPT.
Expand Down
5 changes: 0 additions & 5 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
from paddlenlp.trl.llm_utils import (
ZeroPaddingIterDatasetCallback,
Expand Down Expand Up @@ -154,10 +153,6 @@ def main():
)

LlmMetaConfig.set_llm_config(model_config, training_args)
model_config.refined_recompute = update_refined_recompute(
training_args.refined_recompute,
model_args.lora,
)
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm

# Config for model using dropout, such as GPT.
Expand Down
4 changes: 0 additions & 4 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device
Expand Down Expand Up @@ -406,9 +405,6 @@ def main():
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
# set all llm config
LlmMetaConfig.set_llm_config(config, training_args)
config.refined_recompute = update_refined_recompute(
training_args.refined_recompute,
)
config.use_fast_layer_norm = model_args.use_fast_layer_norm

config.seq_length = data_args.max_seq_length
Expand Down
59 changes: 59 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,15 @@
recompute (`bool`, *optional*, defaults to `False`):
Recompute the forward pass to calculate gradients. Used for saving memory.
Only support for networks with transformer blocks.
refined_recompute (`str`, *optional*, defaults to `""`):
The refined recompute parameter is designed to optimize the balance between GPU memory usage and computational speed.
An example configuration could be: `attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1`.
The supported parameters for refining recompute are `attention_column_ln`, `attention_row_ln`, `flash_attn`, `mlp_column_ln`, and `mlp_row_ln`.
The associated number, `skip_num`, determines how many times to bypass recomputation for the specified operation.
A `skip_num` of `-1` indicates no recomputation across all stages, maximizing memory usage;
A `skip_num` of `0` enforces recomputation at every stage, minimizing memory usage.
You can also set `skip_num` to a value within the range [1, ..., num_layers]. If `skip_num` exceeds `num_layers`, it will behave as if set to `-1`.
If a parameter is omitted, it defaults to `xxx:0`."
scale_loss (`float`, *optional*, defaults to 32768):
The value of initial scale_loss for fp16. (default: 32768)
local_rank (`int`, *optional*, defaults to -1):
Expand Down Expand Up @@ -740,6 +749,19 @@
"Only support for networks with transformer blocks."
},
)
refined_recompute: str = field(
JunnYu marked this conversation as resolved.
Show resolved Hide resolved
default="",
metadata={
"help": "The refined recompute parameter is designed to optimize the balance between GPU memory usage and computational speed.\n"
"An example configuration could be: `attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1`.\n"
"The supported parameters for refining recompute are `attention_column_ln`, `attention_row_ln`, `flash_attn`, `mlp_column_ln`, and `mlp_row_ln`.\n"
"The associated number, `skip_num`, determines how many times to bypass recomputation for the specified operation.\n"
"A `skip_num` of `-1` indicates no recomputation across all stages, maximizing memory usage;\n"
"A `skip_num` of `0` enforces recomputation at every stage, minimizing memory usage.\n"
"You can also set `skip_num` to a value within the range [1, ..., num_layers]. If `skip_num` exceeds `num_layers`, it will behave as if set to `-1`.\n"
"If a parameter is omitted, it defaults to `xxx:0`."
},
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


scale_loss: float = field(default=2**15, metadata={"help": "The value of initial scale_loss for fp16."})

Expand Down Expand Up @@ -1755,6 +1777,43 @@
f"The local_ran: {self.local_rank} should be consistent with the world size: {paddle.distributed.get_world_size()}."
)

# arse_refined_recompute string to dict
if self.refined_recompute in [None, ""]:
self.refined_recompute = dict()
else:
refined_recompute_dict = {

Check warning on line 1784 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1784

Added line #L1784 was not covered by tests
"mlp_row_ln": 0,
"attention_row_ln": 0,
"attention_column_ln": 0,
"mlp_column_ln": 0,
"flash_attn": 0,
}
ops = self.refined_recompute.split(",")
enable_rr = False
for op in ops:
op = op.strip()
if ":" not in op:
raise ValueError("Illegal refined_recompute input, please check.")
op_name, skip_num = op.split(":")[0], int(op.split(":")[1])
if op_name not in refined_recompute_dict:
raise ValueError(f"Refined recompute do not support {op_name}, please check.")
if (

Check warning on line 1800 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1791-L1800

Added lines #L1791 - L1800 were not covered by tests
op_name in ["mlp_row_ln", "attention_row_ln", "attention_column_ln", "mlp_column_ln"]
and self.tensor_parallel_degree <= 1
):
logger.warning(

Check warning on line 1804 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1804

Added line #L1804 was not covered by tests
f"Refined recompute is only supported for the `{op_name}` operation when `tensor_parallel_degree` is greater than 1. \
This refined recompute operation will be ignored."
)
continue

Check warning on line 1808 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1808

Added line #L1808 was not covered by tests

refined_recompute_dict[op_name] = skip_num
if skip_num != 0:
enable_rr = True
if not enable_rr:
refined_recompute_dict = dict()
self.refined_recompute = refined_recompute_dict

Check warning on line 1815 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1810-L1815

Added lines #L1810 - L1815 were not covered by tests

def __str__(self):
self_as_dict = asdict(self)
self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()}
Expand Down
1 change: 0 additions & 1 deletion paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ class LlmMetaConfig:
"",
"refined_recompute, Choose from 'mlp_row_ln', 'mlp_column_ln', 'attention_row_ln', 'attention_column_ln', 'flash_attn']",
),
("skip_recompute_ops", Optional[Dict[str, int]], None, "skip_recompute_ops"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_recompute_ops 这个没有了,现在加在哪里?

]

@classmethod
Expand Down
44 changes: 22 additions & 22 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
RRColumnSequenceParallelLinear,
RRRowParallelLinear,
RRRowSequenceParallelLinear,
create_skip_config_for_refined_recompute,
get_skip_recompute_ops,
recompute,
)

Expand Down Expand Up @@ -605,8 +605,9 @@


class LlamaMLP(nn.Layer):
def __init__(self, config):
def __init__(self, config, skip_recompute_ops={}):
JunnYu marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.skip_recompute_ops = skip_recompute_ops
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.tensor_parallel_degree = config.tensor_parallel_degree
Expand All @@ -618,19 +619,19 @@

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
if skip_recompute_ops.get("mlp_column_ln", False):

Check warning on line 622 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L622

Added line #L622 was not covered by tests
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
if skip_recompute_ops.get("mlp_row_ln", False):

Check warning on line 624 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L624

Added line #L624 was not covered by tests
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
if skip_recompute_ops.get("mlp_column_ln", False):

Check warning on line 632 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L632

Added line #L632 was not covered by tests
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
if skip_recompute_ops.get("mlp_row_ln", False):

Check warning on line 634 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L634

Added line #L634 was not covered by tests
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -682,9 +683,9 @@
class LlamaAttention(nn.Layer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, skip_recompute_ops={}):
super().__init__()

self.skip_recompute_ops = skip_recompute_ops
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
Expand Down Expand Up @@ -746,18 +747,18 @@

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
if skip_recompute_ops.get("attention_column_ln", False):

Check warning on line 750 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L750

Added line #L750 was not covered by tests
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
if skip_recompute_ops.get("attention_row_ln", False):

Check warning on line 752 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L752

Added line #L752 was not covered by tests
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
if skip_recompute_ops.get("attention_column_ln", False):

Check warning on line 759 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L759

Added line #L759 was not covered by tests
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
if skip_recompute_ops.get("attention_row_ln", False):

Check warning on line 761 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L761

Added line #L761 was not covered by tests
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -859,11 +860,7 @@
self.attn_func = scaled_dot_product_attention

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if (
config.recompute
and not config.recompute_use_reentrant
and config.skip_recompute_ops.get("flash_attn", False)
):
if config.recompute and not config.recompute_use_reentrant and skip_recompute_ops.get("flash_attn", False):
self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True)

def _init_rope(self):
Expand Down Expand Up @@ -1168,12 +1165,13 @@


class LlamaDecoderLayer(nn.Layer):
def __init__(self, config, layerwise_recompute: bool = False):
def __init__(self, config, layerwise_recompute: bool = False, skip_recompute_ops={}):
super().__init__()
self.config = config
self.skip_recompute_ops = skip_recompute_ops
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config, layerwise_recompute)
self.mlp = LlamaMLP(config)
self.self_attn = LlamaAttention(config, layerwise_recompute, skip_recompute_ops=skip_recompute_ops)
self.mlp = LlamaMLP(config, skip_recompute_ops=skip_recompute_ops)
self.input_layernorm = LlamaRMSNorm(config)
self.post_attention_layernorm = LlamaRMSNorm(config)
self.sequence_parallel = config.sequence_parallel
Expand Down Expand Up @@ -1518,9 +1516,11 @@
self.layers = nn.LayerList(
[
LlamaDecoderLayer(
create_skip_config_for_refined_recompute(i, config), i not in self.no_recompute_layers
config=config,
layerwise_recompute=layer_idx not in self.no_recompute_layers,
skip_recompute_ops=get_skip_recompute_ops(config, layer_idx),
)
for i in range(config.num_hidden_layers)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config)
Expand Down
8 changes: 3 additions & 5 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
)

from paddlenlp.transformers.model_utils import PipelinePretrainedModel
from paddlenlp.transformers.refined_recompute import (
create_skip_config_for_refined_recompute,
recompute,
)
from paddlenlp.transformers.refined_recompute import get_skip_recompute_ops, recompute
from paddlenlp.utils.tools import get_env_device

from .modeling import (
Expand Down Expand Up @@ -377,8 +374,9 @@ def get_hcg():
self.add_sequential_layer(
LayerDesc(
LlamaDecoderLayerPipe,
config=create_skip_config_for_refined_recompute(i, config),
config=config,
layerwise_recompute=i not in self.no_recompute_layers,
skip_recompute_ops=get_skip_recompute_ops(config, i),
),
f"llama.layers.{i}",
)
Expand Down
Loading