Skip to content

Commit

Permalink
Only download PyTorch weights in text-generation example (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss authored Jun 30, 2023
1 parent 3761c83 commit b1f0bd7
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
2 changes: 2 additions & 0 deletions examples/text-generation/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_repo_root(model_name_or_path, local_rank):
model_name_or_path,
local_files_only=is_offline_mode(),
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
allow_patterns=["*.bin"],
ignore_patterns=["*.safetensors"],
)

Expand All @@ -32,6 +33,7 @@ def get_repo_root(model_name_or_path, local_rank):
model_name_or_path,
local_files_only=is_offline_mode(),
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
allow_patterns=["*.bin"],
ignore_patterns=["*.safetensors"],
)

Expand Down
4 changes: 2 additions & 2 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ def main():
default=None,
type=str,
nargs="+",
help="Optional argument list of token ids that are not allowed to be generated.",
help="Optional argument list of words that are not allowed to be generated.",
)
parser.add_argument(
"--force_words",
default=None,
type=str,
nargs="+",
help="Optional argument list of token ids that must be generated.",
help="Optional argument list of words that must be generated.",
)
parser.add_argument("--num_return_sequences", type=int, default=1)

Expand Down
57 changes: 56 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ def generate(
Whether to use HPU graphs for inference.
ignore_eos (`bool`, *optional*):
Whether to ignore finished sequences (faster in lazy mode and with HPU graphs) or not (eager mode).
profiling_warmup_steps (`int`, *optional*, defaults to 0):
Number of steps to ignore for profling.
profiling_steps (`int`, *optional*, defaults to 0):
Number of steps to be captured when enabling profiling.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
Expand Down Expand Up @@ -889,6 +893,10 @@ def contrastive_search(
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
lazy_mode (`bool`, *optional*, defaults to `False`):
Whether the run is executed in lazy mode or not (i.e. eager mode).
profiling_warmup_steps (`int`, *optional*, defaults to 0):
Number of steps to ignore for profling.
profiling_steps (`int`, *optional*, defaults to 0):
Number of steps to be captured when enabling profiling.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -994,6 +1002,10 @@ def greedy_search(
Whether the run is executed in lazy mode or not (i.e. eager mode).
ignore_eos (`bool`, *optional*):
Whether to ignore finished sequences (faster in lazy mode and with HPU graphs) or not (eager mode).
profiling_warmup_steps (`int`, *optional*, defaults to 0):
Number of steps to ignore for profling.
profiling_steps (`int`, *optional*, defaults to 0):
Number of steps to be captured when enabling profiling.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -1278,6 +1290,10 @@ def sample(
Whether the run is executed in lazy mode or not (i.e. eager mode).
ignore_eos (`bool`, *optional*):
Whether to ignore finished sequences (faster in lazy mode and with HPU graphs) or not (eager mode).
profiling_warmup_steps (`int`, *optional*, defaults to 0):
Number of steps to ignore for profling.
profiling_steps (`int`, *optional*, defaults to 0):
Number of steps to be captured when enabling profiling.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -1577,6 +1593,10 @@ def beam_search(
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
lazy_mode (`bool`, *optional*, defaults to `False`):
Whether the run is executed in lazy mode or not (i.e. eager mode).
profiling_warmup_steps (`int`, *optional*, defaults to 0):
Number of steps to ignore for profling.
profiling_steps (`int`, *optional*, defaults to 0):
Number of steps to be captured when enabling profiling.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -1925,6 +1945,10 @@ def beam_sample(
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
lazy_mode (`bool`, *optional*, defaults to `False`):
Whether the run is executed in lazy mode or not (i.e. eager mode).
profiling_warmup_steps (`int`, *optional*, defaults to 0):
Number of steps to ignore for profling.
profiling_steps (`int`, *optional*, defaults to 0):
Number of steps to be captured when enabling profiling.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -2014,7 +2038,6 @@ def group_beam_search(
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
lazy_mode: Optional[bool] = False,
hpuy_graphs: Optional[bool] = False,
profiling_warmup_steps: Optional[int] = 0,
profiling_steps: Optional[int] = 0,
**model_kwargs,
Expand Down Expand Up @@ -2064,6 +2087,10 @@ def group_beam_search(
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
lazy_mode (`bool`, *optional*, defaults to `False`):
Whether the run is executed in lazy mode or not (i.e. eager mode).
profiling_warmup_steps (`int`, *optional*, defaults to 0):
Number of steps to ignore for profling.
profiling_steps (`int`, *optional*, defaults to 0):
Number of steps to be captured when enabling profiling.
model_kwargs:
Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -2203,6 +2230,10 @@ def constrained_beam_search(
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
lazy_mode (`bool`, *optional*, defaults to `False`):
Whether the run is executed in lazy mode or not (i.e. eager mode).
profiling_warmup_steps (`int`, *optional*, defaults to 0):
Number of steps to ignore for profling.
profiling_steps (`int`, *optional*, defaults to 0):
Number of steps to be captured when enabling profiling.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -2298,30 +2329,38 @@ def constrained_beam_search(
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)

batch_size = len(constrained_beam_scorer._beam_hyps)
num_beams = constrained_beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape

if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)

# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False # used by synced_gpus only

hb_profer = HabanaProfile(warmup=profiling_warmup_steps, active=profiling_steps)
hb_profer.start()
while True:
Expand All @@ -2334,16 +2373,20 @@ def constrained_beam_search(
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break

model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need

token_idx = model_kwargs.get("token_idx", None)
if token_idx is not None and outputs.logits.shape[-2] > 1:
next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2)
Expand All @@ -2355,9 +2398,13 @@ def constrained_beam_search(
next_token_scores = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)

next_token_scores_processed = logits_processor(input_ids, next_token_scores)

next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

scores_for_all_vocab = next_token_scores.clone()

# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
Expand All @@ -2374,15 +2421,19 @@ def constrained_beam_search(
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)

# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)

next_indices = (next_tokens / vocab_size).long()
next_tokens = next_tokens % vocab_size

# stateless
beam_outputs = constrained_beam_scorer.process(
input_ids[:, : token_idx.item()],
Expand All @@ -2396,6 +2447,7 @@ def constrained_beam_search(
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]

if token_idx is not None:
input_ids = input_ids[beam_idx, :]
input_ids.index_copy_(
Expand All @@ -2408,8 +2460,10 @@ def constrained_beam_search(
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)

# increase cur_len
cur_len = cur_len + 1

hb_profer.step()
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
Expand All @@ -2427,6 +2481,7 @@ def constrained_beam_search(
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
)

if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
Expand Down

0 comments on commit b1f0bd7

Please sign in to comment.