Skip to content

Commit

Permalink
Merge pull request #58 from AMMAS1/dev-customize-retrieval-source
Browse files Browse the repository at this point in the history
add scripts and documentation to support customize retrieval source
  • Loading branch information
shaoyijia authored Jul 7, 2024
2 parents 72aebfd + e31dcf1 commit d3b36cd
Show file tree
Hide file tree
Showing 6 changed files with 591 additions and 35 deletions.
30 changes: 8 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,11 @@ Below, we provide a quick start guide to run STORM locally.

### 2. Running STORM-wiki locally

Currently, we provide example scripts under [`examples`](examples) to demonstrate how you can run STORM using different models.

**To run STORM with `gpt` family models**: Make sure you have set up the OpenAI API key and run the following command.
**To run STORM with `gpt` family models with default configurations**: Make sure you have set up the OpenAI API key and run the following command.

```
python examples/run_storm_wiki_gpt.py \
--output_dir $OUTPUT_DIR \
--output-dir $OUTPUT_DIR \
--retriever you \
--do-research \
--do-generate-outline \
Expand All @@ -97,29 +95,15 @@ python examples/run_storm_wiki_gpt.py \
- `--do-generate-article`: If True, generate an article for the topic; otherwise, load the results.
- `--do-polish-article`: If True, polish the article by adding a summarization section and (optionally) removing duplicate content.
**To run STORM with `mistral` family models on local VLLM server**: have a VLLM server running with the `Mistral-7B-Instruct-v0.2` model and run the following command.
```
python examples/run_storm_wiki_mistral.py \
--url $URL \
--port $PORT \
--output_dir $OUTPUT_DIR \
--retriever you \
--do-research \
--do-generate-outline \
--do-generate-article \
--do-polish-article
```
- `--url` URL of the VLLM server.
- `--port` Port of the VLLM server.
We provide more example scripts under [`examples`](examples) to demonstrate how you can run STORM using your favorite language models or grounding on your own corpus.
## Customize STORM
### Customization of the Pipeline
STORM is a knowledge curation engine consisting of 4 modules:
Besides running scripts in `examples`, you can customize STORM based on your own use case. STORM engine consists of 4 modules:
1. Knowledge Curation Module: Collects a broad coverage of information about the given topic.
2. Outline Generation Module: Organizes the collected information by generating a hierarchical outline for the curated knowledge.
Expand All @@ -132,9 +116,11 @@ The interface for each module is defined in `src/interface.py`, while their impl
### Customization of Retriever Module
As a knowledge curation engine, STORM grabs information from the Retriever module. The interface for the Retriever module is defined in [`src/interface.py`](src/interface.py). Please consult the interface documentation if you plan to create a new instance or replace the default search engine API. By default, STORM utilizes the You.com search engine API (see `YouRM` in [`src/rm.py`](src/rm.py)).
As a knowledge curation engine, STORM grabs information from the Retriever module. The Retriever modules are implemented in [`src/rm.py`](src/rm.py). Currently, STORM supports the following retrievers:
:new: [2024/05] We test STORM with [Bing Search](https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/reference/endpoints). See `BingSearch` in [`src/rm.py`](src/rm.py) for the configuration and you can specify `--retriever bing` to use Bing Search in our [example scripts](examples).
- `YouRM`: You.com search engine API
- `BingSearch`: Bing Search API
- `VectorRM`: a retrieval model that retrieves information from user provide corpus
:star2: **PRs for integrating more search engines/retrievers are highly appreciated!**
Expand Down
116 changes: 116 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Examples

We host a number of example scripts for various customization of STORM (e.g., use your favorite language models, use your own corpus, etc.). These examples can be starting points for your own customizations and you are welcome to contribute your own examples by submitting a pull request to this directory.

## Run STORM with your own language model
[run_storm_wiki_gpt.py](run_storm_wiki_gpt.py) provides an example of running STORM with GPT models, and [run_storm_wiki_claude.py](run_storm_wiki_claude.py) provides an example of running STORM with Claude models. Besides using close-source models, you can also run STORM with models with open weights.

`run_storm_wiki_mistral.py` provides an example of running STORM with `Mistral-7B-Instruct-v0.2` using [VLLM](https://docs.vllm.ai/en/stable/) server:

1. Set up a VLLM server with the `Mistral-7B-Instruct-v0.2` model running.
2. Run the following command under the root directory of the repository:

```
python examples/run_storm_wiki_mistral.py \
--url $URL \
--port $PORT \
--output-dir $OUTPUT_DIR \
--retriever you \
--do-research \
--do-generate-outline \
--do-generate-article \
--do-polish-article
```
- `--url` URL of the VLLM server.
- `--port` Port of the VLLM server.
Besides VLLM server, STORM is also compatible with [TGI](https://huggingface.co/docs/text-generation-inference/en/index) server or [Together.ai](https://www.together.ai/products#inference) endpoint.
## Run STORM with your own corpus
By default, STORM is grounded on the Internet using the search engine, but it can also be grounded on your own corpus using `VectorRM`. [run_storm_wiki_with_gpt_with_VectorRM.py](run_storm_wiki_gpt_with_VectorRM.py) provides an example of running STORM grounding on your provided data.
1. Set up API keys.
- Make sure you have set up the OpenAI API key.
- `VectorRM` use [Qdrant](https://github.com/qdrant/qdrant-client) to create a vector store. If you want to set up this vector store online on a [Qdrant cloud server](https://cloud.qdrant.io/login), you need to set up `QDRANT_API_KEY` in `secrets.toml` as well; if you want to save the vector store locally, make sure you provide a location for the vector store.
2. Prepare your corpus. The documents should be provided as a single CSV file with the following format:
| content | title | url | description |
|------------------------|------------|------------|------------------------------------|
| I am a document. | Document 1 | docu-n-112 | A self-explanatory document. |
| I am another document. | Document 2 | docu-l-13 | Another self-explanatory document. |
| ... | ... | ... | ... |
- `url` will be used as a unique identifier of the document in STORM engine, so ensure different documents have different urls.
- The contents for `title` and `description` columns are optional. If not provided, the script will use default empty values.
- The content column is crucial and should be provided for each document.
3. Run the command under the root directory of the repository:
To create the vector store offline, run
```
python examples/run_storm_wiki_gpt_with_VectorRM.py \
--output-dir $OUTPUT_DIR \
--vector-db-mode offline \
--offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \
--update-vector-store \
--csv-file-path $CSV_FILE_PATH \
--device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \
--do-research \
--do-generate-outline \
--do-generate-article \
--do-polish-article
```
To create the vector store online on a Qdrant server, run
```
python examples/run_storm_wiki_gpt_with_VectorRM.py \
--output-dir $OUTPUT_DIR \
--vector-db-mode online \
--online-vector-db-url $ONLINE_VECTOR_DB_URL \
--update-vector-store \
--csv-file-path $CSV_FILE_PATH \
--device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \
--do-research \
--do-generate-outline \
--do-generate-article \
--do-polish-article
```
4. **Quick test with Kaggle arXiv Paper Abstracts dataset**:
- Download `arxiv_data_210930-054931.csv` from [here](https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts).
- Run the following command under the root directory to downsample the dataset by filtering papers with terms `[cs.CV]` and get a csv file that match the format mentioned above.
```
python examples/helper/process_kaggle_arxiv_abstract_dataset --input-path $PATH_TO_THE_DOWNLOADED_FILE --output-path $PATH_TO_THE_PROCESSED_CSV
```
- Run the following command to run STORM grounding on the processed dataset. You can input a topic related to computer vision (e.g., "The progress of multimodal models in computer vision") to see the generated article. (Note that the generated article may not include enough details since the quick test only use the abstracts of arxiv papers.)
```
python examples/run_storm_wiki_gpt_with_VectorRM.py \
--output-dir $OUTPUT_DIR \
--vector-db-mode offline \
--offline-vector-db-dir $OFFLINE_VECTOR_DB_DIR \
--update-vector-store \
--csv-file-path $PATH_TO_THE_PROCESSED_CSV \
--device $DEVICE_FOR_EMBEDDING(mps, cuda, cpu) \
--do-research \
--do-generate-outline \
--do-generate-article \
--do-polish-article
```
- For a quicker run, you can also download the pre-embedded vector store directly from [here](https://drive.google.com/file/d/1bijFkw5BKU7bqcmXMhO-5hg2fdKAL9bf/view?usp=share_link).
```
python examples/run_storm_wiki_gpt_with_VectorRM.py \
--output-dir $OUTPUT_DIR \
--vector-db-mode offline \
--offline-vector-db-dir $DOWNLOADED_VECTOR_DB_DR \
--do-research \
--do-generate-outline \
--do-generate-article \
--do-polish-article
```
28 changes: 28 additions & 0 deletions examples/helper/process_kaggle_arxiv_abstract_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Process `arxiv_data_210930-054931.csv` from https://www.kaggle.com/datasets/spsayakpaul/arxiv-paper-abstracts
to a csv file that is compatible with VectorRM.
"""

from argparse import ArgumentParser

import pandas as pd

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-path", type=str, help="Path to arxiv_data_210930-054931.csv.")
parser.add_argument("--output-path", type=str,
help="Path to store the csv file that is compatible with VectorRM.")
args = parser.parse_args()

df = pd.read_csv(args.input_path)
print(f'The original dataset has {len(df)} samples.')

# Downsample the dataset.
df = df[df['terms'] == "['cs.CV']"]

# Reformat the dataset to match the VectorRM input format.
df.rename(columns={"abstracts": "content", "titles": "title"}, inplace=True)
df['url'] = ['uid_' + str(idx) for idx in range(len(df))] # Ensure the url is unique.
df['description'] = ''

print(f'The downsampled dataset has {len(df)} samples.')
df.to_csv(args.output_path, index=False)
165 changes: 165 additions & 0 deletions examples/run_storm_wiki_gpt_with_VectorRM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
This STORM Wiki pipeline powered by GPT-3.5/4 and local retrieval model that uses Qdrant.
You need to set up the following environment variables to run this script:
- OPENAI_API_KEY: OpenAI API key
- OPENAI_API_TYPE: OpenAI API type (e.g., 'openai' or 'azure')
- QDRANT_API_KEY: Qdrant API key (needed ONLY if online vector store was used)
You will also need an existing Qdrant vector store either saved in a folder locally offline or in a server online.
If not, then you would need a CSV file with documents, and the script is going to create the vector store for you.
The CSV should be in the following format:
content | title | url | description
I am a document. | Document 1 | docu-n-112 | A self-explanatory document.
I am another document. | Document 2 | docu-l-13 | Another self-explanatory document.
Notice that the URL will be a unique identifier for the document so ensure different documents have different urls.
Output will be structured as below
args.output_dir/
topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
conversation_log.json # Log of information-seeking conversation
raw_search_results.json # Raw search results from search engine
direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
storm_gen_outline.txt # Outline refined with collected information
url_to_info.json # Sources that are used in the final article
storm_gen_article.txt # Final article generated
storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
"""

import os
import sys
from argparse import ArgumentParser

sys.path.append('./src')
from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from rm import VectorRM
from lm import OpenAIModel
from utils import load_api_key


def main(args):
# Load API key from the specified toml file path
load_api_key(toml_file_path='secrets.toml')

# Initialize the language model configurations
engine_lm_configs = STORMWikiLMConfigs()
openai_kwargs = {
'api_key': os.getenv("OPENAI_API_KEY"),
'api_provider': os.getenv('OPENAI_API_TYPE'),
'temperature': 1.0,
'top_p': 0.9,
}

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
# which is responsible for generating sections with citations.
conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs)
question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs)
outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=400, **openai_kwargs)
article_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=700, **openai_kwargs)
article_polish_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=4000, **openai_kwargs)

engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm)
engine_lm_configs.set_question_asker_lm(question_asker_lm)
engine_lm_configs.set_outline_gen_lm(outline_gen_lm)
engine_lm_configs.set_article_gen_lm(article_gen_lm)
engine_lm_configs.set_article_polish_lm(article_polish_lm)

# Initialize the engine arguments
engine_args = STORMWikiRunnerArguments(
output_dir=args.output_dir,
max_conv_turn=args.max_conv_turn,
max_perspective=args.max_perspective,
search_top_k=args.search_top_k,
max_thread_num=args.max_thread_num,
)

# Setup VectorRM to retrieve information from your own data
rm = VectorRM(collection_name=args.collection_name, device=args.device, k=engine_args.search_top_k)

# initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally):
if args.vector_db_mode == 'offline':
rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir)
elif args.vector_db_mode == 'online':
rm.init_online_vector_db(url=args.online_vector_db_url, api_key=os.getenv('QDRANT_API_KEY'))

# Update the vector store with the documents in the csv file
if args.update_vector_store:
rm.update_vector_store(
file_path=args.csv_file_path,
content_column='content',
title_column='title',
url_column='url',
desc_column='description',
batch_size=args.embed_batch_size
)

# Initialize the STORM Wiki Runner
runner = STORMWikiRunner(engine_args, engine_lm_configs, rm)

# run the pipeline
topic = input('Topic: ')
runner.run(
topic=topic,
do_research=args.do_research,
do_generate_outline=args.do_generate_outline,
do_generate_article=args.do_generate_article,
do_polish_article=args.do_polish_article,
)
runner.post_run()
runner.summary()


if __name__ == "__main__":
parser = ArgumentParser()
# global arguments
parser.add_argument('--output-dir', type=str, default='./results/gpt_retrieval',
help='Directory to store the outputs.')
parser.add_argument('--max-thread-num', type=int, default=3,
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
# provide local corpus and set up vector db
parser.add_argument('--collection-name', type=str, default="my_documents",
help='The collection name for vector store.')
parser.add_argument('--device', type=str, default="mps",
help='The device used to run the retrieval model (mps, cuda, cpu, etc).')
parser.add_argument('--vector-db-mode', type=str, choices=['offline', 'online'],
help='The mode of the Qdrant vector store (offline or online).')
parser.add_argument('--offline-vector-db-dir', type=str, default='./vector_store',
help='If use offline mode, please provide the directory to store the vector store.')
parser.add_argument('--online-vector-db-url', type=str,
help='If use online mode, please provide the url of the Qdrant server.')
parser.add_argument('--update-vector-store', action='store_true',
help='If True, update the vector store with the documents in the csv file; otherwise, '
'use the existing vector store.')
parser.add_argument('--csv-file-path', type=str,
help='The path of the custom document corpus in CSV format. The CSV file should include '
'content, title, url, and description columns.')
parser.add_argument('--embed-batch-size', type=int, default=64,
help='Batch size for embedding the documents in the csv file.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
help='If True, simulate conversation to research the topic; otherwise, load the results.')
parser.add_argument('--do-generate-outline', action='store_true',
help='If True, generate an outline for the topic; otherwise, load the results.')
parser.add_argument('--do-generate-article', action='store_true',
help='If True, generate an article for the topic; otherwise, load the results.')
parser.add_argument('--do-polish-article', action='store_true',
help='If True, polish the article by adding a summarization section and (optionally) removing '
'duplicate content.')
# hyperparameters for the pre-writing stage
parser.add_argument('--max-conv-turn', type=int, default=3,
help='Maximum number of questions in conversational question asking.')
parser.add_argument('--max-perspective', type=int, default=3,
help='Maximum number of perspectives to consider in perspective-guided question asking.')
parser.add_argument('--search-top-k', type=int, default=3,
help='Top k search results to consider for each search query.')
# hyperparameters for the writing stage
parser.add_argument('--retrieve-top-k', type=int, default=3,
help='Top k collected references for each section title.')
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')
main(parser.parse_args())
Loading

0 comments on commit d3b36cd

Please sign in to comment.