Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

代码问题 multi-head attention 中的 extended_attention_mask #39

Open
NiceMartin opened this issue Jul 17, 2021 · 4 comments
Open

Comments

@NiceMartin
Copy link

NiceMartin commented Jul 17, 2021

在 bert_model.py 的 445行, 下面的代码好像有点问题:
extended_attention_mask = extended_attention_mask.unsqueeze(1).unsqueeze(2)
if attention_mask is not None :
## 如果传进来的注意力mask不是null,那就直接用传进来的注意力mask 乘 原始mask
# 注意 原始mask是extended_attention_mask,这个是用来把pad部分置为0,去掉pad部分影响
extended_attention_mask = attention_mask * extended_attention_mask

在原始的Bert代码中, 是这样的:
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)

    # We create a 3D attention mask from a 2D tensor mask.
    # Sizes are [batch_size, 1, 1, to_seq_length]
    # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
    # this attention mask is more simple than the triangular masking of causal attention
    # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

上面两种操作, 产生的 extended_attention_mask 是有差异的. 不过因为在本代码的训练过程中, attention_mask 一直为None, 所以训练过程没有问题, 但实际上代码可能存在问题

@920232796
Copy link
Owner

谁说的attention mask一直为None, 如果是unilm,那就不是None呀 是一个特殊的mask,需要传进来的。

ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device)
a_mask = ones.tril() # 下三角矩阵
s_ex12 = token_type_id.unsqueeze(1).unsqueeze(2).float()
s_ex13 = token_type_id.unsqueeze(1).unsqueeze(3).float()
a_mask = (1.0 - s_ex12) * (1.0 - s_ex13) + s_ex13 * a_mask

enc_layers, _ = self.bert(input_tensor, position_ids=position_enc, token_type_ids=token_type_id, attention_mask=a_mask,
output_all_encoded_layers=True)

@NiceMartin
Copy link
Author

谁说的attention mask一直为None, 如果是unilm,那就不是None呀 是一个特殊的mask,需要传进来的。

ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device)
a_mask = ones.tril() # 下三角矩阵
s_ex12 = token_type_id.unsqueeze(1).unsqueeze(2).float()
s_ex13 = token_type_id.unsqueeze(1).unsqueeze(3).float()
a_mask = (1.0 - s_ex12) * (1.0 - s_ex13) + s_ex13 * a_mask

enc_layers, _ = self.bert(input_tensor, position_ids=position_enc, token_type_ids=token_type_id, attention_mask=a_mask,
output_all_encoded_layers=True)

没看到UNILM那块, 不过 你代码里的 extended_attention_mask和BERT源码中的不一样, 这个是什么问题呢?

@920232796
Copy link
Owner

谁说的attention mask一直为None, 如果是unilm,那就不是None呀 是一个特殊的mask,需要传进来的。
ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device)
a_mask = ones.tril() # 下三角矩阵
s_ex12 = token_type_id.unsqueeze(1).unsqueeze(2).float()
s_ex13 = token_type_id.unsqueeze(1).unsqueeze(3).float()
a_mask = (1.0 - s_ex12) * (1.0 - s_ex13) + s_ex13 * a_mask
enc_layers, _ = self.bert(input_tensor, position_ids=position_enc, token_type_ids=token_type_id, attention_mask=a_mask,
output_all_encoded_layers=True)

没看到UNILM那块, 不过 你代码里的 extended_attention_mask和BERT源码中的不一样, 这个是什么问题呢?

不清楚,这个代码是我自己改过的,不是严格按照bert源码写的,它那个没办法做生成任务吧?

@NiceMartin
Copy link
Author

嗯, 不能做生成任务.

谁说的attention mask一直为None, 如果是unilm,那就不是None呀 是一个特殊的mask,需要传进来的。
ones = torch.ones((1, 1, seq_len, seq_len), dtype=torch.float32, device=self.device)
a_mask = ones.tril() # 下三角矩阵
s_ex12 = token_type_id.unsqueeze(1).unsqueeze(2).float()
s_ex13 = token_type_id.unsqueeze(1).unsqueeze(3).float()
a_mask = (1.0 - s_ex12) * (1.0 - s_ex13) + s_ex13 * a_mask
enc_layers, _ = self.bert(input_tensor, position_ids=position_enc, token_type_ids=token_type_id, attention_mask=a_mask,
output_all_encoded_layers=True)

没看到UNILM那块, 不过 你代码里的 extended_attention_mask和BERT源码中的不一样, 这个是什么问题呢?

不清楚,这个代码是我自己改过的,不是严格按照bert源码写的,它那个没办法做生成任务吧?

不能做生成任务. 就像我前面说的, 因为你的代码在其他任务调用时, attention_mask 为None. 这样的话, 就和BERT原始代码一致了.
不过你可以测试下, 在那些任务中 attention_mask 如果不设为 None, 结果会怎样

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants