-
Notifications
You must be signed in to change notification settings - Fork 9
/
utils.py
23 lines (17 loc) · 845 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
def get_distribution(logits, temperature):
probs = torch.softmax(logits / (temperature + 1e-10), dim=-1)
return probs
def sample(logits, temperature):
probs = get_distribution(logits, temperature)
return torch.multinomial(probs, num_samples=1)[0]
def sample_from_draft_model(model, initial_prompt_seq, new_tokens, temperature=1.0):
fin_prompt_seq = initial_prompt_seq.detach().clone()
out_logits = []
for _ in range(new_tokens):
sample_token_logits = model(fin_prompt_seq).logits[:, -1, :]
sample_token = sample(sample_token_logits, temperature=temperature)
fin_prompt_seq = torch.concat([fin_prompt_seq, sample_token[None,...]], dim=-1)
out_logits.append(sample_token_logits)
out_logits = torch.stack(out_logits, dim=1)
return fin_prompt_seq, out_logits