Skip to content

Commit

Permalink
Merge branch 'main' into skip_hpugraph_for_first_token
Browse files Browse the repository at this point in the history
  • Loading branch information
polisettyvarma authored Sep 16, 2023
2 parents e69683f + 8a16649 commit 7f93d1e
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 20 deletions.
6 changes: 6 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def main():
type=int,
help="Number of beams used for beam search generation. 1 means greedy search will be performed.",
)
parser.add_argument(
"--trim_logits",
action="store_true",
help="Calculate logits only for the last token to save memory in the first step.",
)
parser.add_argument(
"--seed",
default=27,
Expand Down Expand Up @@ -366,6 +371,7 @@ def check_optimum_habana_min_version(*a, **b):
generation_config.bad_words_ids = bad_words_ids
generation_config.force_words_ids = force_words_ids
generation_config.num_return_sequences = args.num_return_sequences
generation_config.trim_logits = args.trim_logits
generation_config.attn_softmax_bf16 = args.attn_softmax_bf16
generation_config.limit_hpu_graphs = args.limit_hpu_graphs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ class GaudiGenerationConfig(GenerationConfig):
to add HPU-specific arguments for generation.
Arg:
trim_logit (`bool`, *optional):
Calculate logits only for the last token to save memory in the first step.
static_shapes (`bool`, *optional*):
Whether to use static shapes for generation or not. It will run faster on HPUs with static shapes
but not all models support it. If not specified, it will automatically be set to `True` if the given
Expand All @@ -23,7 +25,7 @@ class GaudiGenerationConfig(GenerationConfig):

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.trim_logits = kwargs.get("trim_logits", None)
self.static_shapes = kwargs.get("static_shapes", None)
self.ignore_eos = kwargs.get("ignore_eos", None)
self.attn_softmax_bf16 = kwargs.get("attn_softmax_bf16", None)
Expand Down
4 changes: 3 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ def generate(
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

# determine whether introduce trim_logits feature
model_kwargs["trim_logits"] = generation_config.trim_logits
# determine whether attention softmax needs to execute in lower precision
model_kwargs["attn_softmax_bf16"] = generation_config.attn_softmax_bf16
# determine whether limit_hpu_graphs needs to be used
Expand Down Expand Up @@ -1095,6 +1096,7 @@ def greedy_search(
hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps)
hb_profer.start()
this_peer_finished = False # used by synced_gpus only

while True:
if lazy_mode:
self.htcore_generation.mark_step()
Expand Down
11 changes: 9 additions & 2 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,14 +380,14 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
trim_logits: Optional[bool] = False,
attn_softmax_bf16: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
Expand All @@ -402,8 +402,14 @@ def forward(
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
)

hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
if seq_len > 1 and trim_logits and not self.training:
if token_idx is not None:
hidden_states = hidden_states.index_select(1, token_idx - 1)
else:
hidden_states = hidden_states[:, -1, :]

if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
Expand Down Expand Up @@ -470,6 +476,7 @@ def prepare_inputs_for_generation(
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"trim_logits": kwargs.get("trim_logits"),
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
}
)
Expand Down
23 changes: 12 additions & 11 deletions tests/baselines/vit_base_patch16_224_in21k.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
{
"cifar10": {
"num_train_epochs": 5,
"eval_batch_size": 32,
"num_train_epochs": 1,
"eval_batch_size": 64,
"distribution": {
"single_card": {
"learning_rate": 3e-5,
"train_batch_size": 32,
"eval_accuracy": 0.9901,
"train_runtime": 719.99,
"train_samples_per_second": 300.533,
"learning_rate": 5e-5,
"train_batch_size": 64,
"eval_accuracy": 0.982,
"train_runtime": 143.1925,
"train_samples_per_second": 338.713,
"extra_arguments": [
"--remove_unused_columns False",
"--seed 1337",
Expand All @@ -21,16 +21,17 @@
"multi_card": {
"learning_rate": 2e-4,
"train_batch_size": 64,
"eval_accuracy": 0.9912,
"train_runtime": 128.175,
"train_samples_per_second": 2407.573,
"eval_accuracy": 0.9812,
"train_runtime": 63.4907,
"train_samples_per_second": 2480.927,
"extra_arguments": [
"--remove_unused_columns False",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--dataloader_num_workers 1",
"--pipelining_fwd_bwd True",
"--non_blocking_data_copy True"
"--non_blocking_data_copy True",
"--throughput_warmup_steps 8"
]
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,27 @@ def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
past_keys_dims = len(past_keys.shape)
if past_keys_dims == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
if is_optimized_for_gaudi:
layer[0] = past_keys[keep_indices]
del past_keys
layer[1] = past_values[keep_indices]
del past_values
else:
if len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
if self.keys_head_dim_last:
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
else:
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
del past_keys
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
if past_keys_dims == 3:
layer[0] = layer[0].view(layer[0].shape[0] * layer[0].shape[1], *layer[0].shape[-2:])
layer[1] = layer[1].view(layer[1].shape[0] * layer[1].shape[1], *layer[1].shape[-2:])

top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
Expand Down Expand Up @@ -378,12 +382,13 @@ def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: boo
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
kv_tuple = False
past_key_values_dims = len(batch.past_key_values[0][0].shape)
if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
]
kv_tuple = True
elif len(batch.past_key_values[0][0].shape) == 3:
elif past_key_values_dims == 3:
for layer in batch.past_key_values:
for k, t in enumerate(layer):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
Expand Down Expand Up @@ -469,6 +474,15 @@ def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: boo

# Update values
start_index = end_index

if past_key_values_dims == 3:
padded_past_keys = padded_past_keys.view(
padded_past_keys.shape[0] * padded_past_keys.shape[1], *padded_past_keys.shape[-2:]
)
padded_past_values = padded_past_values.view(
padded_past_values.shape[0] * padded_past_values.shape[1], *padded_past_values.shape[-2:]
)

if kv_tuple:
past_key_values.append((padded_past_keys, padded_past_values))
else:
Expand Down

0 comments on commit 7f93d1e

Please sign in to comment.