Skip to content

Commit

Permalink
Merge pull request #198 from dfusion-dev/main
Browse files Browse the repository at this point in the history
[New RM] Add AzureAISearch
  • Loading branch information
shaoyijia authored Oct 19, 2024
2 parents f1a99b8 + 9511be7 commit 18d0258
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ You could also install the source code which allows you to modify the behavior o
Currently, our package support:

- `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components
- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch` as retrieval module components
- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch`, and `AzureAISearch` as retrieval module components

:star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!**

Expand Down
12 changes: 8 additions & 4 deletions examples/storm_examples/run_storm_wiki_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
"""

import os

from argparse import ArgumentParser
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG
from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG, AzureAISearch
from knowledge_storm.utils import load_api_key


Expand Down Expand Up @@ -72,6 +73,7 @@ def main(args):

# STORM is a knowledge curation system which consumes information from the retrieval module.
# Currently, the information source is the Internet and we use search engine API as the retrieval module.

match args.retriever:
case 'bing':
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
Expand All @@ -87,8 +89,10 @@ def main(args):
rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True)
case 'searxng':
rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k)
case 'azure_ai_search':
rm = AzureAISearch(azure_ai_search_api_key=os.getenv('AZURE_AI_SEARCH_API_KEY'), k=engine_args.search_top_k)
case _:
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"')
raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"')

runner = STORMWikiRunner(engine_args, lm_configs, rm)

Expand All @@ -113,7 +117,7 @@ def main(args):
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'],
parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng', 'azure_ai_search'],
help='The search engine API to use for retrieving information.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
Expand All @@ -138,4 +142,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
133 changes: 133 additions & 0 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,3 +1093,136 @@ def forward(
collected_results.append(r)

return collected_results


class AzureAISearch(dspy.Retrieve):
"""Retrieve information from custom queries using Azure AI Search.
General Documentation: https://learn.microsoft.com/en-us/azure/search/search-create-service-portal.
Python Documentation: https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python.
"""

def __init__(
self,
azure_ai_search_api_key=None,
azure_ai_search_url=None,
azure_ai_search_index_name=None,
k=3,
is_valid_source: Callable = None,
):
"""
Params:
azure_ai_search_api_key: Azure AI Search API key. Check out https://learn.microsoft.com/en-us/azure/search/search-security-api-keys?tabs=rest-use%2Cportal-find%2Cportal-query
"API key" section
azure_ai_search_url: Custom Azure AI Search Endpoint URL. Check out https://learn.microsoft.com/en-us/azure/search/search-create-service-portal#name-the-service
azure_ai_search_index_name: Custom Azure AI Search Index Name. Check out https://learn.microsoft.com/en-us/azure/search/search-how-to-create-search-index?tabs=portal
k: Number of top results to retrieve.
is_valid_source: Optional function to filter valid sources.
min_char_count: Minimum character count for the article to be considered valid.
snippet_chunk_size: Maximum character count for each snippet.
webpage_helper_max_threads: Maximum number of threads to use for webpage helper.
"""
super().__init__(k=k)

try:
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
except ImportError as err:
raise ImportError(
"AzureAISearch requires `pip install azure-search-documents`."
) from err

if not azure_ai_search_api_key and not os.environ.get(
"AZURE_AI_SEARCH_API_KEY"
):
raise RuntimeError(
"You must supply azure_ai_search_api_key or set environment variable AZURE_AI_SEARCH_API_KEY"
)
elif azure_ai_search_api_key:
self.azure_ai_search_api_key = azure_ai_search_api_key
else:
self.azure_ai_search_api_key = os.environ["AZURE_AI_SEARCH_API_KEY"]

if not azure_ai_search_url and not os.environ.get("AZURE_AI_SEARCH_URL"):
raise RuntimeError(
"You must supply azure_ai_search_url or set environment variable AZURE_AI_SEARCH_URL"
)
elif azure_ai_search_url:
self.azure_ai_search_url = azure_ai_search_url
else:
self.azure_ai_search_url = os.environ["AZURE_AI_SEARCH_URL"]

if not azure_ai_search_index_name and not os.environ.get(
"AZURE_AI_SEARCH_INDEX_NAME"
):
raise RuntimeError(
"You must supply azure_ai_search_index_name or set environment variable AZURE_AI_SEARCH_INDEX_NAME"
)
elif azure_ai_search_index_name:
self.azure_ai_search_index_name = azure_ai_search_index_name
else:
self.azure_ai_search_index_name = os.environ["AZURE_AI_SEARCH_INDEX_NAME"]

self.usage = 0

# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {"AzureAISearch": usage}

def forward(
self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []
):
"""Search with Azure Open AI for self.k top passages for query or queries
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.
Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
try:
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
except ImportError as err:
raise ImportError(
"AzureAISearch requires `pip install azure-search-documents`."
) from err
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
collected_results = []

client = SearchClient(
self.azure_ai_search_url,
self.azure_ai_search_index_name,
AzureKeyCredential(self.azure_ai_search_api_key),
)
for query in queries:
try:
# https://learn.microsoft.com/en-us/python/api/azure-search-documents/azure.search.documents.searchclient?view=azure-python#azure-search-documents-searchclient-search
results = client.search(search_text=query, top=1)

for result in results:
document = {
"url": result["metadata_storage_path"],
"title": result["title"],
"description": "N/A",
"snippets": [result["chunk"]],
}
collected_results.append(document)
except Exception as e:
logging.error(f"Error occurs when searching query {query}: {e}")

return collected_results

0 comments on commit 18d0258

Please sign in to comment.