-
Notifications
You must be signed in to change notification settings - Fork 0
/
search.py
143 lines (109 loc) · 5.8 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from pathlib import Path
import torch
from config import get_cfg, get_model_file_path
from train import get_model, get_dataset, eval_model
from dataset import BilingualDataset
from trasnformer import Transformer
def beam_search_decode(model:Transformer, beam_size, source, source_mask, tokenizer_tgt, max_len, device):
sos_idx = tokenizer_tgt.token_to_id('[SOS]')
eos_idx = tokenizer_tgt.token_to_id('[EOS]')
# Precompute the encoder output and reuse it for every step
encoder_output = model.encode(source, source_mask)
# Initialize the decoder input with the sos token
decoder_initial_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
# Create a candidate list
candidates = [(decoder_initial_input, 1)]
while True:
# If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search
if any([cand.size(1) == max_len for cand, _ in candidates]):
break
# Create a new list of candidates
new_candidates = []
for candidate, score in candidates:
# Do not expand candidates that have reached the eos token
if candidate[0][-1].item() == eos_idx:
continue
# Build the candidate's mask
candidate_mask = BilingualDataset.causal_mask(candidate.size(1)).type_as(source_mask).to(device)
# calculate output
out = model.decode(encoder_output, source_mask, candidate, candidate_mask)
# get next token probabilities
prob = model.project(out[:, -1])
# get the top k candidates
topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)
for i in range(beam_size):
# for each of the top k candidates, get the token and its probability
token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)
token_prob = topk_prob[0][i].item()
# create a new candidate by appending the token to the current candidate
new_candidate = torch.cat([candidate, token], dim=1)
# We sum the log probabilities because the probabilities are in log space
new_candidates.append((new_candidate, score + token_prob))
# Sort the new candidates by their score
candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
# Keep only the top k candidates
candidates = candidates[:beam_size]
# If all the candidates have reached the eos token, stop
if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):
break
# Return the best candidate
return candidates[0][0].squeeze()
def greedy_decode(model:Transformer, source, source_mask, tokenizer_tgt, max_len, device):
sos_idx = tokenizer_tgt.token_to_id('[SOS]')
eos_idx = tokenizer_tgt.token_to_id('[EOS]')
# Precompute the encoder output and reuse it for every step
encoder_output = model.encode(source, source_mask)
# Initialize the decoder input with the sos token
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
while True:
if decoder_input.size(1) == max_len:
break
# build mask for target
decoder_mask = BilingualDataset.causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
# calculate output
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
# get next token
prob = model.project(out[:, -1])
_, next_word = torch.max(prob, dim=1)
decoder_input = torch.cat(
[decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
)
if next_word == eos_idx:
break
return decoder_input.squeeze(0)
def run_eval(model, validation_ds, tokenizer_tgt, max_len, device, print_msg, num_examples=2):
model.eval()
count = 0
console_width = 80
with torch.no_grad():
for batch in validation_ds:
count += 1
encoder_input = batch["enc_input"].to(device) # (b, seq_len)
encoder_mask = batch["enc_mask"].to(device) # (b, 1, 1, seq_len)
# check that the batch size is 1
assert encoder_input.size(
0) == 1, "Batch size must be 1 for validation"
model_out_greedy = greedy_decode(model, encoder_input, encoder_mask, tokenizer_tgt, max_len, device)
model_out_beam = beam_search_decode(model, 3, encoder_input, encoder_mask, tokenizer_tgt, max_len, device)
source_text = batch["src_txt"][0]
target_text = batch["tgt_txt"][0]
model_out_text_beam = tokenizer_tgt.decode(model_out_beam.detach().cpu().numpy())
model_out_text_greedy = tokenizer_tgt.decode(model_out_greedy.detach().cpu().numpy())
# Print the source, target and model output
print_msg('-'*console_width)
print_msg(f"{f'SOURCE: ':>20}{source_text}")
print_msg(f"{f'TARGET: ':>20}{target_text}")
print_msg(f"{f'PREDICTED GREEDY: ':>20}{model_out_text_greedy}")
print_msg(f"{f'PREDICTED BEAM: ':>20}{model_out_text_beam}")
if count == num_examples:
print_msg('-'*console_width)
break
cfg = get_cfg()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_, val_dataloader, src_tokenizer, tgt_tokenizer = get_dataset(cfg)
model = get_model(cfg, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device)
# cfg['model_folder'] = "models_old"
model_file = get_model_file_path(cfg, f"29")
model_state = torch.load(model_file)
model.load_state_dict(model_state['model_state_dict'])
run_eval(model, val_dataloader, tgt_tokenizer, 20, device, print_msg=print, num_examples=10)