From 45fdf2dda7bc27fcdd907bcacb657b32828b0b40 Mon Sep 17 00:00:00 2001 From: s-jse Date: Thu, 5 Sep 2024 17:21:52 +0000 Subject: [PATCH] Add experimental preprocessing scripts for semantic scholar dumps --- README.md | 16 +- docs/search_api.md | 1 - pipelines/retriever.py | 4 +- preprocessing/block.py | 51 +++ .../get_all_wiki_sizes.py | 0 .../get_common_english_words.py | 2 +- .../preprocess_semantic_scholar_dump.py | 308 ++++++++++++++++++ .../preprocess_wikipedia_html_dump.py | 85 ++--- .../upload_collections_to_hf_hub.py | 0 .../utils.py | 6 +- .../wikipedia_disambiguation.py | 0 .../word_list.txt | 0 retrieval/add_payload_index.py | 2 + retrieval/create_index.py | 28 +- retrieval/qdrant_index.py | 13 +- retrieval/retriever_server.py | 11 +- tasks/docker_utils.py | 2 + tasks/retrieval.py | 111 ++++++- tests/test_wikipedia_preprocessing.py | 4 +- 19 files changed, 524 insertions(+), 120 deletions(-) create mode 100644 preprocessing/block.py rename {wikipedia_preprocessing => preprocessing}/get_all_wiki_sizes.py (100%) rename wikipedia_preprocessing/get_word_list.py => preprocessing/get_common_english_words.py (94%) create mode 100644 preprocessing/preprocess_semantic_scholar_dump.py rename wikipedia_preprocessing/preprocess_html_dump.py => preprocessing/preprocess_wikipedia_html_dump.py (93%) rename {wikipedia_preprocessing => preprocessing}/upload_collections_to_hf_hub.py (100%) rename {wikipedia_preprocessing => preprocessing}/utils.py (98%) rename {wikipedia_preprocessing => preprocessing}/wikipedia_disambiguation.py (100%) rename {wikipedia_preprocessing => preprocessing}/word_list.txt (100%) diff --git a/README.md b/README.md index 39ce8c9..e06abc0 100644 --- a/README.md +++ b/README.md @@ -208,6 +208,11 @@ Note that this index contains ~180M vector embeddings and therefore requires at inv start-retriever --embedding-model BAAI/bge-m3 --retriever-port ``` +3. Start WikiChat by passing in the URL of this retriever. For example: +```bash +inv demo --retriever-endpoint "http://0.0.0.0:/search" +``` + Note that this server and its embedding model run on CPU, and do not require GPU. For better performance, on compatible systems you can add `--use-onnx` to use the ONNX version of the embedding model, to significantly lower the embedding latency. ### Option 3: Build your own index @@ -228,7 +233,7 @@ inv index-wikipedia-dump --embedding-model BAAI/bge-m3 --workdir ./workdir --la `block_type` and `language` are only used to provide filtering on search results. If you do not need them, you can simply set them to `block_type=text` and `language=en`. The script will feed `full_section_title` and `content_string` to the embedding model to create embedding vectors. -See `wikipedia_preprocessing/preprocess_html_dump.py` for details on how this is implemented for Wikipedia HTML dumps. +See `preprocessing/preprocess_wikipedia_html_dump.py` for details on how this is implemented for Wikipedia HTML dumps. 1. Run the indexing command: @@ -250,6 +255,11 @@ inv start-retriever --embedding-model BAAI/bge-m3 --retriever-port curl -X POST 0.0.0.0:5100/search -H "Content-Type: application/json" -d '{"query": ["What is GPT-4?", "What is LLaMA-3?"], "num_blocks": 3}' ``` +5. Start WikiChat by passing in the URL of this retriever. For example: +```bash +inv demo --retriever-endpoint "http://0.0.0.0:/search" +``` + #### To upload a Qdrant index to 🤗 Hub: 1. Split the index into smaller parts: @@ -258,8 +268,8 @@ tar -cvf - | pigz -p 14 | split --bytes=10GB - ``` 2. Upload the resulting parts: -``` -retrieval/upload_folder_to_hf_hub.py --folder_path --repo_id +```bash +python retrieval/upload_folder_to_hf_hub.py --folder_path --repo_id ``` diff --git a/docs/search_api.md b/docs/search_api.md index 3cd1e0f..0ed09d3 100644 --- a/docs/search_api.md +++ b/docs/search_api.md @@ -18,7 +18,6 @@ The request body should be a JSON object with the following fields: - `query`: A string or a list of strings representing the search queries. - `num_blocks`: An integer representing the number of items to retrieve. - `languages`: (Optional) A string or a list of strings representing the language codes to filter the search results. -- `block_types`: (Optional) A string or a list of strings representing the block types to filter the search results. ### Example Search for the 3 most relevant text, table or infobox in the any of the 10 Wikipedia languages. diff --git a/pipelines/retriever.py b/pipelines/retriever.py index e290ea4..cd789ef 100644 --- a/pipelines/retriever.py +++ b/pipelines/retriever.py @@ -73,7 +73,7 @@ def retrieval_results_to_list(results: dict): results["score"], results["last_edit_date"], ): - if block_type not in ["text", "table", "infobox", "list"]: + if block_type not in ["text", "table", "infobox"]: logger.warning( f"Found invalid block type {str(block_type)} for {passage}." ) @@ -266,7 +266,7 @@ def _try_to_enforce_block_type_limits( results: list[RetrievalResult], block_type_limits: dict, target_num: int ) -> list[RetrievalResult]: block_type_limits_copy = {} - for k in ["table", "text", "list", "infobox"]: + for k in ["table", "text", "infobox"]: if k not in block_type_limits: block_type_limits_copy[k] = 1e6 # Infinity else: diff --git a/preprocessing/block.py b/preprocessing/block.py new file mode 100644 index 0000000..13f78b2 --- /dev/null +++ b/preprocessing/block.py @@ -0,0 +1,51 @@ +from preprocessing.utils import ( + extract_english_translations, + replace_except_first, +) + + +class Block: + """ + A paragraph, list, linearized table, or linearized Infobox + """ + + content_string: str + article_title: str + full_section_title: str + block_type: str + language: str + last_edit_date: str + num_tokens: int + + def __init__( + self, + content_string: str, + full_section_title: str, + block_type: str, + article_title: str = None, + language: str = None, + last_edit_date: str = None, + ): + self.content_string = content_string.strip() + self.article_title = article_title + self.full_section_title = full_section_title + self.block_type = block_type + self.language = language + self.last_edit_date = last_edit_date + self.num_tokens = 0 + + def to_json(self, _id: int): + ret = self.__dict__ + ret["id"] = _id + return ret + + def deduplicate_translations(self) -> None: + """ + Deduplicates (in English: ...) from each block + """ + string = self.full_section_title + " | " + self.content_string + translation_parenthesis = set(extract_english_translations(string)) + for t in translation_parenthesis: + string = replace_except_first(string, " " + t, "") + + self.full_section_title, self.content_string = tuple(string.split(" | ", 1)) diff --git a/wikipedia_preprocessing/get_all_wiki_sizes.py b/preprocessing/get_all_wiki_sizes.py similarity index 100% rename from wikipedia_preprocessing/get_all_wiki_sizes.py rename to preprocessing/get_all_wiki_sizes.py diff --git a/wikipedia_preprocessing/get_word_list.py b/preprocessing/get_common_english_words.py similarity index 94% rename from wikipedia_preprocessing/get_word_list.py rename to preprocessing/get_common_english_words.py index e5eed98..b94d9db 100644 --- a/wikipedia_preprocessing/get_word_list.py +++ b/preprocessing/get_common_english_words.py @@ -26,5 +26,5 @@ word_list.append(w) print("Extracted %d words from %s" % (len(word_list), word_frequency_page)) -with open("./wikipedia_preprocessing/word_list.txt", "w") as f: +with open("./preprocessing/word_list.txt", "w") as f: f.write("\n".join(word_list)) diff --git a/preprocessing/preprocess_semantic_scholar_dump.py b/preprocessing/preprocess_semantic_scholar_dump.py new file mode 100644 index 0000000..81deb6c --- /dev/null +++ b/preprocessing/preprocess_semantic_scholar_dump.py @@ -0,0 +1,308 @@ +from collections import defaultdict +import itertools +import json +import pathlib + +from langchain_text_splitters import RecursiveCharacterTextSplitter +import orjsonl +import sys +import argparse +import glob +import os +from multiprocessing import cpu_count, Pool + +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm + + +sys.path.insert(0, "./") +from preprocessing.utils import draw_and_save_histogram_log_bins, num_tokens +from preprocessing.block import Block +from pipelines.utils import get_logger + +logger = get_logger(__name__) + + +text_splitter = RecursiveCharacterTextSplitter( + separators=[ + "\n\n", + "\n", + " ", + ".", + ",", + "\u200b", # Zero-width space + "\uff0c", # Fullwidth comma + "\u3001", # Ideographic comma + "\uff0e", # Fullwidth full stop + "\u3002", # Ideographic full stop + "", + ], + chunk_size=500, + chunk_overlap=0, + length_function=num_tokens, + is_separator_regex=False, +) + + +def is_in_subset(article: dict, subset: list[str]) -> bool: + if not article["externalids"]: + return False + + for s in subset: + if s in article["externalids"] and article["externalids"][s]: + return True + + return False + + +def get_section_boundaries(section_headers, article_length): + section_boundaries = [] + + for i in range(len(section_headers) - 1): + section_boundaries.append( + (int(section_headers[i]["end"]), int(section_headers[i + 1]["start"])) + ) + section_boundaries.append((int(section_headers[-1]["end"]), article_length)) + return section_boundaries + + +def which_section(start, end, section_boundaries): + for idx, boundary in enumerate(section_boundaries): + if start >= boundary[0] and end <= boundary[1]: + return idx + # print("start, end: ", start, end) + # print("section_boundaries = ", section_boundaries) + return None + + +def extract_article_blocks(article: dict): + # TODO extract publication date as well + section_title_texts = defaultdict(list) + + # assert ( + # article["content"]["source"]["pdfurls"] is None + # or len(article["content"]["source"]["pdfurls"]) >= 1 + # ) + article_text = article["content"]["text"] + titles = set() + if not article["content"]["annotations"]["title"]: + logger.warning("Skipping article because it has no title") + return [] + if not article["content"]["annotations"]["sectionheader"]: + logger.warning("Skipping article because it has no section headers") + return [] + if not article["content"]["annotations"]["paragraph"]: + logger.warning("Skipping article because it has no paragraphs") + return [] + title = json.loads(article["content"]["annotations"]["title"]) + for t in title: + start, end = int(t["start"]), int( + t["end"] + ) # occasionally, some of them are string instead of int. So we convert them here + titles.add(article_text[start:end]) + if len(titles) != 1: + logger.warning("Found %d titles: %s", len(titles), str(titles)) + # when we have multiple titles, the second one is often just author affiliation misclassified as title + # so we should still select the first title + article_title = list(titles)[0] + + if article["content"]["annotations"]["abstract"]: + # Some articles do not have abstracts, for example because they are just the supplementary material to another article + abstract = json.loads(article["content"]["annotations"]["abstract"]) + abstract_boundaries = set() + for a in abstract: + start, end = int(a["start"]), int(a["end"]) # TODO extract method + abstract_boundaries.add((start, end)) + + # assert len(abstract_boundaries) == 1, abstract + start, end = list(abstract_boundaries)[0] + abstract_text = article_text[start:end] + section_title_texts["Abstract"].append(abstract_text) + # print(abstract_text) + + section_headers = json.loads(article["content"]["annotations"]["sectionheader"]) + section_titles = [] + for header in section_headers: + # print(section) + start, end = int(header["start"]), int(header["end"]) + section_title = article_text[start:end] + if "attributes" in header and "n" in header["attributes"]: + section_number = header["attributes"]["n"] + section_title = f"{section_number}. " + section_title + section_titles.append(section_title) + + # print("section_headers = ", section_headers) + section_boundaries = get_section_boundaries(section_headers, len(article_text)) + # print("section_boundaries = ", section_boundaries) + paragraphs = json.loads(article["content"]["annotations"]["paragraph"]) + paragraphs_without_section = [] + for paragraph in paragraphs: + start, end = int(paragraph["start"]), int(paragraph["end"]) + paragraph_text = article_text[start:end] + section_idx = which_section(start, end, section_boundaries) + if not section_idx: + paragraphs_without_section.append((start, end, paragraph_text)) + else: + section_title_texts[section_titles[section_idx]].append(paragraph_text) + + # Merge paragraphs without section that have consecutive start and end + merged_paragraphs_without_section = [] + for i, (start, end, text) in enumerate(paragraphs_without_section): + if ( + i == 0 or start != merged_paragraphs_without_section[-1][1] + ): # Not consecutive + merged_paragraphs_without_section.append([start, end, text]) + else: # Consecutive, merge with previous + merged_paragraphs_without_section[-1][1] = end + merged_paragraphs_without_section[-1][2] += "\n\n" + text + + _id = article["corpusid"] + article_blocks = [] + assert isinstance(_id, int) + for section_title, section_text in section_title_texts.items(): + section_title_texts[section_title] = "\n\n".join(section_text) + for s in text_splitter.split_text(section_title_texts[section_title]): + + block = Block( + content_string=s, + article_title=article_title.title(), + full_section_title=f"{article_title.title()} > {section_title.title()}", + block_type="text", + language="en", + ) + block.num_tokens = num_tokens( + block.full_section_title + " " + block.content_string + ) + article_blocks.append(block) + for p in merged_paragraphs_without_section: + for s in text_splitter.split_text(p[2]): + block = Block( + content_string=s, + article_title=article_title.title(), + full_section_title=f"{article_title.title()} > -", + block_type="text", + language="en", + ) + block.num_tokens = num_tokens( + block.full_section_title + " " + block.content_string + ) + article_blocks.append(block) + + return article_blocks + + +def articles(files: str, max_articles: int, subset: list[str]): + + counter = 0 + all_articles = itertools.chain.from_iterable(orjsonl.stream(f) for f in files) + for article in all_articles: + if not is_in_subset(article, subset): + continue + yield article + counter += 1 + if counter == max_articles: + break + + +def process_file(_input): + file, max_articles, subset, counter_start = _input + all_blocks = [] + counter = counter_start + for article in articles([file], max_articles, subset): + article_blocks = extract_article_blocks(article) + for block in article_blocks: + all_blocks.append(block.to_json(counter)) + counter += 1 + return all_blocks + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + type=str, + required=True, + help="Directory containing input .jsonl.gz files", + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="The output jsonl or json.gz file to write the blocks to.", + ) + parser.add_argument("--num_workers", type=int, default=max(1, cpu_count() - 4)) + parser.add_argument( + "--subset", + type=str, + nargs="+", + choices=[ + "arxiv", + "acl", + "pubmed", + "pubmedcentral", + "dblp", + "mag", # MAG is Microsoft Academic Graph, which was discontinued in 2021 + ], + required=False, + default=None, + help="Subset(s) of data to process. Can be one or more of the listed choices. If not provided, will process all available data.", + ) + parser.add_argument( + "--max_articles", + type=int, + default=-1, + help="Will stop after processing this many articles. -1 means no limit. Used for testing.", + ) + + args = parser.parse_args() + + input_files = sorted( + glob.glob(os.path.join(args.input_dir, "semantic_scholar_dump_*.jsonl.gz")) + ) + if not input_files: + logger.error(f"No matching files found in {args.input_dir}") + sys.exit(1) + + num_workers = min(args.num_workers, len(input_files)) + tasks = [ + (file, args.max_articles, args.subset, i * 2000000) + for i, file in enumerate(input_files) + ] + + if os.path.exists(args.output_path): + raise FileExistsError( + f"Output file '{args.output_path}' already exists. Please choose a different output path or remove the existing file." + ) + logger.info("Saving the collection to %s", args.output_path) + # make parent directories + pathlib.Path(os.path.dirname(args.output_path)).mkdir(parents=True, exist_ok=True) + + with logging_redirect_tqdm(): + with Pool(processes=num_workers) as pool: + for blocks in tqdm( + pool.imap_unordered(process_file, tasks), + total=len(tasks), + desc="Files", + position=0, + leave=True, + ): + orjsonl.extend(args.output_path, blocks) + + # save the collection size + count = 0 + num_tokens_list = [] + for block in tqdm(orjsonl.stream(args.output_path), desc="Calculating output size"): + count += 1 + num_tokens_list.append(block["num_tokens"]) + with open( + os.path.join(os.path.dirname(args.output_path), "collection_size.txt"), "w" + ) as f: + f.write(str(count)) + + logger.info("Total number of blocks: {:,d}".format(count)) + + # print_histogram([len(block.content_string) for block in all_blocks]) + draw_and_save_histogram_log_bins( + num_tokens_list, + args.output_path.rsplit(".", 1)[0] + "_histogram.png", + ) diff --git a/wikipedia_preprocessing/preprocess_html_dump.py b/preprocessing/preprocess_wikipedia_html_dump.py similarity index 93% rename from wikipedia_preprocessing/preprocess_html_dump.py rename to preprocessing/preprocess_wikipedia_html_dump.py index fd1688d..75b4a43 100644 --- a/wikipedia_preprocessing/preprocess_html_dump.py +++ b/preprocessing/preprocess_wikipedia_html_dump.py @@ -1,4 +1,5 @@ import argparse +import asyncio import os import pathlib import re @@ -11,13 +12,25 @@ from bs4 import BeautifulSoup, NavigableString from markdownify import MarkdownConverter from mwparserfromhtml.parse.plaintext import html_to_plaintext +import orjsonl from tqdm import tqdm -from transformers import AutoTokenizer + + +from preprocessing.block import Block sys.path.insert(0, "./") from pipelines.utils import get_logger -from wikipedia_preprocessing.utils import * -from wikipedia_preprocessing.wikipedia_disambiguation import is_disambiguation +from preprocessing.utils import ( + batch_get_wikidata_english_name, + draw_and_save_histogram_log_bins, + find_forest_roots_and_members, + get_from_translation_map, + load_translation_map, + num_tokens, + save_translation_map, + translation_prefix, +) +from preprocessing.wikipedia_disambiguation import is_disambiguation logger = get_logger(__name__) from transformers.utils import logging as transformers_logging @@ -28,54 +41,6 @@ {} ) # used to expand translation search. This map includes a map of each root to itself, for simplicity frequent_words_to_exclude = set() -tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3", fast=True) - - -class Block: - """ - A paragraph, list, linearized table, or linearized Infobox - """ - - content_string: str - article_title: str - full_section_title: str - block_type: str - language: str - last_edit_date: str - num_tokens: int - - def __init__( - self, - content_string: str, - full_section_title: str, - block_type: str, - article_title: str = None, - language: str = None, - last_edit_date: str = None, - ): - self.content_string = content_string.strip() - self.article_title = article_title - self.full_section_title = full_section_title - self.block_type = block_type - self.language = language - self.last_edit_date = last_edit_date - self.num_tokens = 0 - - def to_json(self, _id: int): - ret = self.__dict__ - ret["id"] = _id - return ret - - def deduplicate_translations(self) -> None: - """ - Deduplicates (in English: ...) from each block - """ - string = self.full_section_title + " | " + self.content_string - translation_parenthesis = set(extract_english_translations(string)) - for t in translation_parenthesis: - string = replace_except_first(string, " " + t, "") - - self.full_section_title, self.content_string = tuple(string.split(" | ", 1)) banned_sections = { @@ -416,10 +381,6 @@ def get_full_heading(): return blocks -def num_tokens(text: str) -> int: - return len(tokenizer(text)["input_ids"]) - - def pack_blocks(blocks: list[tuple[str, str]], pack_to_tokens: int) -> list[Block]: """ Passages is the list of tuples where each tuple is (subsection title, passage). @@ -823,7 +784,7 @@ def articles_without_disambiguation_or_redirections( len(redirection_map), ) if args.num_exclude_frequent_words_from_translation > 0: - with open("wikipedia_preprocessing/word_list.txt") as f: + with open("preprocessing/word_list.txt") as f: for line in f: frequent_words_to_exclude.add(line.strip()) if ( @@ -891,7 +852,6 @@ def articles_without_disambiguation_or_redirections( # make parent directories pathlib.Path(os.path.dirname(args.output_path)).mkdir(parents=True, exist_ok=True) - # compact just removes the extra space in the json formatting counter = 0 while True: @@ -912,8 +872,12 @@ def articles_without_disambiguation_or_redirections( all_blocks.append(block.to_json(counter)) counter += 1 + # Wait for processes to complete + for p in all_worker_processes + [reader_process]: + p.join() + logger.info("Saving the collection to %s", args.output_path) - orjsonl.save(args.output_path, all_blocks, compression_format="gz") + orjsonl.save(args.output_path, all_blocks) # save the collection size with open( @@ -921,9 +885,6 @@ def articles_without_disambiguation_or_redirections( ) as f: f.write(str(len(all_blocks))) - # Wait for processes to complete - for p in all_worker_processes + [reader_process]: - p.join() logger.info("Found {:,d} text blocks (including lists)".format(text_count)) logger.info("Found {:,d} table blocks".format(table_count)) @@ -935,5 +896,5 @@ def articles_without_disambiguation_or_redirections( # print_histogram([len(block.content_string) for block in all_blocks]) draw_and_save_histogram_log_bins( [block["num_tokens"] for block in all_blocks], - args.output_path.split(".")[0] + "_histogram.png", + args.output_path.rsplit(".", 1)[0] + "_histogram.png", ) diff --git a/wikipedia_preprocessing/upload_collections_to_hf_hub.py b/preprocessing/upload_collections_to_hf_hub.py similarity index 100% rename from wikipedia_preprocessing/upload_collections_to_hf_hub.py rename to preprocessing/upload_collections_to_hf_hub.py diff --git a/wikipedia_preprocessing/utils.py b/preprocessing/utils.py similarity index 98% rename from wikipedia_preprocessing/utils.py rename to preprocessing/utils.py index 7389c95..dc74d49 100644 --- a/wikipedia_preprocessing/utils.py +++ b/preprocessing/utils.py @@ -9,6 +9,7 @@ from matplotlib import pyplot as plt from tqdm import tqdm, trange from tqdm.contrib.logging import logging_redirect_tqdm +from transformers import AutoTokenizer sys.path.insert(0, "./") from pipelines.utils import get_logger @@ -20,7 +21,7 @@ # this is different from the case that the key is absent, which means we have never looked up that translation in Wikidata global_translation_map = {} - +tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3", fast=True) translation_prefix = "(in English: " @@ -372,3 +373,6 @@ def find_forest_roots_and_members(incoming_edges): for root in roots: trees[root] = find_tree_members(root, incoming_edges) return trees + +def num_tokens(text: str) -> int: + return len(tokenizer(text)["input_ids"]) diff --git a/wikipedia_preprocessing/wikipedia_disambiguation.py b/preprocessing/wikipedia_disambiguation.py similarity index 100% rename from wikipedia_preprocessing/wikipedia_disambiguation.py rename to preprocessing/wikipedia_disambiguation.py diff --git a/wikipedia_preprocessing/word_list.txt b/preprocessing/word_list.txt similarity index 100% rename from wikipedia_preprocessing/word_list.txt rename to preprocessing/word_list.txt diff --git a/retrieval/add_payload_index.py b/retrieval/add_payload_index.py index 7fb9996..0bdc7ea 100644 --- a/retrieval/add_payload_index.py +++ b/retrieval/add_payload_index.py @@ -1,7 +1,9 @@ import argparse import sys + from qdrant_client import QdrantClient from qdrant_client.models import PayloadSchemaType + sys.path.insert(0, "./") from tasks.defaults import DEFAULT_QDRANT_COLLECTION_NAME diff --git a/retrieval/create_index.py b/retrieval/create_index.py index 1278c5f..a558a83 100644 --- a/retrieval/create_index.py +++ b/retrieval/create_index.py @@ -18,8 +18,6 @@ Datatype, Distance, HnswConfigDiff, - OptimizersConfig, - PayloadSchemaType, VectorParams, ) from tqdm import tqdm @@ -117,9 +115,6 @@ def commit_to_index( on_disk=True, datatype=Datatype.FLOAT16, ), - # optimizers_config=OptimizersConfig( - # indexing_threshold=10 - # ), hnsw_config=HnswConfigDiff( m=64, ef_construct=100, full_scan_threshold=10, on_disk=True ), @@ -174,34 +169,13 @@ def commit_to_index( pbar.update(len(batch_blocks)) - # index these payload fields, so that we can filter searched based on them - # logger.info("Indexing payload fields.") - # qdrant_client.create_payload_index( - # collection_name=args.collection_name, - # field_name="language", - # field_schema=PayloadSchemaType.KEYWORD, - # wait=False, - # ) - # qdrant_client.create_payload_index( - # collection_name=args.collection_name, - # field_name="block_type", - # field_schema=PayloadSchemaType.KEYWORD, - # wait=False, - # ) - # qdrant_client.create_payload_index( - # collection_name=args.collection_name, - # field_name="title", - # field_schema=PayloadSchemaType.KEYWORD, - # wait=False, - # ) - qdrant_client.close() def batch_generator(collection_file, embedding_batch_size): """Generator function to yield batches of data from the queue.""" batch = [] - for block in orjsonl.stream(collection_file, compression_format="gz"): + for block in orjsonl.stream(collection_file): batch.append(block) if len(batch) == embedding_batch_size: yield batch diff --git a/retrieval/qdrant_index.py b/retrieval/qdrant_index.py index a39f113..593b312 100644 --- a/retrieval/qdrant_index.py +++ b/retrieval/qdrant_index.py @@ -1,7 +1,7 @@ import asyncio import math from time import time -from typing import Any +from typing import Any, Optional import numpy as np import onnxruntime as ort @@ -41,6 +41,10 @@ "embedding_dimension": 384, "query_prefix": "Represent this sentence for searching relevant passages: ", }, + "Alibaba-NLP/gte-base-en-v1.5": { + "embedding_dimension": 768, + "query_prefix": "", + }, # Its maximum sequence length is 8192 tokens "Alibaba-NLP/gte-Qwen2-1.5B-instruct": { "embedding_dimension": 1536, "query_prefix": "Given a web search query, retrieve relevant passages that answer the query", @@ -95,7 +99,7 @@ class SearchResult(BaseModel): language: list[str] = Field( default_factory=list, json_schema_extra={"example": ["ja", "pt", "es"]} ) - last_edit_date: list[str] = Field( + last_edit_date: list[Optional[str]] = Field( default_factory=list, json_schema_extra={ "example": [ @@ -129,7 +133,7 @@ def __init__( use_onnx: bool = False, ): self.qdrant_client = AsyncQdrantClient( - url=qdrant_url, timeout=10, prefer_grpc=True + url=qdrant_url, timeout=20, prefer_grpc=True ) # Lower timeout so that we don't get stuck on long retrieval jobs self.collection_name = collection_name self.use_onnx = use_onnx @@ -162,7 +166,7 @@ def __init__( self.ort_session = ort.InferenceSession(onnx_model_path) else: - self.embedding_model = AutoModel.from_pretrained(embedding_model_name) + self.embedding_model = AutoModel.from_pretrained(embedding_model_name, trust_remote_code=True) # self.embedding_model.eval() logger.info("Successfully loaded the embedding model " + embedding_model_name) @@ -262,7 +266,6 @@ def search_result_to_pydantic(search_result: list[ScoredPoint]) -> SearchResult: # take the softmax passage_probs = [math.exp(score) for score in ret["score"]] ret["prob"] = [prob / sum(passage_probs) for prob in passage_probs] - return SearchResult(**ret) def embed_queries(self, queries: list[str]): diff --git a/retrieval/retriever_server.py b/retrieval/retriever_server.py index 615b64e..06f6d90 100644 --- a/retrieval/retriever_server.py +++ b/retrieval/retriever_server.py @@ -73,9 +73,6 @@ class QueryData(BaseModel): languages: Optional[Union[list[Language], Language]] = Field( default=None, description="The language codes of the results you want" ) - block_types: Optional[Union[list[BlockType], BlockType]] = Field( - default=None, description="Can be zero, one or more of the defined block types" - ) @field_validator("query") def validate_query(cls, value): @@ -95,7 +92,8 @@ def validate_query(cls, value): return value model_config = ConfigDict( - use_enum_values=True # This allows the API to accept string values for enums + extra="forbid", + use_enum_values=True, # This allows the API to accept string values for enums ) @@ -143,16 +141,13 @@ async def search(request: Request, query_data: QueryData): languages = query_data.languages if isinstance(languages, str): languages = [languages] - block_types = query_data.block_types - if isinstance(block_types, str): - block_types = [block_types] try: results = await qdrant_search( tuple(queries), evi_num, languages=tuple(languages) if languages else None, - block_types=tuple(block_types) if block_types else None, + block_types=None, ) except Exception as e: logger.error(str(e)) diff --git a/tasks/docker_utils.py b/tasks/docker_utils.py index 63c8a6d..2071172 100644 --- a/tasks/docker_utils.py +++ b/tasks/docker_utils.py @@ -265,6 +265,8 @@ def start_embedding_docker_container( "50000", "--max-concurrent-requests", "128", + "--dtype", + "float16", "--hostname", "0.0.0.0", ], diff --git a/tasks/retrieval.py b/tasks/retrieval.py index 88e34fc..bc1369c 100644 --- a/tasks/retrieval.py +++ b/tasks/retrieval.py @@ -4,6 +4,7 @@ import sys import threading from concurrent.futures import ThreadPoolExecutor +import concurrent from datetime import datetime from typing import Optional @@ -18,7 +19,7 @@ start_qdrant_docker_container, stop_all_docker_containers, ) -from tasks.main import start_redis +from tasks.main import load_api_keys, start_redis sys.path.insert(0, "./") from pipelines.utils import get_logger @@ -117,7 +118,9 @@ def download_chunk_from_url( pbar.update(len(chunk)) -def multithreaded_download(url: str, output_path: str, num_parts: int = 3) -> None: +def multithreaded_download( + url: str, output_path: str, num_parts: int = 3, disable_tqdm=False +) -> None: """ Download a file in parts concurrently using multiple threads to optimize the download process. @@ -131,6 +134,7 @@ def multithreaded_download(url: str, output_path: str, num_parts: int = 3) -> No output_path (str): The path to which the file will be downloaded. num_parts (int, optional): The number of parts into which the download will be split. Defaults to 3, based on the typical rate limiting encountered from servers. + disable_tqdm (bool, optional): Whether to disable showing a progress bar Note: The server from which files are being downloaded must support range requests, as the function relies on @@ -154,8 +158,9 @@ def multithreaded_download(url: str, output_path: str, num_parts: int = 3) -> No total=total_size, unit="iB", unit_scale=True, - desc="Downloading the HTML dump", + desc="Downloading file", smoothing=0, + disable=disable_tqdm, ) for i in range(num_parts): @@ -307,7 +312,6 @@ def test_index( @task( pre=[ stop_all_docker_containers, - start_embedding_docker_container, start_qdrant_docker_container, ], post=[stop_all_docker_containers], @@ -340,6 +344,9 @@ def index_collection( - embedding_model_port (int): The port on which the embedding model server is running. - embedding_model_name (str): The HuggingFace ID of the embedding model to use for indexing. """ + start_embedding_docker_container( + c, port=embedding_model_port, embedding_model=embedding_model_name + ) c.run( f"python retrieval/create_index.py " f"--collection_file {collection_path} " @@ -416,7 +423,7 @@ def preprocess_wikipedia_dump( # Constructing the command with parameters command = ( - f"python wikipedia_preprocessing/preprocess_html_dump.py " + f"python preprocessing/preprocess_wikipedia_html_dump.py " f"--input_path {input_path} " f"--output_path {output_path} " f"--wikidata_translation_map {wikidata_translation_map} " @@ -433,7 +440,6 @@ def preprocess_wikipedia_dump( @task( pre=[ stop_all_docker_containers, - start_embedding_docker_container, start_qdrant_docker_container, ], post=[stop_all_docker_containers], @@ -468,6 +474,13 @@ def index_wikipedia_dump( - language (str): The language edition of Wikipedia to index (e.g., "en" for English). - wikipedia_date (str, optional): The date of the Wikipedia dump to use. If not provided, the latest available dump is used. """ + start_embedding_docker_container( + c, + workdir=workdir, + port=embedding_model_port, + embedding_model=embedding_model_name, + ) + if not wikipedia_date: wikipedia_date = get_latest_wikipedia_dump_date() preprocess_wikipedia_dump( @@ -493,18 +506,29 @@ def index_wikipedia_dump( @task( pre=[ stop_all_docker_containers, - start_embedding_docker_container, start_qdrant_docker_container, ], post=[stop_all_docker_containers], iterable=["language"], ) def index_multiple_wikipedia_dumps( - c, language, workdir: str = DEFAULT_WORKDIR, wikipedia_date: Optional[str] = None + c, + language, + workdir: str = DEFAULT_WORKDIR, + wikipedia_date: Optional[str] = None, + embedding_model_port: int = DEFAULT_EMBEDDING_MODEL_PORT, + embedding_model_name: str = DEFAULT_EMBEDDING_MODEL_NAME, ): """ Index multiple Wikipedia dumps from different languages in a for loop. """ + start_embedding_docker_container( + c, + workdir=workdir, + port=embedding_model_port, + embedding_model=embedding_model_name, + ) + if not wikipedia_date: wikipedia_date = get_latest_wikipedia_dump_date() for l in language: @@ -512,3 +536,74 @@ def index_multiple_wikipedia_dumps( index_wikipedia_dump( c, workdir=workdir, language=l, wikipedia_date=wikipedia_date ) + + +@task(pre=[load_api_keys]) +def download_semantic_scholar_dump( + c, workdir: str = DEFAULT_WORKDIR, num_threads: int = 8 +): + # Read the API key from environment variables + api_key = os.environ.get("SEMANTIC_SCHOLAR_API_KEY") + if not api_key: + raise ValueError("SEMANTIC_SCHOLAR_API_KEY environment variable is not set") + + # Set up headers with the API key + headers = {"x-api-key": api_key} + + # Call the Semantic Scholar API to get the latest release information + response = requests.get("https://api.semanticscholar.org/datasets/v1/release/") + + if response.status_code != 200: + raise Exception( + f"Failed to fetch release information. Status code: {response.status_code}" + ) + + latest_release = response.json()[-1] + + response = requests.get( + f"https://api.semanticscholar.org/datasets/v1/release/{latest_release}/dataset/s2orc", + headers=headers, + ) + if response.status_code != 200: + raise Exception( + f"Failed to fetch release details. Status code: {response.status_code}" + ) + release_details = response.json() + # logger.info(release_details) + file_urls = release_details["files"] + + # Download all files + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [] + for file_idx, url in enumerate(file_urls): + file_path = os.path.join( + workdir, f"semantic_scholar_dump_{file_idx:03d}.jsonl.gz" + ) + future = executor.submit( + multithreaded_download, url, file_path, num_parts=1 + ) + futures.append(future) + + for future in tqdm( + concurrent.futures.as_completed(futures), total=len(futures), desc="Files" + ): + future.result() # This will raise any exceptions that occurred during download + + logger.info("All files have been downloaded.") + + # Find the part files + part_files = glob.glob( + os.path.join(workdir, "semantic_scholar_dump_part*.jsonl.gz") + ) + + # Ensure part_files is not empty + if not part_files: + raise FileNotFoundError("No part files found in the specified directory.") + + # Unzip part files + for file in tqdm(part_files, desc="Running gunzip"): + output_file = file[:-3] # Remove the .gz extension + c.run(f"gunzip -c {file} > {output_file}", pty=True) + logger.info(f"Unzipped {file} to {output_file}") + + logger.info("All part files have been unzipped.") diff --git a/tests/test_wikipedia_preprocessing.py b/tests/test_wikipedia_preprocessing.py index f854929..f324c72 100644 --- a/tests/test_wikipedia_preprocessing.py +++ b/tests/test_wikipedia_preprocessing.py @@ -5,11 +5,11 @@ from bs4 import BeautifulSoup sys.path.insert(0, "./") -from wikipedia_preprocessing.preprocess_html_dump import ( +from preprocessing.preprocess_wikipedia_html_dump import ( find_h_tags_hierarchy, get_adjacent_tags, ) -from wikipedia_preprocessing.utils import ( +from preprocessing.utils import ( extract_english_translations, find_forest_roots_and_members, get_wikidata_english_name,