Skip to content

Commit

Permalink
[Feat] Support resetting agent memory recursively or by keypath (#271)
Browse files Browse the repository at this point in the history
* update `reset` method

* fix type hints
  • Loading branch information
braisedpork1964 authored Nov 18, 2024
1 parent b2bf23d commit e304e5d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
23 changes: 20 additions & 3 deletions lagent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,26 @@ def register_hook(self, hook: Callable):
self._hooks[handle.id] = hook
return handle

def reset(self, session_id=0):
if self.memory:
self.memory.reset(session_id=session_id)
def reset(self,
session_id=0,
keypath: Optional[str] = None,
recursive: bool = False):
assert not (keypath and
recursive), 'keypath and recursive can\'t be used together'
if keypath:
keys, agent = keypath.split('.'), self
for key in keys:
agents = getattr(agent, '_agents', {})
if key not in agents:
raise KeyError(f'No sub-agent named {key} in {agent}')
agent = agents[key]
agent.reset(session_id, recursive=False)
else:
if self.memory:
self.memory.reset(session_id=session_id)
if recursive:
for agent in getattr(self, '_agents', {}).values():
agent.reset(session_id, recursive=True)

def __repr__(self):

Expand Down
5 changes: 3 additions & 2 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from threading import Lock
from typing import Dict, List, Optional, Union
from typing import AsyncGenerator, Dict, List, Optional, Union

import aiohttp
import requests
Expand Down Expand Up @@ -701,7 +701,8 @@ async def _chat(self, messages: List[dict], **gen_params) -> str:
f'{max_num_retries} times. Check the logs for '
f'details. errmsg: {errmsg}')

async def _stream_chat(self, messages: List[dict], **gen_params) -> str:
async def _stream_chat(self, messages: List[dict],
**gen_params) -> AsyncGenerator[str, None]:
"""Generate completion from a list of templates.
Args:
Expand Down

0 comments on commit e304e5d

Please sign in to comment.