Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support balanced dataset to speed-up training #506

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions internvl_chat/internvl/model/internlm2/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if 'padding_mask' in kwargs:
Expand Down Expand Up @@ -456,6 +457,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# InternLM2FlashAttention2 attention does not support output_attentions
Expand Down Expand Up @@ -510,7 +512,7 @@ def forward(
value_states = value_states.transpose(1, 2)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len
query_states, key_states, value_states, attention_mask, q_len, cu_seqlens=cu_seqlens
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.wo(attn_output)
Expand All @@ -521,7 +523,7 @@ def forward(
return attn_output, attn_weights, past_key_value

def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, cu_seqlens=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
Expand All @@ -544,7 +546,31 @@ def _flash_attention_forward(
"""
# Contains at least one padding token in the sequence
causal = self.is_causal and query_length != 1
if attention_mask is not None:
if cu_seqlens is not None:
cu_seqlens = cu_seqlens.to(query_states.device).to(torch.int32).view(-1)
cu_seqlens_offset = torch.zeros_like(cu_seqlens)
cu_seqlens_offset[:-1] = cu_seqlens[1:]
max_seqlen = max(cu_seqlens_offset[:-1] - cu_seqlens[:-1]).item()

_, _, q_heads, head_dim = query_states.shape
_, _, k_heads, head_dim = key_states.shape
query_states = query_states.view(-1, q_heads, head_dim)
key_states = key_states.view(-1, k_heads, head_dim)
value_states = value_states.view(-1, k_heads, head_dim)

attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
)
elif attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
query_states, key_states, value_states, attention_mask, query_length
Expand Down Expand Up @@ -640,6 +666,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -674,6 +701,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -876,6 +904,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -952,7 +981,7 @@ def forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return module(*inputs, output_attentions, None, cu_seqlens=cu_seqlens)

return custom_forward

Expand All @@ -971,6 +1000,7 @@ def custom_forward(*inputs):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1045,6 +1075,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1089,6 +1120,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cu_seqlens=cu_seqlens
)

hidden_states = outputs[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cu_seqlens: Optional[torch.LongTensor] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down Expand Up @@ -185,6 +186,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cu_seqlens=cu_seqlens
)
logits = outputs.logits

Expand Down
Loading