-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathSS_dataset.py
105 lines (86 loc) · 2.79 KB
/
SS_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
import os, gc
import cPickle
import copy
import logging
import threading
import Queue
import collections
logger = logging.getLogger(__name__)
np.random.seed(1234)
class SSFetcher(threading.Thread):
def __init__(self, parent):
threading.Thread.__init__(self)
self.parent = parent
self.indexes = np.arange(parent.data_len)
def run(self):
diter = self.parent
# Shuffle with parents random generator
self.parent.rng.shuffle(self.indexes)
offset = 0
# Take groups of 10000 triples and group by length
while not diter.exit_flag:
last_batch = False
triples = []
while len(triples) < diter.batch_size:
if offset == diter.data_len:
if not diter.use_infinite_loop:
last_batch = True
break
else:
# Infinite loop here, we reshuffle the indexes
# and reset the offset
self.parent.rng.shuffle(self.indexes)
offset = 0
index = self.indexes[offset]
s = diter.data[index]
offset += 1
# Append only if it is shorter than max_len
if len(s) <= diter.max_len:
triples.append(s)
if len(triples):
diter.queue.put(triples)
if last_batch:
diter.queue.put(None)
return
class SSIterator(object):
def __init__(self,
rng,
batch_size,
triple_file=None,
dtype="int32",
can_fit=False,
queue_size=100,
cache_size=100,
shuffle=True,
use_infinite_loop=True,
max_len=1000):
args = locals()
args.pop("self")
self.__dict__.update(args)
self.rng = rng
self.load_files()
self.exit_flag = False
def load_files(self):
self.data = cPickle.load(open(self.triple_file, 'r'))
self.data_len = len(self.data)
logger.debug('Data len is %d' % self.data_len)
def start(self):
self.exit_flag = False
self.queue = Queue.Queue(maxsize=self.queue_size)
self.gather = SSFetcher(self)
self.gather.daemon = True
self.gather.start()
def __del__(self):
if hasattr(self, 'gather'):
self.gather.exitFlag = True
self.gather.join()
def __iter__(self):
return self
def next(self):
if self.exit_flag:
return None
batch = self.queue.get()
if not batch:
self.exit_flag = True
return batch