-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_loader.py
182 lines (142 loc) · 4.83 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
adapted from depccg's implementation @ https://github.com/masashi-y/depccg
"""
from typing import List, Tuple, NamedTuple
from base import Token, ConstituentNode, Category
class DataItem(NamedTuple):
filename: str
id: str
tokens: List[Token]
tree_root: ConstituentNode
class _AutoLineReader(object):
def __init__(self, line):
self.line = line
self.index = 0
self.word_id = -1
self.tokens = []
self.cats = set()
def next(self):
end = self.line.find(' ', self.index)
res = self.line[self.index:end]
self.index = end + 1
return res
def check(self, text, offset=0):
if self.line[self.index + offset] != text:
raise RuntimeError(f'failed to parse: {self.line}')
def peek(self):
return self.line[self.index]
def parse(self):
tree = self.next_node()
return tree, self.tokens, self.cats
@property
def next_node(self):
if self.line[self.index + 2] == 'L':
return self.parse_leaf
elif self.line[self.index + 2] == 'T':
return self.parse_tree
else:
raise RuntimeError(f'failed to parse: {self.line}')
def parse_leaf(self):
self.word_id += 1
self.check('(')
self.check('<', 1)
self.check('L', 2)
self.next()
cat_str = self.next()
cat = Category.parse(cat_str)
self.cats.add(cat_str)
tag1 = self.next() # modified POS tag
tag2 = self.next() # original POS
word = self.next().replace('\\', '')
token = Token(
contents=word,
POS=tag1,
tag=cat
)
self.tokens.append(token)
if word == '-LRB-':
word = "("
elif word == '-RRB-':
word = ')'
self.next()
return ConstituentNode(tag=token.tag, children=[token])
def parse_tree(self):
self.check('(')
self.check('<', 1)
self.check('T', 2)
self.next()
cat_str = self.next()
cat = Category.parse(cat_str)
self.cats.add(cat_str)
head_is_left = self.next() == '0'
self.next()
children = []
while self.peek() != ')':
children.append(self.next_node())
self.next()
if len(children) > 2:
raise RuntimeError(f'failed to parse: {self.line}')
else:
node = ConstituentNode(tag=cat, children=children)
return node
def load_auto_file(filename: str, problematic_ids: List[str] = None) -> Tuple[List[DataItem], List[str]]:
"""read traditional AUTO file used for CCGBank
English CCGbank contains some unwanted categories such as (S\\NP)\\(S\\NP)[conj].
This reads the treebank while taking care of those categories.
Args:
filename (str): file name string
problematic_ids (str): IDs causing nan problems in training, needed to be deleted
Yields:
Iterator[ReaderResult]: iterator object containing parse results
"""
__fix = {'((S[b]\\NP)/NP)/': '(S[b]\\NP)/NP', 'conj[conj]': 'conj'}
def _fix(cat):
if cat in __fix:
return __fix[cat]
if cat.endswith(')[conj]') or cat.endswith('][conj]'):
return cat[:-6]
return cat
data_items = list()
all_cats = set()
cnt = 0
with open(filename, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if len(line) == 0:
continue
if line.startswith("ID"):
id = line
else:
line = ' '.join(
_fix(token) for token in line.split(' ')
)
try:
root, tokens, cats = _AutoLineReader(line).parse()
except:
cnt += 1
continue
if problematic_ids is not None:
if '_'.join([filename, id]) in problematic_ids:
print('Deleted ID:', '_'.join([filename, id]))
continue
data_items.append(DataItem(filename, id, tokens, root))
all_cats.update(cats)
print(f'Number of failed parses for {filename}: {cnt}')
return data_items, list(all_cats)
if __name__ == '__main__':
# sample usage
filename = "data/ccg-sample.auto"
items, cats = load_auto_file(filename)
for item in items:
print(item.id)
for token in item.tokens:
print('{}\t{}\t{}'.format(token.contents, token.POS, token.tag))
root = item.tree_root
def _iter(node):
print(node.tag)
if isinstance(node, ConstituentNode):
for child in node.children:
_iter(child)
_iter(root)
print(cats)