diff --git a/README.md b/README.md index 569e7303..b5ebb642 100644 --- a/README.md +++ b/README.md @@ -79,13 +79,11 @@ Below, we provide a quick start guide to run STORM locally. ### 2. Running STORM-wiki locally -Currently, we provide example scripts under [`examples`](examples) to demonstrate how you can run STORM using different models. - -**To run STORM with `gpt` family models**: Make sure you have set up the OpenAI API key and run the following command. +**To run STORM with `gpt` family models with default configurations**: Make sure you have set up the OpenAI API key and run the following command. ``` python examples/run_storm_wiki_gpt.py \ - --output_dir $OUTPUT_DIR \ + --output-dir $OUTPUT_DIR \ --retriever you \ --do-research \ --do-generate-outline \ @@ -97,29 +95,15 @@ python examples/run_storm_wiki_gpt.py \ - `--do-generate-article`: If True, generate an article for the topic; otherwise, load the results. - `--do-polish-article`: If True, polish the article by adding a summarization section and (optionally) removing duplicate content. -**To run STORM with `mistral` family models on local VLLM server**: have a VLLM server running with the `Mistral-7B-Instruct-v0.2` model and run the following command. -``` -python examples/run_storm_wiki_mistral.py \ - --url $URL \ - --port $PORT \ - --output_dir $OUTPUT_DIR \ - --retriever you \ - --do-research \ - --do-generate-outline \ - --do-generate-article \ - --do-polish-article -``` -- `--url` URL of the VLLM server. -- `--port` Port of the VLLM server. +We provide more example scripts under [`examples`](examples) to demonstrate how you can run STORM using your favorite language models or grounding on your own corpus. - ## Customize STORM ### Customization of the Pipeline -STORM is a knowledge curation engine consisting of 4 modules: +Besides running scripts in `examples`, you can customize STORM based on your own use case. STORM engine consists of 4 modules: 1. Knowledge Curation Module: Collects a broad coverage of information about the given topic. 2. Outline Generation Module: Organizes the collected information by generating a hierarchical outline for the curated knowledge. @@ -132,9 +116,11 @@ The interface for each module is defined in `src/interface.py`, while their impl ### Customization of Retriever Module -As a knowledge curation engine, STORM grabs information from the Retriever module. The interface for the Retriever module is defined in [`src/interface.py`](src/interface.py). Please consult the interface documentation if you plan to create a new instance or replace the default search engine API. By default, STORM utilizes the You.com search engine API (see `YouRM` in [`src/rm.py`](src/rm.py)). +As a knowledge curation engine, STORM grabs information from the Retriever module. The Retriever modules are implemented in [`src/rm.py`](src/rm.py). Currently, STORM supports the following retrievers: -:new: [2024/05] We test STORM with [Bing Search](https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/reference/endpoints). See `BingSearch` in [`src/rm.py`](src/rm.py) for the configuration and you can specify `--retriever bing` to use Bing Search in our [example scripts](examples). +- `YouRM`: You.com search engine API +- `BingSearch`: Bing Search API +- `VectorRM`: a retrieval model that retrieves information from user provide corpus :star2: **PRs for integrating more search engines/retrievers are highly appreciated!** diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..d050ec00 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,116 @@ +# Examples + +We host a number of example scripts for various customization of STORM (e.g., use your favorite language models, use your own corpus, etc.). These examples can be starting points for your own customizations and you are welcome to contribute your own examples by submitting a pull request to this directory. + +## Run STORM with your own language model +[run_storm_wiki_gpt.py](run_storm_wiki_gpt.py) provides an example of running STORM with GPT models, and [run_storm_wiki_claude.py](run_storm_wiki_claude.py) provides an example of running STORM with Claude models. Besides using close-source models, you can also run STORM with models with open weights. + +`run_storm_wiki_mistral.py` provides an example of running STORM with `Mistral-7B-Instruct-v0.2` using [VLLM](https://docs.vllm.ai/en/stable/) server: + +1. Set up a VLLM server with the `Mistral-7B-Instruct-v0.2` model running. +2. Run the following command under the root directory of the repository: + + ``` + python examples/run_storm_wiki_mistral.py \ + --url $URL \ + --port $PORT \ + --output-dir $OUTPUT_DIR \ + --retriever you \ + --do-research \ + --do-generate-outline \ + --do-generate-article \ + --do-polish-article + ``` + - `--url` URL of the VLLM server. + - `--port` Port of the VLLM server. + +Besides VLLM server, STORM is also compatible with [TGI](https://huggingface.co/docs/text-generation-inference/en/index) server or [Together.ai](https://www.together.ai/products#inference) endpoint. + + +## Run STORM with your own corpus + +By default, STORM is grounded on the Internet using the search engine, but it can also be grounded on your own corpus using `VectorRM`. [run_storm_wiki_with_gpt_with_VectorRM.py](run_storm_wiki_gpt_with_VectorRM.py) provides an example of running STORM grounding on your provided data. + +1. Set up API keys. + - Make sure you have set up the OpenAI API key. + - `VectorRM` use [Qdrant](https://github.com/qdrant/qdrant-client) to create a vector store. If you want to set up this vector store online on a [Qdrant cloud server](https://cloud.qdrant.io/login), you need to set up `QDRANT_API_KEY` in `secrets.toml` as well; if you want to save the vector store locally, make sure you provide a location for the vector store. +2. Prepare your corpus. The documents should be provided as a single CSV file with the following format: + + | content | title | url | description | + |------------------------|------------|------------|------------------------------------| + | I am a document. | Document 1 | docu-n-112 | A self-explanatory document. | + | I am another document. | Document 2 | docu-l-13 | Another self-explanatory document. | + | ... | ... | ... | ... | + + - `url` will be used as a unique identifier of the document in STORM engine, so ensure different documents have different urls. + - The contents for `title` and `description` columns are optional. If not provided, the script will use default empty values. + - The content column is crucial and should be provided for each document. + +3. Run the command under the root directory of the repository: + To create the vector store offline, run + + ``` + python examples/run_storm_wiki_gpt_with_VectorRM.py \ + --output-dir $OUTPUT_DIR \ + --vector-db-mode offline \ + --offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \ + --update-vector-store \ + --csv-file-path $CSV_FILE_PATH \ + --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \ + --do-research \ + --do-generate-outline \ + --do-generate-article \ + --do-polish-article + ``` + + To create the vector store online on a Qdrant server, run + + ``` + python examples/run_storm_wiki_gpt_with_VectorRM.py \ + --output-dir $OUTPUT_DIR \ + --vector-db-mode online \ + --online-vector-db-url $ONLINE_VECTOR_DB_URL \ + --update-vector-store \ + --csv-file-path $CSV_FILE_PATH \ + --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \ + --do-research \ + --do-generate-outline \ + --do-generate-article \ + --do-polish-article + ``` + +4. **Quick test with Kaggle arXiv Paper Abstracts dataset**: + + - Download `arxiv_data_210930-054931.csv` from [here](https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts). + - Run the following command under the root directory to downsample the dataset by filtering papers with terms `[cs.CV]` and get a csv file that match the format mentioned above. + + ``` + python examples/helper/process_kaggle_arxiv_abstract_dataset --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV + ``` + - Run the following command to run STORM grounding on the processed dataset. You can input a topic related to computer vision (e.g., "The progress of multimodal models in computer vision") to see the generated article. (Note that the generated article may not include enough details since the quick test only use the abstracts of arxiv papers.) + + ``` + python examples/run_storm_wiki_gpt_with_VectorRM.py \ + --output-dir $OUTPUT_DIR \ + --vector-db-mode offline \ + --offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \ + --update-vector-store \ + --csv-file-path $PATH_TO_THE_PROCESSED_CSV \ + --device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \ + --do-research \ + --do-generate-outline \ + --do-generate-article \ + --do-polish-article + ``` + - For a quicker run, you can also download the pre-embedded vector store directly from [here](https://drive.google.com/file/d/1bijFkw5BKU7bqcmXMhO-5hg2fdKAL9bf/view?usp=share_link). + + ``` + python examples/run_storm_wiki_gpt_with_VectorRM.py \ + --output-dir $OUTPUT_DIR \ + --vector-db-mode offline \ + --offline-vector-db-dir $DOWNLOADED_VECTOR_DB_DR \ + --do-research \ + --do-generate-outline \ + --do-generate-article \ + --do-polish-article + ``` \ No newline at end of file diff --git a/examples/helper/process_kaggle_arxiv_abstract_dataset.py b/examples/helper/process_kaggle_arxiv_abstract_dataset.py new file mode 100644 index 00000000..4cb885c1 --- /dev/null +++ b/examples/helper/process_kaggle_arxiv_abstract_dataset.py @@ -0,0 +1,28 @@ +"""Process `arxiv_data_210930-054931.csv` from https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts +to a csv file that is compatible with VectorRM. +""" + +from argparse import ArgumentParser + +import pandas as pd + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--input-path", type=str, help="Path to arxiv_data_210930-054931.csv.") + parser.add_argument("--output-path", type=str, + help="Path to store the csv file that is compatible with VectorRM.") + args = parser.parse_args() + + df = pd.read_csv(args.input_path) + print(f'The original dataset has {len(df)} samples.') + + # Downsample the dataset. + df = df[df['terms'] == "['cs.CV']"] + + # Reformat the dataset to match the VectorRM input format. + df.rename(columns={"abstracts": "content", "titles": "title"}, inplace=True) + df['url'] = ['uid_' + str(idx) for idx in range(len(df))] # Ensure the url is unique. + df['description'] = '' + + print(f'The downsampled dataset has {len(df)} samples.') + df.to_csv(args.output_path, index=False) diff --git a/examples/run_storm_wiki_gpt_with_VectorRM.py b/examples/run_storm_wiki_gpt_with_VectorRM.py new file mode 100644 index 00000000..c5dd4354 --- /dev/null +++ b/examples/run_storm_wiki_gpt_with_VectorRM.py @@ -0,0 +1,165 @@ +""" +This STORM Wiki pipeline powered by GPT-3.5/4 and local retrieval model that uses Qdrant. +You need to set up the following environment variables to run this script: + - OPENAI_API_KEY: OpenAI API key + - OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure') + - QDRANT_API_KEY: Qdrant API key (needed ONLY if online vector store was used) + +You will also need an existing Qdrant vector store either saved in a folder locally offline or in a server online. +If not, then you would need a CSV file with documents, and the script is going to create the vector store for you. +The CSV should be in the following format: +content | title | url | description +I am a document. | Document 1 | docu-n-112 | A self-explanatory document. +I am another document. | Document 2 | docu-l-13 | Another self-explanatory document. + +Notice that the URL will be a unique identifier for the document so ensure different documents have different urls. + +Output will be structured as below +args.output_dir/ + topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash + conversation_log.json # Log of information-seeking conversation + raw_search_results.json # Raw search results from search engine + direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge + storm_gen_outline.txt # Outline refined with collected information + url_to_info.json # Sources that are used in the final article + storm_gen_article.txt # Final article generated + storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) +""" + +import os +import sys +from argparse import ArgumentParser + +sys.path.append('./src') +from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from rm import VectorRM +from lm import OpenAIModel +from utils import load_api_key + + +def main(args): + # Load API key from the specified toml file path + load_api_key(toml_file_path='secrets.toml') + + # Initialize the language model configurations + engine_lm_configs = STORMWikiLMConfigs() + openai_kwargs = { + 'api_key': os.getenv("OPENAI_API_KEY"), + 'api_provider': os.getenv('OPENAI_API_TYPE'), + 'temperature': 1.0, + 'top_p': 0.9, + } + + # STORM is a LM system so different components can be powered by different models. + # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm + # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models + # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm + # which is responsible for generating sections with citations. + conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) + question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) + outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=400, **openai_kwargs) + article_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=700, **openai_kwargs) + article_polish_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=4000, **openai_kwargs) + + engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm) + engine_lm_configs.set_question_asker_lm(question_asker_lm) + engine_lm_configs.set_outline_gen_lm(outline_gen_lm) + engine_lm_configs.set_article_gen_lm(article_gen_lm) + engine_lm_configs.set_article_polish_lm(article_polish_lm) + + # Initialize the engine arguments + engine_args = STORMWikiRunnerArguments( + output_dir=args.output_dir, + max_conv_turn=args.max_conv_turn, + max_perspective=args.max_perspective, + search_top_k=args.search_top_k, + max_thread_num=args.max_thread_num, + ) + + # Setup VectorRM to retrieve information from your own data + rm = VectorRM(collection_name=args.collection_name, device=args.device, k=engine_args.search_top_k) + + # initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally): + if args.vector_db_mode == 'offline': + rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir) + elif args.vector_db_mode == 'online': + rm.init_online_vector_db(url=args.online_vector_db_url, api_key=os.getenv('QDRANT_API_KEY')) + + # Update the vector store with the documents in the csv file + if args.update_vector_store: + rm.update_vector_store( + file_path=args.csv_file_path, + content_column='content', + title_column='title', + url_column='url', + desc_column='description', + batch_size=args.embed_batch_size + ) + + # Initialize the STORM Wiki Runner + runner = STORMWikiRunner(engine_args, engine_lm_configs, rm) + + # run the pipeline + topic = input('Topic: ') + runner.run( + topic=topic, + do_research=args.do_research, + do_generate_outline=args.do_generate_outline, + do_generate_article=args.do_generate_article, + do_polish_article=args.do_polish_article, + ) + runner.post_run() + runner.summary() + + +if __name__ == "__main__": + parser = ArgumentParser() + # global arguments + parser.add_argument('--output-dir', type=str, default='./results/gpt_retrieval', + help='Directory to store the outputs.') + parser.add_argument('--max-thread-num', type=int, default=3, + help='Maximum number of threads to use. The information seeking part and the article generation' + 'part can speed up by using multiple threads. Consider reducing it if keep getting ' + '"Exceed rate limit" error when calling LM API.') + # provide local corpus and set up vector db + parser.add_argument('--collection-name', type=str, default="my_documents", + help='The collection name for vector store.') + parser.add_argument('--device', type=str, default="mps", + help='The device used to run the retrieval model (mps, cuda, cpu, etc).') + parser.add_argument('--vector-db-mode', type=str, choices=['offline', 'online'], + help='The mode of the Qdrant vector store (offline or online).') + parser.add_argument('--offline-vector-db-dir', type=str, default='./vector_store', + help='If use offline mode, please provide the directory to store the vector store.') + parser.add_argument('--online-vector-db-url', type=str, + help='If use online mode, please provide the url of the Qdrant server.') + parser.add_argument('--update-vector-store', action='store_true', + help='If True, update the vector store with the documents in the csv file; otherwise, ' + 'use the existing vector store.') + parser.add_argument('--csv-file-path', type=str, + help='The path of the custom document corpus in CSV format. The CSV file should include ' + 'content, title, url, and description columns.') + parser.add_argument('--embed-batch-size', type=int, default=64, + help='Batch size for embedding the documents in the csv file.') + # stage of the pipeline + parser.add_argument('--do-research', action='store_true', + help='If True, simulate conversation to research the topic; otherwise, load the results.') + parser.add_argument('--do-generate-outline', action='store_true', + help='If True, generate an outline for the topic; otherwise, load the results.') + parser.add_argument('--do-generate-article', action='store_true', + help='If True, generate an article for the topic; otherwise, load the results.') + parser.add_argument('--do-polish-article', action='store_true', + help='If True, polish the article by adding a summarization section and (optionally) removing ' + 'duplicate content.') + # hyperparameters for the pre-writing stage + parser.add_argument('--max-conv-turn', type=int, default=3, + help='Maximum number of questions in conversational question asking.') + parser.add_argument('--max-perspective', type=int, default=3, + help='Maximum number of perspectives to consider in perspective-guided question asking.') + parser.add_argument('--search-top-k', type=int, default=3, + help='Top k search results to consider for each search query.') + # hyperparameters for the writing stage + parser.add_argument('--retrieve-top-k', type=int, default=3, + help='Top k collected references for each section title.') + parser.add_argument('--remove-duplicate', action='store_true', + help='If True, remove duplicate content from the article.') + main(parser.parse_args()) diff --git a/requirements.txt b/requirements.txt index 3893cf28..b1bdfa51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,29 @@ -flair==0.13.0 -nltk==3.8.1 -dspy-ai==2.1.1 -pandas==2.1.1 -scikit-learn==1.3.2 -chardet==5.2.0 -sentence-transformers==2.2.2 -transformers==4.34.1 -scipy==1.10.1 -fastchat +dspy_ai==2.4.9 +streamlit==1.31.1 wikipedia==1.4.0 -Wikipedia-API==0.6.0 -rouge-score +streamlit_authenticator==0.2.3 +streamlit_oauth==0.1.8 +streamlit-card +google-cloud==0.34.0 +google-cloud-vision==3.5.0 +google-cloud-storage==2.14.0 +sentence_transformers toml -tqdm==4.66.0 +markdown +unidecode +extra-streamlit-components==0.1.60 +google-cloud-firestore==2.14.0 +firebase-admin==6.4.0 +streamlit_extras +streamlit_cookies_manager +deprecation==2.1.0 +st-pages==0.4.5 +streamlit-float +streamlit-option-menu +sentry-sdk +pdfkit==1.0.0 +langchain-text-splitters +trafilatura +langchain-huggingface +qdrant-client +langchain-qdrant diff --git a/src/rm.py b/src/rm.py index f3e0a784..5126aa5a 100644 --- a/src/rm.py +++ b/src/rm.py @@ -3,7 +3,13 @@ from typing import Callable, Union, List import dspy +import pandas as pd import requests +from langchain_huggingface import HuggingFaceEmbeddings +from langchain_core.documents import Document +from langchain_qdrant import Qdrant +from qdrant_client import QdrantClient, models +from tqdm import tqdm from utils import WebPageHelper @@ -157,3 +163,244 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st collected_results.append(r) return collected_results + + +class VectorRM(dspy.Retrieve): + """Retrieve information from custom documents using Qdrant. + + To be compatible with STORM, the custom documents should have the following fields: + - content: The main text content of the document. + - title: The title of the document. + - url: The URL of the document. STORM use url as the unique identifier of the document, so ensure different + documents have different urls. + - description (optional): The description of the document. + The documents should be stored in a CSV file. + """ + + def __init__(self, + collection_name: str = "my_documents", + embedding_model: str = 'BAAI/bge-m3', + device: str = "mps", + k: int = 3, + chunk_size: int = 500, + chunk_overlap: int = 100): + """ + Params: + collection_name: Name of the Qdrant collection. + embedding_model: Name of the Hugging Face embedding model. + device: Device to run the embeddings model on, can be "mps", "cuda", "cpu". + k: Number of top chunks to retrieve. + chunk_size: Size of each chunk if you need to build the vector store from documents. + chunk_overlap: Overlap between chunks if you need to build the vector store from documents. + """ + super().__init__(k=k) + self.usage = 0 + + model_kwargs = {"device": device} + encode_kwargs = {"normalize_embeddings": True} + self.model = HuggingFaceEmbeddings( + model_name=embedding_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs + ) + + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + self.collection_name = collection_name + self.client = None + self.qdrant = None + + def _check_create_collection(self): + """ + Check if the Qdrant collection exists and create it if it does not. + """ + if self.client is None: + raise ValueError("Qdrant client is not initialized.") + if self.client.collection_exists(collection_name=f"{self.collection_name}"): + print(f"Collection {self.collection_name} exists. Loading the collection...") + self.qdrant = Qdrant( + client=self.client, + collection_name=self.collection_name, + embeddings=self.model, + ) + else: + print(f"Collection {self.collection_name} does not exist. Creating the collection...") + # create the collection + self.client.create_collection( + collection_name=f"{self.collection_name}", + vectors_config=models.VectorParams(size=1024, distance=models.Distance.COSINE), + ) + self.qdrant = Qdrant( + client=self.client, + collection_name=self.collection_name, + embeddings=self.model, + ) + + def init_online_vector_db(self, url: str, api_key: str): + """ + Initialize the Qdrant client that is connected to an online vector store with the given URL and API key. + + Args: + url (str): URL of the Qdrant server. + api_key (str): API key for the Qdrant server. + """ + if api_key is None: + if not os.getenv("QDRANT_API_KEY"): + raise ValueError("Please provide an api key.") + api_key = os.getenv("QDRANT_API_KEY") + if url is None: + raise ValueError("Please provide a url for the Qdrant server.") + + try: + self.client = QdrantClient(url=url, api_key=api_key) + self._check_create_collection() + except Exception as e: + raise ValueError(f"Error occurs when connecting to the server: {e}") + + def init_offline_vector_db(self, vector_store_path: str): + """ + Initialize the Qdrant client that is connected to an offline vector store with the given vector store folder path. + + Args: + vector_store_path (str): Path to the vector store. + """ + if vector_store_path is None: + raise ValueError("Please provide a folder path.") + + try: + self.client = QdrantClient(path=vector_store_path) + self._check_create_collection() + except Exception as e: + raise ValueError(f"Error occurs when loading the vector store: {e}") + + def update_vector_store( + self, + file_path: str, + content_column: str, + title_column: str = "title", + url_column: str = "url", + desc_column: str = "description", + batch_size: int = 64 + ): + """ + Takes a CSV file where each row is a document and has columns for content, title, url, and description. + Then it converts all these documents in the content column to vectors and add them the Qdrant collection. + + Args: + file_path (str): Path to the CSV file. + content_column (str): Name of the column containing the content. + title_column (str): Name of the column containing the title. Default is "title". + url_column (str): Name of the column containing the URL. Default is "url". + desc_column (str): Name of the column containing the description. Default is "description". + batch_size (int): Batch size for adding documents to the collection. + """ + if file_path is None: + raise ValueError("Please provide a file path.") + # check if the file is a csv file + if not file_path.endswith('.csv'): + raise ValueError(f"Not valid file format. Please provide a csv file.") + if content_column is None: + raise ValueError("Please provide the name of the content column.") + if url_column is None: + raise ValueError("Please provide the name of the url column.") + + if self.qdrant is None: + raise ValueError("Qdrant client is not initialized.") + + # read the csv file + df = pd.read_csv(file_path) + # check that content column exists and url column exists + if content_column not in df.columns: + raise ValueError(f"Content column {content_column} not found in the csv file.") + if url_column not in df.columns: + raise ValueError(f"URL column {url_column} not found in the csv file.") + + documents = [ + Document( + page_content=row[content_column], + metadata={ + "title": row.get(title_column, ''), + "url": row[url_column], + "description": row.get(desc_column, ''), + } + ) + for row in df.to_dict(orient='records') + ] + + # split the documents + from langchain_text_splitters import RecursiveCharacterTextSplitter + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + length_function=len, + add_start_index=True, + separators=[ + "\n\n", + "\n", + ".", + "\uff0e", # Fullwidth full stop + "\u3002", # Ideographic full stop + ",", + "\uff0c", # Fullwidth comma + "\u3001", # Ideographic comma + " ", + "\u200B", # Zero-width space + "", + ] + ) + split_documents = text_splitter.split_documents(documents) + + # update and save the vector store + num_batches = (len(split_documents) + batch_size - 1) // batch_size + for i in tqdm(range(num_batches)): + start_idx = i * batch_size + end_idx = min((i + 1) * batch_size, len(split_documents)) + self.qdrant.add_documents( + documents=split_documents[start_idx:end_idx], + batch_size=batch_size, + ) + + def get_usage_and_reset(self): + usage = self.usage + self.usage = 0 + + return {'VectorRM': usage} + + def get_vector_count(self): + """ + Get the count of vectors in the collection. + + Returns: + int: Number of vectors in the collection. + """ + return self.qdrant.client.count(collection_name=self.collection_name) + + def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]): + """ + Search in your data for self.k top passages for query or queries. + + Args: + query_or_queries (Union[str, List[str]]): The query or queries to search for. + exclude_urls (List[str]): Dummy parameter to match the interface. Does not have any effect. + + Returns: + a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url' + """ + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + self.usage += len(queries) + collected_results = [] + for query in queries: + related_docs = self.qdrant.similarity_search_with_score(query, k=self.k) + for i in range(len(related_docs)): + doc = related_docs[i][0] + collected_results.append({ + 'description': doc.metadata['description'], + 'snippets': [doc.page_content], + 'title': doc.metadata['title'], + 'url': doc.metadata['url'], + }) + + return collected_results