-
Notifications
You must be signed in to change notification settings - Fork 7
/
helper.py
87 lines (74 loc) · 2.67 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import unicodedata
import re
import torch
from torch.autograd import Variable
import time
import math
import random
SOS_token = 0
EOS_token = 1
PAD_token = 2
use_cuda = torch.cuda.is_available()
class Lang:
def __init__(self, name):
self.name = name
self.word2index = {}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS", 2: 'PAD'}
self.n_words = 3 # Count SOS and EOS
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def readLangs(lang1, lang2, reverse=False):
print("Reading lines...")
# Read the file and split into lines
lines = open('%s-%s.txt' % (lang1, lang2), encoding='utf-8'). read().strip().split('\n')
# Split every line into pairs and normalize
pairs = [[s for s in l.split('\t')] for l in lines]
# Reverse pairs, make Lang instances
if reverse:
pairs = [list(reversed(p)) for p in pairs]
input_lang = Lang(lang2)
output_lang = Lang(lang1)
else:
input_lang = Lang(lang1)
output_lang = Lang(lang2)
return input_lang, output_lang, pairs
def prepareData(lang1, lang2, max_length, reverse=False):
input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
print("Read %s sentence pairs" % len(pairs))
print("Trimmed to %s sentence pairs" % len(pairs))
print("Counting words...")
for pair in pairs:
input_lang.addSentence(pair[0])
output_lang.addSentence(pair[1])
print("Counted words:")
print(input_lang.name, input_lang.n_words)
print(output_lang.name, output_lang.n_words)
return input_lang, output_lang, pairs
def indexesFromSentence(lang, sentence):
return [lang.word2index[word] for word in sentence.split(' ')]
def variableFromSentence(lang, sentence, max_length):
indexes = indexesFromSentence(lang, sentence)
indexes.append(EOS_token)
indexes.extend([PAD_token] * (max_length - len(indexes)))
result = torch.LongTensor(indexes)
if use_cuda:
return result.cuda()
else:
return result
def variablesFromPairs(input_lang, output_lang, pairs, max_length):
res = []
for pair in pairs:
input_variable = variableFromSentence(input_lang, pair[0], max_length)
target_variable = variableFromSentence(output_lang, pair[1], max_length)
res.append((input_variable, target_variable))
return res