-
Notifications
You must be signed in to change notification settings - Fork 4
/
metrics.py
45 lines (37 loc) · 1.6 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from nltk.translate.bleu_score import SmoothingFunction
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu
def bleu_score_fn(method_no: int = 4, ref_type='corpus'):
"""
:param method_no:
:param ref_type: 'corpus' or 'sentence'
:return: bleu score
"""
smoothing_method = getattr(SmoothingFunction(), f'method{method_no}')
def bleu_score_corpus(reference_corpus: list, candidate_corpus: list, n: int = 4):
"""
:param reference_corpus: [b, 5, var_len]
:param candidate_corpus: [b, var_len]
:param n: size of n-gram
"""
weights = [1 / n] * n
return corpus_bleu(reference_corpus, candidate_corpus,
smoothing_function=smoothing_method, weights=weights)
def bleu_score_sentence(reference_sentences: list, candidate_sentence: list, n: int = 4):
"""
:param reference_sentences: [5, var_len]
:param candidate_sentence: [var_len]
:param n: size of n-gram
"""
weights = [1 / n] * n
return sentence_bleu(reference_sentences, candidate_sentence,
smoothing_function=smoothing_method, weights=weights)
if ref_type == 'corpus':
return bleu_score_corpus
elif ref_type == 'sentence':
return bleu_score_sentence
def accuracy_fn(ignore_value: int = 0):
def accuracy_ignoring_value(source: torch.Tensor, target: torch.Tensor):
mask = target != ignore_value
return (torch.argmax(source, dim=1)[mask] == target[mask]).sum().item() / mask.sum().item()
return accuracy_ignoring_value