-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_loader.py
107 lines (87 loc) · 3.59 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import numpy as np
from torch.utils.data import Dataset
class TrainDataset(Dataset):
def __init__(self, triples, params):
self.triples = triples
self.p = params
self.strategy = self.p.train_strategy
self.entities = np.arange(self.p.num_ent, dtype=np.int32)
def __len__(self):
return len(self.triples)
def __getitem__(self, idx):
ele = self.triples[idx]
triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32(ele['label']), np.float32(ele['sub_samp'])
trp_label = self.get_label(label)
if self.p.lbl_smooth != 0.0:
trp_label = (1.0 - self.p.lbl_smooth) * trp_label + (1.0 / self.p.num_ent)
if self.strategy == 'one_to_n':
return triple, trp_label, None, None
elif self.strategy == 'one_to_x':
sub_samp = torch.FloatTensor([sub_samp])
neg_ent = torch.LongTensor(self.get_neg_ent(triple, label))
return triple, trp_label, neg_ent, sub_samp
else:
raise NotImplementedError
@staticmethod
def collate_fn(data):
triple = torch.stack([_[0] for _ in data], dim=0)
trp_label = torch.stack([_[1] for _ in data], dim=0)
if not data[0][2] is None: # one_to_x
neg_ent = torch.stack([_[2] for _ in data], dim=0)
sub_samp = torch.cat([_[3] for _ in data], dim=0)
return triple, trp_label, neg_ent, sub_samp
else:
return triple, trp_label
def get_label(self, label):
if self.strategy == 'one_to_n':
y = np.zeros([self.p.num_ent], dtype=np.float32)
for e2 in label:
y[e2] = 1.0
elif self.strategy == 'one_to_x':
y = [1] + [0] * self.p.neg_num
else:
raise NotImplementedError
return torch.FloatTensor(y)
def get_neg_ent(self, triple, label):
def get(triple, label):
if self.strategy == 'one_to_x':
pos_obj = triple[2]
mask = np.ones([self.p.num_ent], dtype=np.bool)
mask[label] = 0
neg_ent = np.int32(np.random.choice(self.entities[mask], self.p.neg_num, replace=False)).reshape([-1])
neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent))
else:
pos_obj = label
mask = np.ones([self.p.num_ent], dtype=np.bool)
mask[label] = 0
neg_ent = np.int32(
np.random.choice(self.entities[mask], self.p.neg_num - len(label), replace=False)).reshape([-1])
neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent))
if len(neg_ent) > self.p.neg_num:
import pdb;
pdb.set_trace()
return neg_ent
neg_ent = get(triple, label)
return neg_ent
class TestDataset(Dataset):
def __init__(self, triples, params):
self.triples = triples
self.p = params
def __len__(self):
return len(self.triples)
def __getitem__(self, idx):
ele = self.triples[idx]
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label'])
label = self.get_label(label)
return triple, label
@staticmethod
def collate_fn(data):
triple = torch.stack([_[0] for _ in data], dim=0)
label = torch.stack([_[1] for _ in data], dim=0)
return triple, label
def get_label(self, label):
y = np.zeros([self.p.num_ent], dtype=np.float32)
for e2 in label:
y[e2] = 1.0
return torch.FloatTensor(y)