-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils_elmo.py
130 lines (103 loc) · 4.26 KB
/
utils_elmo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Utility functions for training and validating models.
"""
import time
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from mfae.utils import correct_predictions
from allennlp.modules.elmo import batch_to_ids
def train(model,
dataloader,
optimizer,
criterion,
epoch_number,
max_gradient_norm):
"""
Train a model for one epoch on some input data with a given optimizer and
criterion.
Args:
model: A torch module that must be trained on some input data.
dataloader: A DataLoader object to iterate over the training data.
optimizer: A torch optimizer to use for training on the input model.
criterion: A loss criterion to use for training.
epoch_number: The number of the epoch for which training is performed.
max_gradient_norm: Max. norm for gradient norm clipping.
Returns:
epoch_time: The total time necessary to train the epoch.
epoch_loss: The training loss computed for the epoch.
epoch_accuracy: The accuracy computed for the epoch.
"""
# Switch the model to train mode.
model.train()
device = model.device
epoch_start = time.time()
batch_time_avg = 0.0
running_loss = 0.0
correct_preds = 0
total_num = 0
tqdm_batch_iterator = tqdm(dataloader)
for batch_index, batch in enumerate(tqdm_batch_iterator):
batch_start = time.time()
# Move input and output data to the GPU if it is used.
premises_ids = batch["premises"].squeeze().to(device)
hypotheses_ids = batch["hypotheses"].squeeze().to(device)
labels = torch.tensor(batch["labels"]).to(device)
optimizer.zero_grad()
logits, probs = model(premises_ids, hypotheses_ids)
loss = criterion(logits, labels)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm)
optimizer.step()
batch_time_avg += time.time() - batch_start
running_loss += loss.item()
correct_preds += correct_predictions(probs, labels)
total_num += len(labels)
description = "Avg. batch proc. time: {:.4f}s, loss: {:.4f}"\
.format(batch_time_avg/(batch_index+1),
running_loss/(batch_index+1))
tqdm_batch_iterator.set_description(description)
epoch_time = time.time() - epoch_start
epoch_loss = running_loss / len(dataloader)
epoch_accuracy = correct_preds / total_num
return epoch_time, epoch_loss, epoch_accuracy
def validate(model, dataloader, criterion):
"""
Compute the loss and accuracy of a model on some validation dataset.
Args:
model: A torch module for which the loss and accuracy must be
computed.
dataloader: A DataLoader object to iterate over the validation data.
criterion: A loss criterion to use for computing the loss.
epoch: The number of the epoch for which validation is performed.
device: The device on which the model is located.
Returns:
epoch_time: The total time to compute the loss and accuracy on the
entire validation set.
epoch_loss: The loss computed on the entire validation set.
epoch_accuracy: The accuracy computed on the entire validation set.
"""
# Switch to evaluate mode.
model.eval()
device = model.device
epoch_start = time.time()
running_loss = 0.0
running_accuracy = 0.0
total_num = 0
# Deactivate autograd for evaluation.
with torch.no_grad():
for batch in dataloader:
# Move input and output data to the GPU if one is used.
premises_ids = batch["premises"].squeeze().to(device)
hypotheses_ids = batch["hypotheses"].squeeze().to(device)
labels = torch.tensor(batch["labels"]).to(device)
logits, probs = model(premises_ids, hypotheses_ids)
loss = criterion(logits, labels)
running_loss += loss.item()
running_accuracy += correct_predictions(probs, labels)
total_num += len(labels)
epoch_time = time.time() - epoch_start
epoch_loss = running_loss / len(dataloader)
epoch_accuracy = running_accuracy / total_num
return epoch_time, epoch_loss, epoch_accuracy