Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qwen2-sft 训练起步阶段就卡住 #325

Open
baisechundu opened this issue Aug 22, 2024 · 6 comments
Open

qwen2-sft 训练起步阶段就卡住 #325

baisechundu opened this issue Aug 22, 2024 · 6 comments

Comments

@baisechundu
Copy link

使用的是 README.md 中推荐的镜像
目前发现有两个问题:
问题一:

megatron_core 0.7.0
Pai-Megatron-Patch 0.8.3 / Pai-Megatron-Patch 0.9.0 都试过

按照 examples/qwen2 下面的README.md,对 Qwen2-1.5B 进行的操作(A100 4ka):

1. TP1 PP1 pretrain 和 sft 均正常
2. TP1 PP2 pretrain 正常,sft 不正常,现象就是:一半的GPU占用显示为100%,另一半是0,训练卡住无法继续
3. TP2  PP2 同2

目前针对qwen2 的 sft,只要TP或者PP超过1,均会出现卡住的情况。

问题二:
megatron_patch/data/utils.py 中 代码的147行左右

sep_index = (labels[0] == tokenizer.sep_token_id).nonzero(as_tuple=True)[0]
labels[:, :sep_index] = -100

这里得到的 sep_index 可能为空,或类型是张量,下一行作为索引时会报错(注 :Pai-Megatron-Patch 0.9.0 这个发布包默认使用 idxmap 类型的数据集,会出现此问题。Pai-Megatron-Patch 0.8.3 用的是Raw dataset,不会有这个问题)。

我看到主分支提到了qwen2 的 “hang“ 的情况,主分支的代码我也试过了,还是会卡住。pretrain 和 finetune 的模式是在哪块实现有差异吗?求解惑哈

@jerryli1981
Copy link
Collaborator

您好,更新下代码库,#326

@jerryli1981
Copy link
Collaborator

目前llama3.1的sft是同时指出json和idxmap且不会卡住的。qwen2的sft使用json微调应该也不会卡住,如果使用idxmap微调,您需要重新用最新的代码制作下idxmap微调数据,然后用同样的seqlen来启动微调脚本

@tzyodear
Copy link

tzyodear commented Sep 2, 2024

我遇到过问题一的第二点类似的现象。
我是开了micro batch size = 1正常训练, >1就会卡住。原因是在

attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(

如果不开reset_attention_mask,这里生成的attention mask维度是[1, 1, seq_len, seq_len], 而不是[bs, 1, seq_len, seq_len]。导致维度不一致。
我不太清楚是不是我个人的问题,总之我手动expand了一下shape 就正常运行了

@pizts
Copy link

pizts commented Sep 10, 2024

sft训练数据是json的时候,tp=2的时候会卡住,tp=4的时候不卡,不知道啥原因

@renwuli
Copy link

renwuli commented Nov 11, 2024

我遇到过问题一的第二点类似的现象。 我是开了micro batch size = 1正常训练, >1就会卡住。原因是在

attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(

如果不开reset_attention_mask,这里生成的attention mask维度是[1, 1, seq_len, seq_len], 而不是[bs, 1, seq_len, seq_len]。导致维度不一致。 我不太清楚是不是我个人的问题,总之我手动expand了一下shape 就正常运行了

我也遇到了你同样的问题,想问一下,你手动expand之后精度正常吗?

@tzyodear
Copy link

我遇到过问题一的第二点类似的现象。 我是开了micro batch size = 1正常训练, >1就会卡住。原因是在

attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(

如果不开reset_attention_mask,这里生成的attention mask维度是[1, 1, seq_len, seq_len], 而不是[bs, 1, seq_len, seq_len]。导致维度不一致。 我不太清楚是不是我个人的问题,总之我手动expand了一下shape 就正常运行了

我也遇到了你同样的问题,想问一下,你手动expand之后精度正常吗?

这里mask是bool,只是expand第一个维度,对精度没啥影响吧。
卡住的原因很多,还是得自己细致检查每个环节,我这里只是个例

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants