diff --git a/README.md b/README.md index 3f97a2ae..81f1af5f 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,7 @@ runner = STORMWikiRunner(engine_args, lm_configs, rm) Currently, our package support: - `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components -- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, and `TavilySearchRM` as retrieval module components +- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch` as retrieval module components :star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!** diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index f03e1792..80b8a385 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -881,3 +881,127 @@ def forward( print(f"Error occurs when searching query {query}: {e}") return collected_results + + +class GoogleSearch(dspy.Retrieve): + def __init__( + self, + google_search_api_key=None, + google_cse_id=None, + k=3, + is_valid_source: Callable = None, + min_char_count: int = 150, + snippet_chunk_size: int = 1000, + webpage_helper_max_threads=10, + ): + """ + Params: + google_search_api_key: Google API key. Check out https://developers.google.com/custom-search/v1/overview + "API key" section + google_cse_id: Custom search engine ID. Check out https://developers.google.com/custom-search/v1/overview + "Search engine ID" section + k: Number of top results to retrieve. + is_valid_source: Optional function to filter valid sources. + min_char_count: Minimum character count for the article to be considered valid. + snippet_chunk_size: Maximum character count for each snippet. + webpage_helper_max_threads: Maximum number of threads to use for webpage helper. + """ + super().__init__(k=k) + try: + from googleapiclient.discovery import build + except ImportError as err: + raise ImportError( + "GoogleSearch requires `pip install google-api-python-client`." + ) from err + if not google_search_api_key and not os.environ.get("GOOGLE_SEARCH_API_KEY"): + raise RuntimeError( + "You must supply google_search_api_key or set the GOOGLE_SEARCH_API_KEY environment variable" + ) + if not google_cse_id and not os.environ.get("GOOGLE_CSE_ID"): + raise RuntimeError( + "You must supply google_cse_id or set the GOOGLE_CSE_ID environment variable" + ) + + self.google_search_api_key = ( + google_search_api_key or os.environ["GOOGLE_SEARCH_API_KEY"] + ) + self.google_cse_id = google_cse_id or os.environ["GOOGLE_CSE_ID"] + + if is_valid_source: + self.is_valid_source = is_valid_source + else: + self.is_valid_source = lambda x: True + + self.service = build( + "customsearch", "v1", developerKey=self.google_search_api_key + ) + self.webpage_helper = WebPageHelper( + min_char_count=min_char_count, + snippet_chunk_size=snippet_chunk_size, + max_thread_num=webpage_helper_max_threads, + ) + self.usage = 0 + + def get_usage_and_reset(self): + usage = self.usage + self.usage = 0 + return {"GoogleSearch": usage} + + def forward( + self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] + ): + """Search using Google Custom Search API for self.k top results for query or queries. + + Args: + query_or_queries (Union[str, List[str]]): The query or queries to search for. + exclude_urls (List[str]): A list of URLs to exclude from the search results. + + Returns: + A list of dicts, each dict has keys: 'title', 'url', 'snippet', 'description'. + """ + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + self.usage += len(queries) + + url_to_results = {} + + for query in queries: + try: + response = ( + self.service.cse() + .list( + q=query, + cx=self.google_cse_id, + num=self.k, + ) + .execute() + ) + + for item in response.get("items", []): + if ( + self.is_valid_source(item["link"]) + and item["link"] not in exclude_urls + ): + url_to_results[item["link"]] = { + "title": item["title"], + "url": item["link"], + # "snippet": item.get("snippet", ""), # Google search snippet is very short. + "description": item.get("snippet", ""), + } + + except Exception as e: + logging.error(f"Error occurred while searching query {query}: {e}") + + valid_url_to_snippets = self.webpage_helper.urls_to_snippets( + list(url_to_results.keys()) + ) + collected_results = [] + for url in valid_url_to_snippets: + r = url_to_results[url] + r["snippets"] = valid_url_to_snippets[url]["snippets"] + collected_results.append(r) + + return collected_results