-
Notifications
You must be signed in to change notification settings - Fork 27
/
lagent_example.py
52 lines (41 loc) · 1.35 KB
/
lagent_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import argparse
from lagent import GPTAPI, ActionExecutor, ReAct
from prompt_toolkit import ANSI, prompt
from agentlego.apis import load_tool
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gpt-3.5-turbo')
parser.add_argument(
'--tools',
type=str,
nargs='+',
default=['GoogleSearch', 'Calculator'],
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# set OPEN_API_KEY in your environment or directly pass it with key=''
model = GPTAPI(model_type=args.model)
tools = [load_tool(tool_type).to_lagent() for tool_type in args.tools]
chatbot = ReAct(
llm=model,
max_turn=3,
action_executor=ActionExecutor(actions=tools),
)
while True:
try:
user = prompt(ANSI('\033[92mUser\033[0m: '))
except UnicodeDecodeError:
print('UnicodeDecodeError')
continue
if user == 'exit':
exit(0)
result = chatbot.chat(user)
for history in result.inner_steps:
if history['role'] == 'system':
print(f"\033[92mSystem\033[0m:{history['content']}")
elif history['role'] == 'assistant':
print(f"\033[92mBot\033[0m:\n{history['content']}")
if __name__ == '__main__':
main()