From e304e5d323cdbb631257fac9187d16b99476bc2f Mon Sep 17 00:00:00 2001 From: BraisedPork <46232992+braisedpork1964@users.noreply.github.com> Date: Mon, 18 Nov 2024 12:04:56 +0800 Subject: [PATCH] [Feat] Support resetting agent memory recursively or by keypath (#271) * update `reset` method * fix type hints --- lagent/agents/agent.py | 23 ++++++++++++++++++++--- lagent/llms/openai.py | 5 +++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py index cebfe423..b1e941ba 100644 --- a/lagent/agents/agent.py +++ b/lagent/agents/agent.py @@ -165,9 +165,26 @@ def register_hook(self, hook: Callable): self._hooks[handle.id] = hook return handle - def reset(self, session_id=0): - if self.memory: - self.memory.reset(session_id=session_id) + def reset(self, + session_id=0, + keypath: Optional[str] = None, + recursive: bool = False): + assert not (keypath and + recursive), 'keypath and recursive can\'t be used together' + if keypath: + keys, agent = keypath.split('.'), self + for key in keys: + agents = getattr(agent, '_agents', {}) + if key not in agents: + raise KeyError(f'No sub-agent named {key} in {agent}') + agent = agents[key] + agent.reset(session_id, recursive=False) + else: + if self.memory: + self.memory.reset(session_id=session_id) + if recursive: + for agent in getattr(self, '_agents', {}).values(): + agent.reset(session_id, recursive=True) def __repr__(self): diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 7dc44e86..ffbd1b3d 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor from logging import getLogger from threading import Lock -from typing import Dict, List, Optional, Union +from typing import AsyncGenerator, Dict, List, Optional, Union import aiohttp import requests @@ -701,7 +701,8 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: f'{max_num_retries} times. Check the logs for ' f'details. errmsg: {errmsg}') - async def _stream_chat(self, messages: List[dict], **gen_params) -> str: + async def _stream_chat(self, messages: List[dict], + **gen_params) -> AsyncGenerator[str, None]: """Generate completion from a list of templates. Args: