diff --git a/colbert/data/collection.py b/colbert/data/collection.py index d5efc943..4d800fb7 100644 --- a/colbert/data/collection.py +++ b/colbert/data/collection.py @@ -33,7 +33,9 @@ def _load_file(self, path): return self._load_tsv(path) if path.endswith('.tsv') else self._load_jsonl(path) def _load_tsv(self, path): - return load_collection(path) + collection, pid_list = load_collection(path) + self.pid_list = pid_list + return collection def _load_jsonl(self, path): raise NotImplementedError() diff --git a/colbert/evaluation/loaders.py b/colbert/evaluation/loaders.py index 251065b0..aae585ff 100644 --- a/colbert/evaluation/loaders.py +++ b/colbert/evaluation/loaders.py @@ -156,6 +156,7 @@ def load_collection(collection_path): print_message("#> Loading collection...") collection = [] + pid_list = [] with open(collection_path) as f: for line_idx, line in enumerate(f): @@ -163,7 +164,8 @@ def load_collection(collection_path): print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True) pid, passage, *rest = line.strip('\n\r ').split('\t') - assert pid == 'id' or int(pid) == line_idx, f"pid={pid}, line_idx={line_idx}" + pid_list.append(pid) + # assert pid == 'id' or int(pid) == line_idx, f"pid={pid}, line_idx={line_idx}" if len(rest) >= 1: title = rest[0] @@ -173,7 +175,7 @@ def load_collection(collection_path): print() - return collection + return collection, pid_list def load_colbert(args, do_print=True): diff --git a/colbert/searcher.py b/colbert/searcher.py index 8bc07c50..fc25e896 100644 --- a/colbert/searcher.py +++ b/colbert/searcher.py @@ -37,6 +37,7 @@ def __init__(self, index, checkpoint=None, collection=None, config=None, index_r self.config = ColBERTConfig.from_existing(self.checkpoint_config, self.index_config, initial_config) self.collection = Collection.cast(collection or self.config.collection) + self.pid_list = self.idx2pid(self.config.collection) self.configure(checkpoint=self.checkpoint, collection=self.collection) self.checkpoint = Checkpoint(self.checkpoint, colbert_config=self.config, verbose=self.verbose) @@ -49,6 +50,14 @@ def __init__(self, index, checkpoint=None, collection=None, config=None, index_r self.ranker = IndexScorer(self.index, use_gpu, load_index_with_mmap) print_memory_stats() + + def idx2pid(self, collection_path): + pid_list = [] + with open(collection_path) as f: + for line_idx, line in enumerate(f): + pid, passage, *rest = line.strip('\n\r ').split('\t') + pid_list.append(pid) + return pid_list def configure(self, **kw_args): self.config.configure(**kw_args)