-
Notifications
You must be signed in to change notification settings - Fork 90
/
data_manager.py
131 lines (120 loc) · 4.47 KB
/
data_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# -*- coding:utf-8 -*-
'''
@Author: yanwii
@Date: 2018-05-30 14:46:36
'''
import copy
import pickle as cPickle
import torch
class DataManager():
def __init__(self, max_length=100, batch_size=20, data_type='train', tags=[]):
self.index = 0
self.input_size = 0
self.batch_size = batch_size
self.max_length = max_length
self.data_type = data_type
self.data = []
self.batch_data = []
self.vocab = {"unk": 0}
# self.tag_map = {"O":0, "B-ORG":1, "I-ORG":2, "E-ORG":3, "B-PER":4, "I-PER":5, "E-PER":6, "S":7}
self.tag_map = {"O":0, "START":1, "STOP":2}
if data_type == "train":
assert tags, Exception("请指定需要训练的tag类型,如[\"ORG\", \"PER\"]")
self.generate_tags(tags)
self.data_path = "data/train"
elif data_type == "dev":
self.data_path = "data/dev"
self.load_data_map()
elif data_type == "test":
self.data_path = "data/test"
self.load_data_map()
self.load_data()
self.prepare_batch()
def generate_tags(self, tags):
self.tags = []
for tag in tags:
for prefix in ["B-", "I-", "E-"]:
self.tags.append(prefix + tag)
self.tags.append("O")
def load_data_map(self):
with open("models/data.pkl", "rb") as f:
self.data_map = cPickle.load(f)
self.vocab = self.data_map.get("vocab", {})
self.tag_map = self.data_map.get("tag_map", {})
self.tags = self.data_map.keys()
def load_data(self):
# load data
# add vocab
# covert to one-hot
sentence = []
target = []
with open(self.data_path) as f:
for line in f:
line = line[:-1]
if line == "end":
self.data.append([sentence, target])
sentence = []
target = []
continue
try:
word, tag = line.split(" ")
except Exception:
continue
if word not in self.vocab and self.data_type == "train":
self.vocab[word] = max(self.vocab.values()) + 1
if tag not in self.tag_map and self.data_type == "train" and tag in self.tags:
self.tag_map[tag] = len(self.tag_map.keys())
sentence.append(self.vocab.get(word, 0))
target.append(self.tag_map.get(tag, 0))
self.input_size = len(self.vocab.values())
print("{} data: {}".format(self.data_type ,len(self.data)))
print("vocab size: {}".format(self.input_size))
print("unique tag: {}".format(len(self.tag_map.values())))
print("-"*50)
def convert_tag(self, data):
# add E-XXX for tags
# add O-XXX for tags
_, tags = data
converted_tags = []
for _, tag in enumerate(tags[:-1]):
if tag not in self.tag_map and self.data_type == "train":
self.tag_map[tag] = len(self.tag_map.keys())
converted_tags.append(self.tag_map.get(tag, 0))
converted_tags.append(0)
data[1] = converted_tags
assert len(converted_tags) == len(tags), "convert error, the list dosen't match!"
return data
def prepare_batch(self):
'''
prepare data for batch
'''
index = 0
while True:
if index+self.batch_size >= len(self.data):
pad_data = self.pad_data(self.data[-self.batch_size:])
self.batch_data.append(pad_data)
break
else:
pad_data = self.pad_data(self.data[index:index+self.batch_size])
index += self.batch_size
self.batch_data.append(pad_data)
def pad_data(self, data):
c_data = copy.deepcopy(data)
max_length = max([len(i[0]) for i in c_data])
for i in c_data:
i.append(len(i[0]))
i[0] = i[0] + (max_length-len(i[0])) * [0]
i[1] = i[1] + (max_length-len(i[1])) * [0]
# i[0] = torch.tensor(i[0])
# i[1] = torch.tensor(i[1])
return c_data
def iteration(self):
idx = 0
while True:
yield self.batch_data[idx]
idx += 1
if idx > len(self.batch_data)-1:
idx = 0
def get_batch(self):
for data in self.batch_data:
yield data