Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RCI Agent #38

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 33 additions & 30 deletions agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,40 +116,43 @@ def set_action_set_tag(self, tag: str) -> None:
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any]
) -> Action:
prompt = self.prompt_constructor.construct(
trajectory, intent, meta_data
)
lm_config = self.lm_config
if lm_config.provider == "openai":
if lm_config.mode == "chat":
response = generate_from_openai_chat_completion(
messages=prompt,
model=lm_config.model,
temperature=lm_config.gen_config["temperature"],
top_p=lm_config.gen_config["top_p"],
context_length=lm_config.gen_config["context_length"],
max_tokens=lm_config.gen_config["max_tokens"],
stop_token=None,
)
elif lm_config.mode == "completion":
response = generate_from_openai_completion(
prompt=prompt,
engine=lm_config.model,
temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"],
stop_token=lm_config.gen_config["stop_token"],
)
def llm(prompt):
if lm_config.provider == "openai":
if lm_config.mode == "chat":
response = generate_from_openai_chat_completion(
messages=prompt,
model=lm_config.model,
temperature=lm_config.gen_config["temperature"],
top_p=lm_config.gen_config["top_p"],
context_length=lm_config.gen_config["context_length"],
max_tokens=lm_config.gen_config["max_tokens"],
stop_token=None,
)
elif lm_config.mode == "completion":
response = generate_from_openai_completion(
prompt=prompt,
engine=lm_config.model,
temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"],
stop_token=lm_config.gen_config["stop_token"],
)
else:
raise ValueError(
f"OpenAI models do not support mode {lm_config.mode}"
)
else:
raise ValueError(
f"OpenAI models do not support mode {lm_config.mode}"
raise NotImplementedError(
f"Provider {lm_config.provider} not implemented"
)
else:
raise NotImplementedError(
f"Provider {lm_config.provider} not implemented"
)


return response

try:
response = self.prompt_constructor.construct(
trajectory, intent, meta_data, llm
)
parsed_response = self.prompt_constructor.extract_action(response)
if self.action_set_tag == "id_accessibility_tree":
action = create_id_based_action(parsed_response)
Expand Down
177 changes: 164 additions & 13 deletions agent/prompts/prompt_constructor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import re
from pathlib import Path
from typing import Any, TypedDict
from typing import Any, TypedDict, Callable, Optional

import tiktoken

Expand Down Expand Up @@ -37,6 +37,22 @@ def __init__(
self.instruction: Instruction = instruction
self.tokenizer = tokenizer

@beartype
def _get_llm_output(
self,
intro: str,
examples: list[tuple[str, str]],
template: str,
llm: Callable,
**kwargs
) -> str:
prompt = template.format(**kwargs)
prompt = self.get_lm_api_input(intro, examples, prompt)
response = llm(prompt)

return response

@beartype
def get_lm_api_input(
self, intro: str, examples: list[tuple[str, str]], current: str
) -> APIInput:
Expand Down Expand Up @@ -129,7 +145,8 @@ def construct(
trajectory: Trajectory,
intent: str,
meta_data: dict[str, Any] = {},
) -> APIInput:
llm: Callable = None
) -> str:
"""Construct prompt given the trajectory"""
intro = self.instruction["intro"]
examples = self.instruction["examples"]
Expand All @@ -146,18 +163,18 @@ def construct(
url = page.url
previous_action_str = meta_data["action_history"][-1]

# input x
current = template.format(
response = self._get_llm_output(
intro,
examples,
template,
llm,
objective=intent,
url=self.map_url_to_real(url),
observation=obs,
previous_action=previous_action_str,
)

# make sure all keywords are replaced
assert all([f"{{k}}" not in current for k in keywords])
prompt = self.get_lm_api_input(intro, examples, current)
return prompt
return response

def _extract_action(self, response: str) -> str:
action_splitter = self.instruction["meta_data"]["action_splitter"]
Expand Down Expand Up @@ -188,7 +205,8 @@ def construct(
trajectory: Trajectory,
intent: str,
meta_data: dict[str, Any] = {},
) -> APIInput:
llm: Callable = None
) -> str:
intro = self.instruction["intro"]
examples = self.instruction["examples"]
template = self.instruction["template"]
Expand All @@ -203,17 +221,150 @@ def construct(
page = state_info["info"]["page"]
url = page.url
previous_action_str = meta_data["action_history"][-1]
current = template.format(

response = self._get_llm_output(
intro,
examples,
template,
llm,
objective=intent,
url=self.map_url_to_real(url),
observation=obs,
previous_action=previous_action_str,
)

assert all([f"{{k}}" not in current for k in keywords])
return response

prompt = self.get_lm_api_input(intro, examples, current)
return prompt
@beartype
def _extract_action(self, response: str) -> str:
# find the first occurence of action
action_splitter = self.instruction["meta_data"]["action_splitter"]
pattern = rf"{action_splitter}(.*?){action_splitter}"
match = re.search(pattern, response)
if match:
return match.group(1)
else:
raise ActionParsingError(
f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"'
)

class RCIPromptConstructor(PromptConstructor):
def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: tiktoken.core.Encoding,
):
super().__init__(instruction_path, lm_config, tokenizer)
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
self.plan = None

@beartype
def construct(
self,
trajectory: Trajectory,
intent: str,
meta_data: dict[str, Any] = {},
llm: Callable = None
) -> str:
intro = self.instruction["intro"]

state_info: StateInfo = trajectory[-1] # type: ignore[assignment]

page = state_info["info"]["page"]
url = self.map_url_to_real(page.url)
history_actions = ', '.join(meta_data["action_history"])
previous_action_str = meta_data["action_history"][-1]

obs = state_info["observation"][self.obs_modality]
max_obs_length = self.lm_config.gen_config["max_obs_length"]
if max_obs_length:
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]

# Get plan
if self.plan is None:
plan = self._get_llm_output(
intro,
[],
self.instruction["template_plan"],
llm,
observation=obs,
url=url,
objective=intent,
)

# Get critique
critique = self._get_llm_output(
intro,
[],
self.instruction["template_critique"],
llm,
observation=obs,
url=url,
objective=intent,
plan=plan,
)

# Get improved plan
plan = self._get_llm_output(
intro,
[],
self.instruction["template_improve"],
llm,
observation=obs,
url=url,
objective=intent,
plan=plan,
critique=critique,
)

self.plan = plan

# Get next step
meta_next_action = self._get_llm_output(
intro,
[],
self.instruction["template_next_step"],
llm,
observation=obs,
url=url,
objective=intent,
previous_action=previous_action_str,
plan=self.plan,
)

# Get state grounding
draft_next_action = self._get_llm_output(
intro,
[],
self.instruction["template_state_grounding"],
llm,
observation=obs,
url=url,
previous_action=previous_action_str,
meta_next_action=meta_next_action,
)

# Get agent grounding
response = self._get_llm_output(
intro,
[],
self.instruction["template_agent_grounding"],
llm,
observation=obs,
url=url,
previous_action=previous_action_str,
meta_next_action=meta_next_action,
draft_next_action=draft_next_action
)

# XXX: hacky fix
# fix = input(f'fix response="{response}"?').strip()
# if fix != '':
# response = fix
# print(f'fixed response="{response}"')

return response

def _extract_action(self, response: str) -> str:
# find the first occurence of action
Expand Down
Loading