From 7167bbdcf0f27ff8b284cfa0b6b9e776e4f33e04 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Mon, 29 Jul 2024 19:21:22 +0800 Subject: [PATCH] Mind search (#208) Support MindSearch --- lagent/actions/__init__.py | 3 +- lagent/actions/bing_browser.py | 270 +++++++++++++++++++++++++++++++ lagent/actions/parser.py | 5 +- lagent/agents/internlm2_agent.py | 20 ++- lagent/llms/base_api.py | 3 +- lagent/llms/openai.py | 7 + lagent/schema.py | 8 +- 7 files changed, 307 insertions(+), 9 deletions(-) create mode 100755 lagent/actions/bing_browser.py diff --git a/lagent/actions/__init__.py b/lagent/actions/__init__.py index 56f467a9..fdffebeb 100644 --- a/lagent/actions/__init__.py +++ b/lagent/actions/__init__.py @@ -3,6 +3,7 @@ from .action_executor import ActionExecutor from .arxiv_search import ArxivSearch from .base_action import TOOL_REGISTRY, BaseAction, tool_api +from .bing_browser import BingBrowser from .bing_map import BINGMap from .builtin_actions import FinishAction, InvalidAction, NoAction from .google_scholar_search import GoogleScholar @@ -20,7 +21,7 @@ 'GoogleScholar', 'IPythonInterpreter', 'IPythonInteractive', 'IPythonInteractiveManager', 'PythonInterpreter', 'PPT', 'BaseParser', 'JsonParser', 'TupleParser', 'tool_api', 'list_tools', 'get_tool_cls', - 'get_tool' + 'get_tool', 'BingBrowser' ] diff --git a/lagent/actions/bing_browser.py b/lagent/actions/bing_browser.py new file mode 100755 index 00000000..678d0866 --- /dev/null +++ b/lagent/actions/bing_browser.py @@ -0,0 +1,270 @@ +import json +import logging +import random +import re +import time +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Optional, Tuple, Type, Union + +import requests +from bs4 import BeautifulSoup +from cachetools import TTLCache, cached +from duckduckgo_search import DDGS + +from lagent.actions import BaseAction, tool_api +from lagent.actions.parser import BaseParser, JsonParser + + +class BaseSearch: + + def __init__(self, topk: int = 3, black_list: List[str] = None): + self.topk = topk + self.black_list = black_list + + def _filter_results(self, results: List[tuple]) -> dict: + filtered_results = {} + count = 0 + for url, snippet, title in results: + if all(domain not in url + for domain in self.black_list) and not url.endswith('.pdf'): + filtered_results[count] = { + 'url': url, + 'summ': json.dumps(snippet, ensure_ascii=False)[1:-1], + 'title': title + } + count += 1 + if count >= self.topk: + break + return filtered_results + + +class DuckDuckGoSearch(BaseSearch): + + def __init__(self, + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): + self.proxy = kwargs.get('proxy') + self.timeout = kwargs.get('timeout', 10) + super().__init__(topk, black_list) + + @cached(cache=TTLCache(maxsize=100, ttl=600)) + def search(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = self._call_ddgs( + query, timeout=self.timeout, proxy=self.proxy) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + time.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from DuckDuckGo after retries.') + + def _call_ddgs(self, query: str, **kwargs) -> dict: + ddgs = DDGS(**kwargs) + response = ddgs.text(query.strip("'"), max_results=10) + return response + + def _parse_response(self, response: dict) -> dict: + raw_results = [] + for item in response: + raw_results.append( + (item['href'], item['description'] + if 'description' in item else item['body'], item['title'])) + return self._filter_results(raw_results) + + +class BingSearch(BaseSearch): + + def __init__(self, + api_key: str, + region: str = 'zh-CN', + topk: int = 3, + black_list: List[str] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + **kwargs): + self.api_key = api_key + self.market = region + self.proxy = kwargs.get('proxy') + super().__init__(topk, black_list) + + @cached(cache=TTLCache(maxsize=100, ttl=600)) + def search(self, query: str, max_retry: int = 3) -> dict: + for attempt in range(max_retry): + try: + response = self._call_bing_api(query) + return self._parse_response(response) + except Exception as e: + logging.exception(str(e)) + warnings.warn( + f'Retry {attempt + 1}/{max_retry} due to error: {e}') + time.sleep(random.randint(2, 5)) + raise Exception( + 'Failed to get search results from Bing Search after retries.') + + def _call_bing_api(self, query: str) -> dict: + endpoint = 'https://api.bing.microsoft.com/v7.0/search' + params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'} + headers = {'Ocp-Apim-Subscription-Key': self.api_key} + response = requests.get( + endpoint, headers=headers, params=params, proxies=self.proxy) + response.raise_for_status() + return response.json() + + def _parse_response(self, response: dict) -> dict: + webpages = { + w['id']: w + for w in response.get('webPages', {}).get('value', []) + } + raw_results = [] + + for item in response.get('rankingResponse', + {}).get('mainline', {}).get('items', []): + if item['answerType'] == 'WebPages': + webpage = webpages.get(item['value']['id']) + if webpage: + raw_results.append( + (webpage['url'], webpage['snippet'], webpage['name'])) + elif item['answerType'] == 'News' and item['value'][ + 'id'] == response.get('news', {}).get('id'): + for news in response.get('news', {}).get('value', []): + raw_results.append( + (news['url'], news['description'], news['name'])) + + return self._filter_results(raw_results) + + +class ContentFetcher: + + def __init__(self, timeout: int = 5): + self.timeout = timeout + + @cached(cache=TTLCache(maxsize=100, ttl=600)) + def fetch(self, url: str) -> Tuple[bool, str]: + try: + response = requests.get(url, timeout=self.timeout) + response.raise_for_status() + html = response.content + except requests.RequestException as e: + return False, str(e) + + text = BeautifulSoup(html, 'html.parser').get_text() + cleaned_text = re.sub(r'\n+', '\n', text) + return True, cleaned_text + + +class BingBrowser(BaseAction): + """Wrapper around the Web Browser Tool. + """ + + def __init__(self, + searcher_type: str = 'DuckDuckGoSearch', + timeout: int = 5, + black_list: Optional[List[str]] = [ + 'enoN', + 'youtube.com', + 'bilibili.com', + 'researchgate.net', + ], + topk: int = 20, + description: Optional[dict] = None, + parser: Type[BaseParser] = JsonParser, + enable: bool = True, + **kwargs): + self.searcher = eval(searcher_type)( + black_list=black_list, topk=topk, **kwargs) + self.fetcher = ContentFetcher(timeout=timeout) + self.search_results = None + super().__init__(description, parser, enable) + + @tool_api + def search(self, query: Union[str, List[str]]) -> dict: + """BING search API + Args: + query (List[str]): list of search query strings + """ + queries = query if isinstance(query, list) else [query] + search_results = {} + + with ThreadPoolExecutor() as executor: + future_to_query = { + executor.submit(self.searcher.search, q): q + for q in queries + } + + for future in as_completed(future_to_query): + query = future_to_query[future] + try: + results = future.result() + except Exception as exc: + warnings.warn(f'{query} generated an exception: {exc}') + else: + for result in results.values(): + if result['url'] not in search_results: + search_results[result['url']] = result + else: + search_results[ + result['url']]['summ'] += f"\n{result['summ']}" + + self.search_results = { + idx: result + for idx, result in enumerate(search_results.values()) + } + return self.search_results + + @tool_api + def select(self, select_ids: List[int]) -> dict: + """get the detailed content on the selected pages. + + Args: + select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4. + """ + if not self.search_results: + raise ValueError('No search results to select from.') + + new_search_results = {} + with ThreadPoolExecutor() as executor: + future_to_id = { + executor.submit(self.fetcher.fetch, + self.search_results[select_id]['url']): + select_id + for select_id in select_ids if select_id in self.search_results + } + + for future in as_completed(future_to_id): + select_id = future_to_id[future] + try: + web_success, web_content = future.result() + except Exception as exc: + warnings.warn(f'{select_id} generated an exception: {exc}') + else: + if web_success: + self.search_results[select_id][ + 'content'] = web_content[:8192] + new_search_results[select_id] = self.search_results[ + select_id].copy() + new_search_results[select_id].pop('summ') + + return new_search_results + + @tool_api + def open_url(self, url: str) -> dict: + print(f'Start Browsing: {url}') + web_success, web_content = self.fetcher.fetch(url) + if web_success: + return {'type': 'text', 'content': web_content} + else: + return {'error': web_content} diff --git a/lagent/actions/parser.py b/lagent/actions/parser.py index 3f4a19bc..4188ae39 100644 --- a/lagent/actions/parser.py +++ b/lagent/actions/parser.py @@ -72,7 +72,10 @@ def parse_outputs(self, outputs: Any) -> List[dict]: outputs = json.dumps(outputs, ensure_ascii=False) elif not isinstance(outputs, str): outputs = str(outputs) - return [{'type': 'text', 'content': outputs}] + return [{ + 'type': 'text', + 'content': outputs.encode('gbk', 'ignore').decode('gbk') + }] class JsonParser(BaseParser): diff --git a/lagent/agents/internlm2_agent.py b/lagent/agents/internlm2_agent.py index 0f3c023d..9bcd343e 100644 --- a/lagent/agents/internlm2_agent.py +++ b/lagent/agents/internlm2_agent.py @@ -4,6 +4,8 @@ from copy import deepcopy from typing import Dict, List, Optional, Union +from termcolor import colored + from lagent.actions import ActionExecutor from lagent.agents.base_agent import BaseAgent from lagent.llms import BaseAPIModel, BaseModel @@ -160,8 +162,10 @@ def format(self, formatted += self.format_sub_role(inner_step) return formatted - def parse(self, message, plugin_executor: ActionExecutor, - interpreter_executor: ActionExecutor): + def parse(self, + message, + plugin_executor: ActionExecutor = None, + interpreter_executor: ActionExecutor = None): if self.language['begin']: message = message.split(self.language['begin'])[-1] if self.tool['name_map']['plugin'] in message: @@ -183,9 +187,10 @@ def parse(self, message, plugin_executor: ActionExecutor, message = message.strip() code = code.split(self.tool['end'].strip())[0].strip() return 'interpreter', message, dict( - name=interpreter_executor.action_names()[0], - parameters=dict( - command=code)) if interpreter_executor else None + name=interpreter_executor.action_names()[0] if isinstance( + interpreter_executor, ActionExecutor) else + 'IPythonInterpreter', + parameters=dict(command=code)) return None, message.split(self.tool['start_token'])[0], None def format_response(self, action_return, name) -> dict: @@ -285,6 +290,7 @@ def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn: inner_history = message[:] offset = len(inner_history) agent_return = AgentReturn() + agent_return.inner_steps = deepcopy(inner_history) last_agent_state = AgentStatusCode.SESSION_READY for _ in range(self.max_turn): # list of dict @@ -348,11 +354,14 @@ def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn: agent_return.response = language last_agent_state = agent_state yield deepcopy(agent_return) + print(colored(response, 'red')) if name: action_return: ActionReturn = executor(action['name'], action['parameters']) + action_return.type = action['name'] action_return.thought = language agent_return.actions.append(action_return) + print(colored(action_return.result, 'magenta')) inner_history.append(dict(role='language', content=language)) if not name: agent_return.response = language @@ -372,6 +381,7 @@ def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn: self._protocol.format_response(action_return, name=name)) agent_state += 1 agent_return.state = agent_state + agent_return.inner_steps = deepcopy(inner_history[offset:]) yield agent_return agent_return.inner_steps = deepcopy(inner_history[offset:]) agent_return.state = AgentStatusCode.END diff --git a/lagent/llms/base_api.py b/lagent/llms/base_api.py index d6667263..6c7f4bf2 100644 --- a/lagent/llms/base_api.py +++ b/lagent/llms/base_api.py @@ -174,7 +174,8 @@ def __init__(self, top_k=top_k, temperature=temperature, repetition_penalty=repetition_penalty, - stop_words=stop_words) + stop_words=stop_words, + skip_special_tokens=False) def _wait(self): """Wait till the next query can be sent. diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index a8c69615..c9af4242 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -73,6 +73,8 @@ def __init__(self, retry=retry, **gen_params) self.gen_params.pop('top_k') + if not model_type.lower().startswith('internlm'): + self.gen_params.pop('skip_special_tokens') self.logger = getLogger(__name__) if isinstance(key, str): @@ -282,6 +284,10 @@ def streaming(raw_response): if decoded[:6] == 'data: ': decoded = decoded[6:] response = json.loads(decoded) + if 'code' in response and response['code'] == -20003: + # Context exceeds maximum length + yield '' + return choice = response['choices'][0] if choice['finish_reason'] == 'stop': return @@ -290,6 +296,7 @@ def streaming(raw_response): assert isinstance(messages, list) gen_params = gen_params.copy() + # Hold out 100 tokens due to potential errors in tiktoken calculation max_tokens = min(gen_params.pop('max_new_tokens'), 4096) if max_tokens <= 0: diff --git a/lagent/schema.py b/lagent/schema.py index a7f8e0cd..da7e8af3 100644 --- a/lagent/schema.py +++ b/lagent/schema.py @@ -1,6 +1,6 @@ from dataclasses import asdict, dataclass, field from enum import IntEnum -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union def enum_dict_factory(inputs): @@ -77,12 +77,18 @@ class AgentStatusCode(IntEnum): CODING = 6 # start python CODE_END = 7 # end python CODE_RETURN = 8 # python return + ANSWER_ING = 9 # final answer is in streaming @dataclass class AgentReturn: + type: str = '' + content: str = '' state: Union[AgentStatusCode, int] = AgentStatusCode.END actions: List[ActionReturn] = field(default_factory=list) response: str = '' inner_steps: List = field(default_factory=list) + nodes: Dict = None + adjacency_list: Dict = None + references: Dict = field(default_factory=dict) errmsg: Optional[str] = None