From 1531d916d54723ab33cee2ee4745ef66dc981527 Mon Sep 17 00:00:00 2001 From: howardchanth <55630770+howardchanth@users.noreply.github.com> Date: Wed, 29 Sep 2021 17:37:08 +0800 Subject: [PATCH 1/3] Update Base.cpp Add a stage to switch the global variables to current dataloader index --- openke/base/Base.cpp | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/openke/base/Base.cpp b/openke/base/Base.cpp index 13a1b384..9cdad3cd 100755 --- a/openke/base/Base.cpp +++ b/openke/base/Base.cpp @@ -168,12 +168,32 @@ void sampling( INT negRate = 1, INT negRelRate = 0, INT mode = 0, + INT domain_idx=0, bool filter_flag = true, bool p = false, bool val_loss = false ) { pthread_t *pt = (pthread_t *)malloc(workThreads * sizeof(pthread_t)); Parameter *para = (Parameter *)malloc(workThreads * sizeof(Parameter)); + + // Update the global variables to current loader index + trainList = trainLists[domain_idx]; + trainTotal = trainTotals[domain_idx]; + + trainHead = trainHeads[domain_idx]; + trainTail = trainTails[domain_idx]; + trainRel = trainRels[domain_idx]; + + lefHead = lefHeads[domain_idx]; + rigHead = rigHeads[domain_idx]; + lefTail = lefTails[domain_idx]; + rigTail = rigTails[domain_idx]; + lefRel = lefRels[domain_idx]; + rigRel = rigRels[domain_idx]; + + entityTotal = entityTotals[domain_idx]; + relationTotal = relationTotals[domain_idx]; + for (INT threads = 0; threads < workThreads; threads++) { para[threads].id = threads; para[threads].batch_h = batch_h; @@ -199,4 +219,4 @@ void sampling( int main() { importTrainFiles(); return 0; -} \ No newline at end of file +} From c85f33c4f15f28ba321c8c284c75ac2bd8a7691d Mon Sep 17 00:00:00 2001 From: howardchanth <55630770+howardchanth@users.noreply.github.com> Date: Wed, 29 Sep 2021 17:41:49 +0800 Subject: [PATCH 2/3] Update Reader.h Generalized loading to multiple data loaders --- openke/base/Reader.h | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/openke/base/Reader.h b/openke/base/Reader.h index 14c75625..196a52d5 100755 --- a/openke/base/Reader.h +++ b/openke/base/Reader.h @@ -6,6 +6,7 @@ #include #include #include +#include INT *freqRel, *freqEnt; INT *lefHead, *rigHead; @@ -14,6 +15,25 @@ INT *lefRel, *rigRel; REAL *left_mean, *right_mean; REAL *prob; +// Vector of collecting the global variables from different data loaders +std::vector< Triple* > trainLists = std::vector(); + +std::vector< Triple* > trainHeads = std::vector(); +std::vector< Triple* > trainTails = std::vector(); +std::vector< Triple* > trainRels = std::vector(); + +std::vector< INT* > lefHeads = std::vector(); +std::vector< INT* > rigHeads = std::vector(); +std::vector< INT* > lefTails = std::vector(); +std::vector< INT* > rigTails = std::vector(); +std::vector< INT* > lefRels = std::vector(); +std::vector< INT* > rigRels = std::vector(); + +std::vector< INT > trainTotals = std::vector(); +std::vector< INT > tripleTotals = std::vector(); +std::vector< INT > entityTotals = std::vector(); +std::vector< INT > relationTotals = std::vector(); + Triple *trainList; Triple *trainHead; Triple *trainTail; @@ -83,6 +103,16 @@ void importTrainFiles() { trainRel = (Triple *)calloc(trainTotal, sizeof(Triple)); freqRel = (INT *)calloc(relationTotal, sizeof(INT)); freqEnt = (INT *)calloc(entityTotal, sizeof(INT)); + + // Collect the address of train lists, thier heads, tails and relations + trainLists.push_back(trainList); + trainHeads.push_back(trainHead); + trainTails.push_back(trainTail); + trainRels.push_back(trainRel); + + // Collect number of entities and relations of this data loader + entityTotals.push_back(entityTotal); + relationTotals.push_back(relationTotal); for (INT i = 0; i < trainTotal; i++) { tmp = fscanf(fin, "%ld", &trainList[i].h); tmp = fscanf(fin, "%ld", &trainList[i].t); @@ -117,6 +147,18 @@ void importTrainFiles() { rigRel = (INT *)calloc(entityTotal, sizeof(INT)); memset(rigHead, -1, sizeof(INT)*entityTotal); memset(rigTail, -1, sizeof(INT)*entityTotal); + + // Collect left and rights of heads, tail and relations + lefHeads.push_back(lefHead); + rigHeads.push_back(rigHead); + lefTails.push_back(lefTail); + rigTails.push_back(rigTail); + lefRels.push_back(lefRel); + rigRels.push_back(rigRel); + + // Collect train totals + trainTotals.push_back(trainTotal); + memset(rigRel, -1, sizeof(INT)*entityTotal); for (INT i = 1; i < trainTotal; i++) { if (trainTail[i].t != trainTail[i - 1].t) { From 1e979f0ce3a87ad3d9a422ea0a34778d7c30f5de Mon Sep 17 00:00:00 2001 From: howardchanth <55630770+howardchanth@users.noreply.github.com> Date: Wed, 29 Sep 2021 17:43:54 +0800 Subject: [PATCH 3/3] Update TrainDataLoader.py Update samplings to include the index of the current data loader. Assuming only one data loader doing sampling at a time --- openke/data/TrainDataLoader.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/openke/data/TrainDataLoader.py b/openke/data/TrainDataLoader.py index 33da8959..d9d5878b 100755 --- a/openke/data/TrainDataLoader.py +++ b/openke/data/TrainDataLoader.py @@ -52,6 +52,7 @@ def __init__(self, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64, + ctypes.c_int64, ctypes.c_int64 ] self.in_path = in_path @@ -71,6 +72,7 @@ def __init__(self, self.negative_ent = neg_ent self.negative_rel = neg_rel self.sampling_mode = sampling_mode + self.domain_idx = domain_idx self.cross_sampling_flag = 0 self.read() @@ -115,6 +117,7 @@ def sampling(self): self.negative_ent, self.negative_rel, 0, + self.domain_idx, self.filter, 0, 0 @@ -137,6 +140,7 @@ def sampling_head(self): self.negative_ent, self.negative_rel, -1, + self.domain_idx, self.filter, 0, 0 @@ -159,6 +163,7 @@ def sampling_tail(self): self.negative_ent, self.negative_rel, 1, + self.domain_idx, self.filter, 0, 0 @@ -226,4 +231,4 @@ def __iter__(self): return TrainDataSampler(self.nbatches, self.cross_sampling) def __len__(self): - return self.nbatches \ No newline at end of file + return self.nbatches