Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New RM] Add AzureAISearch #198

Merged
merged 5 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ You could also install the source code which allows you to modify the behavior o
Currently, our package support:

- `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components
- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch` as retrieval module components
- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch`, and `AzureAISearch` as retrieval module components

:star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!**

Expand Down
12 changes: 8 additions & 4 deletions examples/storm_examples/run_storm_wiki_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
"""

import os

from argparse import ArgumentParser
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG, AzureAISearch
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -72,6 +73,7 @@ def main(args):

# STORM is a knowledge curation system which consumes information from the retrieval module.
# Currently, the information source is the Internet and we use search engine API as the retrieval module.

match args.retriever:
case 'bing':
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
Expand All @@ -87,8 +89,10 @@ def main(args):
rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True)
case 'searxng':
rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k)
case 'azure_ai_search':
rm = AzureAISearch(azure_ai_search_api_key=os.getenv('AZURE_AI_SEARCH_API_KEY'), k=engine_args.search_top_k)
case _:
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"')
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"')

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -113,7 +117,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng', 'azure_ai_search'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand All @@ -138,4 +142,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
133 changes: 133 additions & 0 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,3 +1093,136 @@ def forward(
collected_results.append(r)

return collected_results


class AzureAISearch(dspy.Retrieve):
"""Retrieve information from custom queries using Azure AI Search.
General Documentation: https://learn.microsoft.com/en-us/azure/search/search-create-service-portal.
Python Documentation: https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python.
"""

def __init__(
self,
azure_ai_search_api_key=None,
azure_ai_search_url=None,
azure_ai_search_index_name=None,
k=3,
is_valid_source: Callable = None,
):
"""
Params:
azure_ai_search_api_key: Azure AI Search API key. Check out https://learn.microsoft.com/en-us/azure/search/search-security-api-keys?tabs=rest-use%2Cportal-find%2Cportal-query
"API key" section
azure_ai_search_url: Custom Azure AI Search Endpoint URL. Check out https://learn.microsoft.com/en-us/azure/search/search-create-service-portal#name-the-service
azure_ai_search_index_name: Custom Azure AI Search Index Name. Check out https://learn.microsoft.com/en-us/azure/search/search-how-to-create-search-index?tabs=portal
k: Number of top results to retrieve.
is_valid_source: Optional function to filter valid sources.
min_char_count: Minimum character count for the article to be considered valid.
snippet_chunk_size: Maximum character count for each snippet.
webpage_helper_max_threads: Maximum number of threads to use for webpage helper.
"""
super().__init__(k=k)

try:
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
except ImportError as err:
raise ImportError(
"AzureAISearch requires `pip install azure-search-documents`."
) from err

if not azure_ai_search_api_key and not os.environ.get(
"AZURE_AI_SEARCH_API_KEY"
):
raise RuntimeError(
"You must supply azure_ai_search_api_key or set environment variable AZURE_AI_SEARCH_API_KEY"
)
elif azure_ai_search_api_key:
self.azure_ai_search_api_key = azure_ai_search_api_key
else:
self.azure_ai_search_api_key = os.environ["AZURE_AI_SEARCH_API_KEY"]

if not azure_ai_search_url and not os.environ.get("AZURE_AI_SEARCH_URL"):
raise RuntimeError(
"You must supply azure_ai_search_url or set environment variable AZURE_AI_SEARCH_URL"
)
elif azure_ai_search_url:
self.azure_ai_search_url = azure_ai_search_url
else:
self.azure_ai_search_url = os.environ["AZURE_AI_SEARCH_URL"]

if not azure_ai_search_index_name and not os.environ.get(
"AZURE_AI_SEARCH_INDEX_NAME"
):
raise RuntimeError(
"You must supply azure_ai_search_index_name or set environment variable AZURE_AI_SEARCH_INDEX_NAME"
)
elif azure_ai_search_index_name:
self.azure_ai_search_index_name = azure_ai_search_index_name
else:
self.azure_ai_search_index_name = os.environ["AZURE_AI_SEARCH_INDEX_NAME"]

self.usage = 0

# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {"AzureAISearch": usage}

def forward(
self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []
):
"""Search with Azure Open AI for self.k top passages for query or queries
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.
Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
try:
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
except ImportError as err:
raise ImportError(
"AzureAISearch requires `pip install azure-search-documents`."
) from err
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
collected_results = []

client = SearchClient(
self.azure_ai_search_url,
self.azure_ai_search_index_name,
AzureKeyCredential(self.azure_ai_search_api_key),
)
for query in queries:
try:
# https://learn.microsoft.com/en-us/python/api/azure-search-documents/azure.search.documents.searchclient?view=azure-python#azure-search-documents-searchclient-search
results = client.search(search_text=query, top=1)

for result in results:
document = {
"url": result["metadata_storage_path"],
"title": result["title"],
"description": "N/A",
"snippets": [result["chunk"]],
}
collected_results.append(document)
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")

return collected_results
Loading