Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

topic_changed param, blur test, context historty pass #33

Open
wants to merge 2 commits into
base: bart/refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from models import MultimodalRequest, MultimodalResponse
from util import process_image, is_speech_present
from assistant import Assistant, AssistantVisionTool

from dotenv import load_dotenv
load_dotenv()

####################################################################################################
# Configuration
Expand Down
41 changes: 29 additions & 12 deletions assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
SEARCH_TOOL_NAME = "web_search"
VISION_TOOL_NAME = "analyze_photo"
QUERY_PARAM_NAME = "query"
TOPIC_CHNAGED_PARAM_NAME = "topic_changed"

IMAGE_GENERATION_TOOL_NAME = "generate_image"
IMAGE_GENERATION_PARAM_NAME = "description"
Expand All @@ -69,16 +70,20 @@
"type": "function",
"function": {
"name": SEARCH_TOOL_NAME,
"description": """Up-to-date information on news, retail products, current events, local conditions, and esoteric knowledge""",
"description": """Up-to-date information on news, retail products, current events, local conditions, and esoteric knowledge. performs a web search based on the user's query.""",
"parameters": {
"type": "object",
"properties": {
QUERY_PARAM_NAME: {
"type": "string",
"description": "search query",
},
TOPIC_CHNAGED_PARAM_NAME: {
"type": "boolean",
"description": "Whether the topic has changed since the last query"
},
},
"required": [ QUERY_PARAM_NAME ]
"required": [ QUERY_PARAM_NAME , TOPIC_CHNAGED_PARAM_NAME ]
},
},
},
Expand All @@ -95,8 +100,12 @@
"type": "string",
"description": "User's query to answer expressed as a command that NEVER refers to the photo or image itself"
},
TOPIC_CHNAGED_PARAM_NAME: {
"type": "boolean",
"description": "Whether the topic has changed since the last query"
},
},
"required": [ QUERY_PARAM_NAME ]
"required": [ QUERY_PARAM_NAME , TOPIC_CHNAGED_PARAM_NAME ]
},
},
},
Expand Down Expand Up @@ -338,7 +347,7 @@ def _create_speculative_tool_calls(
tool_call = ChatCompletionMessageToolCall(
id="speculative_web_search_tool",
function=Function(
arguments=json.dumps({ "query": query }),
arguments=json.dumps({ "query": query, "topic_changed": False }),
name=SEARCH_TOOL_NAME
),
type="function"
Expand Down Expand Up @@ -529,7 +538,8 @@ async def _handle_vision_tool(
message_history: List[Message],
image_bytes: bytes | None,
location_address: str | None,
local_time: str | None
local_time: str | None,
topic_changed: bool | None = None
) -> ToolOutput:
t_start = timeit.default_timer()

Expand All @@ -538,22 +548,27 @@ async def _handle_vision_tool(
if image_bytes:
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
media_type = detect_media_type(image_bytes=image_bytes)


extra_context = CONTEXT_SYSTEM_MESSAGE_PREFIX + "\n".join([ f"<{key}>{value}</{key}>" for key, value in { "location": location_address, "current_time": local_time }.items() if value is not None ])
if self._vision_tool == AssistantVisionTool.GPT4O:
output = await vision_query_gpt(
client=self._client,
token_usage_by_model=token_usage_by_model,
query=query,
image_base64=image_base64,
media_type=media_type
media_type=media_type,
message_history= [] if topic_changed else message_history,
extra_context=extra_context
)
else:
output = await vision_query_claude(
client=self._anthropic_client,
token_usage_by_model=token_usage_by_model,
query=query,
image_base64=image_base64,
media_type=media_type
media_type=media_type,
message_history= [] if topic_changed else message_history,
extra_context=extra_context
)

t_end = timeit.default_timer()
Expand All @@ -574,10 +589,10 @@ async def _handle_vision_tool(
timings=timings,
query=output.web_query,
flavor_prompt=flavor_prompt,
message_history=message_history,
message_history=[],
image_bytes=None,
location_address=location_address,
local_time=local_time
local_time=local_time,
)
return ToolOutput(text=f"HERE IS WHAT YOU SEE: {output.response}\nEXTRA INFO FROM WEB: {web_result}", safe_for_final_response=False)

Expand All @@ -591,15 +606,16 @@ async def _handle_web_search_tool(
message_history: List[Message],
image_bytes: bytes | None,
location_address: str | None,
local_time: str | None
local_time: str | None,
topic_changed: bool | None = None
) -> ToolOutput:
t_start = timeit.default_timer()
output = await self._web_tool.search_web(
token_usage_by_model=token_usage_by_model,
timings=timings,
query=query,
flavor_prompt=flavor_prompt,
message_history=message_history,
message_history=[] if topic_changed else message_history,
location=location_address
)
t_end = timeit.default_timer()
Expand Down Expand Up @@ -649,6 +665,7 @@ def _validate_tool_args(tool_call: ChatCompletionMessageToolCall) -> Dict[str, A
args: Dict[str, Any] = {}
try:
args = json.loads(tool_call.function.arguments)
print(f"Tool arguments: {args}")
except:
pass
for param_name in list(args.keys()):
Expand Down
107 changes: 76 additions & 31 deletions assistant/claude_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# Claude-based vision tool.
#

from typing import Dict, Optional
from typing import Dict, Optional, List

import anthropic
from pydantic import BaseModel

from models import Message, TokenUsage, accumulate_token_usage
from .vision_tool_output import VisionToolOutput

from util import is_blurry_image
import base64

MODEL = "claude-3-haiku-20240307"
# MODEL = "claude-3-5-sonnet-20240620"
Expand All @@ -21,35 +22,55 @@
seeing at the time they spoke but you NEVER mention the photo or image and instead respond as if you
are actually seeing.

Always do your best with images, never comment on their quality, and answer decisively with a guess
if you are not sure. There are no negative consequences to guessing.
The camera is unfortunately VERY low quality but the user is counting on you to interpret the
blurry, pixelated images. NEVER comment on image quality. Do your best with images.

ALWAYS respond with a valid JSON object with these fields:

response: (String) Respond to user as best you can. Be precise, get to the point, and speak as though you actually see the image. If it needs a web search it will be a description of the image.
web_query: (String) Empty if your "response" answers everything user asked. If web search based on visual description would be more helpful, create a query (e.g. up-to-date, location-based, or product info).

ALWAYS respond with a JSON object with these fields:
examples:
1. If the user asks "What do you see?" and the image is a cat in a room, you would respond:
{
"response": "You are looking at a cat in a room.",
"web_query": ""
}

response: (String) Respond to user as best you can. Be precise, get to the point, never comment on image quality.
web_query: (String) Web query to answer the user's request.
web_search_needed: (Bool) Whether to search the web. True ONLY if "response" does not answer the user query precisely enough and up-to-date, location-specific, or product-specific info is needed.
2. If the user asks "What is that?" and the image is a red shoe with white laces, you would respond:
{
"response": "A red shoe with white laces.",
"web_query": "red shoe with white laces"
}
"""

class VisionResponse(BaseModel):
response: str
web_query: Optional[str] = None
web_search_needed: Optional[bool] = None


async def vision_query_claude(
client: anthropic.AsyncAnthropic,
token_usage_by_model: Dict[str, TokenUsage],
query: str,
image_base64: str | None,
media_type: str | None
media_type: str | None,
message_history: list[Message]=[],
extra_context: str | None = None,
) -> VisionToolOutput:

user_message = {
"role": "user",
"content": []
}

# Check if image is blurry
if image_base64 is not None and is_blurry_image(base64.b64decode(image_base64)):
print("Image is too blurry to interpret.")
return VisionToolOutput(
is_error=False,
response="The image is too blurry to interpret. Please try again.",
web_query=None
)
if image_base64 is not None and media_type is not None:
image_chunk = {
"type": "image",
Expand All @@ -61,25 +82,19 @@ async def vision_query_claude(
}
user_message["content"].append(image_chunk)
user_message["content"].append({ "type": "text", "text": query })

messages = [
user_message,
{
# Prefill a leading '{' to force JSON output as per Anthropic's recommendations
"role": "assistant",
"content": [
{
"type": "text",
"text": "{"
}
]
}
clean_message_history = [mag for mag in message_history if mag.role == "assistant" or mag.role == "user"]
clean_message_history = make_alternating(messages=clean_message_history)
messages = clean_message_history + [
user_message
]
_system_message = SYSTEM_MESSAGE

if extra_context is not None:
_system_message = _system_message + "\n" + extra_context
# Call Claude
response = await client.messages.create(
model=MODEL,
system=SYSTEM_MESSAGE,
system=_system_message,
messages=messages,
max_tokens=4096,
temperature=0.0
Expand All @@ -96,11 +111,41 @@ async def vision_query_claude(
return VisionToolOutput(is_error=False, response=vision_response.response, web_query=web_query)

def parse_response(content: str) -> Optional[VisionResponse]:
# Put the leading '{' back
json_string = "{" + content
try:
return VisionResponse.model_validate_json(json_data=json_string)
except:
pass
return None
return VisionResponse.model_validate_json(json_data=content)
except Exception as e:
raise Exception(f"Error: Unable to parse vision response.{e}")

def make_alternating(messages: List[Message]) -> List[Message]:
"""
Ensure that the messages are alternating between user and assistant.
"""
# Start with the first message's role
if len(messages) == 0:
return []
expected_role = messages[0].role
last_message = messages[-1]
alternating_messages = []
expected_role = "user" if expected_role == "assistant" else "assistant"

for i, message in enumerate(messages):
if message.content.strip()=='':
continue
if message.role != expected_role:
continue

alternating_messages.append(message)
expected_role = "assistant" if expected_role == "user" else "user"

# Ensure the last message is from the assistant
if alternating_messages and alternating_messages[-1].role != "assistant":
if last_message.role == "assistant":
alternating_messages.append(last_message)
else:
alternating_messages.pop()
# if first message is from assistant, remove it
if alternating_messages and alternating_messages[0].role == "assistant":
alternating_messages.pop(0)
return alternating_messages


Loading