We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nezha_auto_title_train.py做修改的核心代码在4个地方:一是增加 self.bert_model=torch.nn.DataParallel(self.bert_model,device_ids=[0,1])(建议在加载完预训练模型等在训练前需要对原self.bert_model做的处理都完成之后,如原代码167行 self.bert_model.load_pretrain_params(model_path) 之后)。二是在此后调用self.bert_model时改为调用self.bert_model.module,如原代码 self.bert_model.save_all_params(save_path) 需要改成self.bert_model.module.save_all_params(save_path),因为torch.nn.DataParallel会使原模型的方法不可见。三是这样直接得到的loss就不是一个元素而是多个元素,所以原代码第153行 report_loss += loss.item() 不可用。 参考#15 ,我把代码改成了这样:
self.bert_model=torch.nn.DataParallel(self.bert_model,device_ids=[0,1])
self.bert_model.load_pretrain_params(model_path)
self.bert_model.save_all_params(save_path)
self.bert_model.module.save_all_params(save_path)
report_loss += loss.item()
loss=loss.mean() report_loss += loss.item()
事实上我并不确定这样直接求平均是否正确……看起来应该是合理的……反正能跑起来了。 第四点是在forward()函数运行之前就手动把输入数据放置到模型初始卡上(原代码148-149行之间的位置):
token_ids=token_ids.to(self.device) token_type_ids=token_type_ids.to(self.device) target_ids=target_ids.to(self.device)
对源代码的修改: bert_seq2seq/seq2seq_model.py 源代码69-75行手动将输入数据放置到模型初始卡上的代码删除:
input_tensor = input_tensor.to(self.device) token_type_id = token_type_id.to(self.device) if position_enc is not None: position_enc = position_enc.to(self.device) if labels is not None : labels = labels.to(self.device)
原80行 ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device) 删除,83-84行之间添加 ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=s_ex13.device)
ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device)
ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=s_ex13.device)
另外还有一个不必要的优化工作:参考 https://blog.csdn.net/weixin_43301333/article/details/111386343 这篇博文的建议,把源代码第97行 return predictions, loss 改成了 return loss,然后nezha_auto_title_train.py原第149行 predictions, loss = self.bert_model(token_ids, 改成 loss = self.bert_model(token_ids,,因为反正predictions不需要,只输出loss能有效降低模型初始卡上的消耗,每个卡上每个batch能多放一条数据了。
return predictions, loss
return loss
predictions, loss = self.bert_model(token_ids,
loss = self.bert_model(token_ids,
max_position_embeddings=2048
The text was updated successfully, but these errors were encountered:
谢谢老哥,明天我改一下哈,感谢感谢~
Sorry, something went wrong.
No branches or pull requests
nezha_auto_title_train.py做修改的核心代码在4个地方:一是增加
self.bert_model=torch.nn.DataParallel(self.bert_model,device_ids=[0,1])
(建议在加载完预训练模型等在训练前需要对原self.bert_model做的处理都完成之后,如原代码167行self.bert_model.load_pretrain_params(model_path)
之后)。二是在此后调用self.bert_model时改为调用self.bert_model.module,如原代码self.bert_model.save_all_params(save_path)
需要改成self.bert_model.module.save_all_params(save_path)
,因为torch.nn.DataParallel会使原模型的方法不可见。三是这样直接得到的loss就不是一个元素而是多个元素,所以原代码第153行report_loss += loss.item()
不可用。参考#15 ,我把代码改成了这样:
事实上我并不确定这样直接求平均是否正确……看起来应该是合理的……反正能跑起来了。
第四点是在forward()函数运行之前就手动把输入数据放置到模型初始卡上(原代码148-149行之间的位置):
对源代码的修改:
bert_seq2seq/seq2seq_model.py
源代码69-75行手动将输入数据放置到模型初始卡上的代码删除:
原80行
ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device)
删除,83-84行之间添加ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=s_ex13.device)
另外还有一个不必要的优化工作:参考 https://blog.csdn.net/weixin_43301333/article/details/111386343 这篇博文的建议,把源代码第97行
return predictions, loss
改成了return loss
,然后nezha_auto_title_train.py原第149行predictions, loss = self.bert_model(token_ids,
改成loss = self.bert_model(token_ids,
,因为反正predictions不需要,只输出loss能有效降低模型初始卡上的消耗,每个卡上每个batch能多放一条数据了。我觉得这个参数看起来组合jieba.cut会很好用。
二是希望能在调用时手动改模型的参数,如 https://github.com/920232796/bert_seq2seq/blob/master/bert_seq2seq/model/nezha_model.py 的BertConfig,之前2.3.2版本时
max_position_embeddings=2048
因为太大所以我手动改成了1024。希望能改成我可以在调用时就直接修改的方式。The text was updated successfully, but these errors were encountered: