Skip to content

Commit

Permalink
Feat (llm): QuantizableBert
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jul 4, 2023
1 parent fd4fb20 commit b2a8f19
Showing 1 changed file with 149 additions and 13 deletions.
162 changes: 149 additions & 13 deletions src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
from torch import nn


def attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length):
"""Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D
(implicit batch_size and n_heads)."""
if len(attention_mask.shape) == 4:
if attention_mask.shape[0] == 1:
attention_mask = attention_mask.repeat(batch_size, 1, 1, 1)
if attention_mask.shape[1] == 1:
attention_mask = attention_mask.repeat(1, num_heads, 1, 1)
if attention_mask.shape[2] == 1:
attention_mask = attention_mask.repeat(1, 1, query_seq_length, 1)
attention_mask = attention_mask.view(
batch_size * num_heads, query_seq_length, key_value_seq_length)
elif len(attention_mask.shape) == 2 and attention_mask.shape[0] == 1:
# This could happen in Encoder-like architecture
assert query_seq_length == key_value_seq_length
attention_mask = attention_mask.repeat(query_seq_length, 1)
return attention_mask


class MultiheadAttentionWrapper(nn.Module):

def __init__(
Expand Down Expand Up @@ -33,6 +53,41 @@ def __init__(
device,
dtype)


class QuantizableOPTAttention(MultiheadAttentionWrapper):

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
if layer_head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if self.mha.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[1]
else:
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
past_key_value = None
return attn_output, attn_output_weights, past_key_value

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
Expand Down Expand Up @@ -99,29 +154,110 @@ def set_weight(value):
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)


class QuantizableOPTAttention(MultiheadAttentionWrapper):
class QuantizableBertAttention(MultiheadAttentionWrapper):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
if layer_head_mask is not None:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if encoder_attention_mask is not None:
attention_mask = encoder_attention_mask
if head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if attention_mask is not None:
attention_mask = attention_mask.squeeze()

if self.mha.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = encoder_hidden_states.shape[1]
else:
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = encoder_hidden_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
encoder_hidden_states,
encoder_hidden_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
past_key_value = None
return attn_output, attn_output_weights, past_key_value

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):

def set_bias(value):
bias_name = f'{prefix}mha.in_proj_bias'
if bias_name in state_dict:
state_dict[bias_name] += value
else:
state_dict[bias_name] = value

def set_weight(value):
weight_name = f'{prefix}mha.in_proj_weight'
if weight_name in state_dict:
state_dict[weight_name] += value
else:
state_dict[weight_name] = value

embed_dim = self.mha.embed_dim
for name, value in list(state_dict.items()):
if prefix + 'query.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[:embed_dim] = value
set_weight(weight)
del state_dict[name]
elif prefix + 'key.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[embed_dim:2 * embed_dim] = value
set_weight(weight)
del state_dict[name]
elif prefix + 'value.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[2 * embed_dim:3 * embed_dim] = value
set_weight(weight)
del state_dict[name]
if prefix + 'query.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[:embed_dim] = value
set_bias(bias)
del state_dict[name]
elif prefix + 'key.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[embed_dim:2 * embed_dim] = value
set_bias(bias)
del state_dict[name]
elif prefix + 'value.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[2 * embed_dim:3 * embed_dim] = value
set_bias(bias)
del state_dict[name]
state_dict[prefix + 'mha.out_proj.weight'] = torch.eye(self.mha.out_proj.weight.shape[0])
state_dict[prefix + 'mha.out_proj.bias'] = torch.zeros(self.mha.out_proj.bias.shape)
# elif prefix + 'self.output.dense.weight' in name:
# state_dict[prefix + 'mha.out_proj.weight'] = value
# del state_dict[name]
# elif prefix + 'self.output.dense.bias' in name:
# state_dict[prefix + 'mha.out_proj.bias'] = value
# del state_dict[name]
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

0 comments on commit b2a8f19

Please sign in to comment.