From 1410e7b7cb71bb09e5870be1adbc804d14a7a008 Mon Sep 17 00:00:00 2001 From: uma-shankar-TE Date: Tue, 27 Aug 2024 20:58:31 +0530 Subject: [PATCH 1/2] added blur check, context passed in vision --- app.py | 3 +- assistant/assistant.py | 6 ++- assistant/claude_vision.py | 96 +++++++++++++++++++++++++++----------- assistant/gpt_vision.py | 41 ++++++++++++---- requirements.txt | 3 +- util/__init__.py | 2 +- util/image_processing.py | 11 +++++ 7 files changed, 122 insertions(+), 40 deletions(-) diff --git a/app.py b/app.py index 67ffc83..ec2672f 100644 --- a/app.py +++ b/app.py @@ -20,7 +20,8 @@ from models import MultimodalRequest, MultimodalResponse from util import process_image, is_speech_present from assistant import Assistant, AssistantVisionTool - +from dotenv import load_dotenv +load_dotenv() #################################################################################################### # Configuration diff --git a/assistant/assistant.py b/assistant/assistant.py index 090047f..a936ba6 100644 --- a/assistant/assistant.py +++ b/assistant/assistant.py @@ -545,7 +545,8 @@ async def _handle_vision_tool( token_usage_by_model=token_usage_by_model, query=query, image_base64=image_base64, - media_type=media_type + media_type=media_type, + message_history=message_history, ) else: output = await vision_query_claude( @@ -553,7 +554,8 @@ async def _handle_vision_tool( token_usage_by_model=token_usage_by_model, query=query, image_base64=image_base64, - media_type=media_type + media_type=media_type, + message_history=message_history, ) t_end = timeit.default_timer() diff --git a/assistant/claude_vision.py b/assistant/claude_vision.py index 16a2fb2..1681a79 100644 --- a/assistant/claude_vision.py +++ b/assistant/claude_vision.py @@ -4,14 +4,15 @@ # Claude-based vision tool. # -from typing import Dict, Optional +from typing import Dict, Optional, List import anthropic from pydantic import BaseModel from models import Message, TokenUsage, accumulate_token_usage from .vision_tool_output import VisionToolOutput - +from util import is_blurry_image +import base64 MODEL = "claude-3-haiku-20240307" # MODEL = "claude-3-5-sonnet-20240620" @@ -21,14 +22,26 @@ seeing at the time they spoke but you NEVER mention the photo or image and instead respond as if you are actually seeing. -Always do your best with images, never comment on their quality, and answer decisively with a guess -if you are not sure. There are no negative consequences to guessing. +The camera is unfortunately VERY low quality but the user is counting on you to interpret the +blurry, pixelated images. NEVER comment on image quality. Do your best with images. + +ALWAYS respond with a valid JSON object with these fields: + +response: (String) Respond to user as best you can. Be precise, get to the point, and speak as though you actually see the image. If it needs a web search it will be a description of the image. +web_query: (String) Empty if your "response" answers everything user asked. If web search based on visual description would be more helpful, create a query (e.g. up-to-date, location-based, or product info). -ALWAYS respond with a JSON object with these fields: +examples: +1. If the user asks "What do you see?" and the image is a cat in a room, you would respond: +{ + "response": "You are looking at a cat in a room.", + "web_query": "" +} -response: (String) Respond to user as best you can. Be precise, get to the point, never comment on image quality. -web_query: (String) Web query to answer the user's request. -web_search_needed: (Bool) Whether to search the web. True ONLY if "response" does not answer the user query precisely enough and up-to-date, location-specific, or product-specific info is needed. +2. If the user asks "What is that?" and the image is a red shoe with white laces, you would respond: +{ + "response": "A red shoe with white laces.", + "web_query": "red shoe with white laces" +} """ class VisionResponse(BaseModel): @@ -42,14 +55,22 @@ async def vision_query_claude( token_usage_by_model: Dict[str, TokenUsage], query: str, image_base64: str | None, - media_type: str | None + media_type: str | None, + message_history: list[Message]=[] ) -> VisionToolOutput: user_message = { "role": "user", "content": [] } - + # Check if image is blurry + if image_base64 is not None and is_blurry_image(base64.b64decode(image_base64)): + print("Image is too blurry to interpret.") + return VisionToolOutput( + is_error=False, + response="The image is too blurry to interpret. Please try again.", + web_query=None + ) if image_base64 is not None and media_type is not None: image_chunk = { "type": "image", @@ -61,19 +82,10 @@ async def vision_query_claude( } user_message["content"].append(image_chunk) user_message["content"].append({ "type": "text", "text": query }) - - messages = [ - user_message, - { - # Prefill a leading '{' to force JSON output as per Anthropic's recommendations - "role": "assistant", - "content": [ - { - "type": "text", - "text": "{" - } - ] - } + clean_message_history = [mag for mag in message_history if mag.role == "assistant" or mag.role == "user"] + clean_message_history = make_alternating(messages=clean_message_history) + messages = clean_message_history + [ + user_message ] # Call Claude @@ -89,6 +101,7 @@ async def vision_query_claude( # Parse response vision_response = parse_response(content=response.content[0].text) + print(f"Vision response: {vision_response}") if vision_response is None: return VisionToolOutput(is_error=True, response="Error: Unable to parse vision tool response. Tell user a problem interpreting the image occurred and ask them to try again.", web_query=None) web_search_needed = vision_response.web_search_needed and vision_response.web_query is not None and len(vision_response.web_query) > 0 @@ -96,11 +109,42 @@ async def vision_query_claude( return VisionToolOutput(is_error=False, response=vision_response.response, web_query=web_query) def parse_response(content: str) -> Optional[VisionResponse]: - # Put the leading '{' back - json_string = "{" + content try: - return VisionResponse.model_validate_json(json_data=json_string) + return VisionResponse.model_validate_json(json_data=content) except: pass return None +def make_alternating(messages: List[Message]) -> List[Message]: + """ + Ensure that the messages are alternating between user and assistant. + """ + # Start with the first message's role + if len(messages) == 0: + return [] + expected_role = messages[0].role + last_message = messages[-1] + alternating_messages = [] + expected_role = "user" if expected_role == "assistant" else "assistant" + + for i, message in enumerate(messages): + if message.content.strip()=='': + continue + if message.role != expected_role: + continue + + alternating_messages.append(message) + expected_role = "assistant" if expected_role == "user" else "user" + + # Ensure the last message is from the assistant + if alternating_messages and alternating_messages[-1].role != "assistant": + if last_message.role == "assistant": + alternating_messages.append(last_message) + else: + alternating_messages.pop() + # if first message is from assistant, remove it + if alternating_messages and alternating_messages[0].role == "assistant": + alternating_messages.pop(0) + return alternating_messages + + diff --git a/assistant/gpt_vision.py b/assistant/gpt_vision.py index a4e56b3..671a4cc 100644 --- a/assistant/gpt_vision.py +++ b/assistant/gpt_vision.py @@ -11,7 +11,8 @@ from models import Message, TokenUsage, accumulate_token_usage from .vision_tool_output import VisionToolOutput - +from util import is_blurry_image +import base64 MODEL = "gpt-4o" # MODEL = "gpt-4o-mini" @@ -27,9 +28,22 @@ ALWAYS respond with a valid JSON object with these fields: -response: (String) Respond to user as best you can. Be precise, get to the point, and speak as though you actually see the image. +response: (String) Respond to user as best you can. Be precise, get to the point, and speak as though you actually see the image. If it needs a web search it will be a description of the image. web_query: (String) Empty if your "response" answers everything user asked. If web search based on visual description would be more helpful, create a query (e.g. up-to-date, location-based, or product info). -reverse_image_search: (Bool) True if your web query from description is insufficient and including the *exact* thing user is looking at as visual target is needed. + +examples: +1. If the user asks "What do you see?" and the image is a cat in a room, you would respond: +{ + "response": "You are looking at a cat in a room.", + "web_query": "" +} + +2. If the user asks "What is that?" and the image is a red shoe with white laces, you would respond: +{ + "response": "A red shoe with white laces.", + "web_query": "red shoe with white laces" +} + """ @@ -44,23 +58,32 @@ async def vision_query_gpt( token_usage_by_model: Dict[str, TokenUsage], query: str, image_base64: str | None, - media_type: str | None + media_type: str | None, + message_history: list[Message]=[], ) -> VisionToolOutput: # Create messages for GPT w/ image. No message history or extra context for this tool, as we # will rely on second LLM call. Passing in message history and extra context necessary to # allow direct tool output seems to cause this to take longer, hence we don't permit it. + + # Check if image is blurry + if image_base64 is not None and is_blurry_image(base64.b64decode(image_base64)): + return VisionToolOutput( + is_error=False, + response="The image is too blurry to interpret. Please try again.", + web_query=None + ) user_message = { "role": "user", "content": [ { "type": "text", "text": query } ] } + clean_message_history = [m for m in message_history if m.role == "user" or m.role == "assistant"] if image_base64 is not None and media_type is not None: user_message["content"].append({ "type": "image_url", "image_url": { "url": f"data:{media_type};base64,{image_base64}" } }) messages = [ { "role": "system", "content": SYSTEM_MESSAGE }, - user_message - ] + ] + clean_message_history + [user_message] # Call GPT response = await client.chat.completions.create( @@ -77,12 +100,12 @@ async def vision_query_gpt( json_string = content[json_start : json_end + 1] if json_start > -1 and json_end > -1 else content try: vision_response = VisionResponse.model_validate_json(json_data=json_string) - vision_response.reverse_image_search = vision_response.reverse_image_search is not None and vision_response.reverse_image_search == True - if len(vision_response.web_query) == 0 and vision_response.reverse_image_search: + # vision_response.reverse_image_search = vision_response.reverse_image_search is not None and vision_response.reverse_image_search == True + # if len(vision_response.web_query) == 0 and vision_response.reverse_image_search: # If no web query output but reverse image search asked for, just use user query # directly. This is sub-optimal and it would be better to figure out a way to ensure # web_query is generated when reverse_image_search is true. - vision_response.web_query = query + # vision_response.web_query = query except Exception as e: print(f"Error: Unable to parse vision response: {e}") return VisionToolOutput( diff --git a/requirements.txt b/requirements.txt index dd8fdb7..99b3175 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,4 +33,5 @@ aiohttp groq opencv-python webrtcvad -numpy \ No newline at end of file +numpy +python-dotenv \ No newline at end of file diff --git a/util/__init__.py b/util/__init__.py index a6dfc99..909101a 100644 --- a/util/__init__.py +++ b/util/__init__.py @@ -1,4 +1,4 @@ from .hexdump import hexdump from .media import detect_media_type -from .image_processing import process_image +from .image_processing import process_image, is_blurry_image from .vad import is_speech_present \ No newline at end of file diff --git a/util/image_processing.py b/util/image_processing.py index 007b83d..6468f1b 100644 --- a/util/image_processing.py +++ b/util/image_processing.py @@ -293,6 +293,17 @@ def get_bytes(self): else: return cv2.imencode('.jpg', self.filtered_image)[1].tobytes() return None + +def is_blurry_image(image: bytes, threshold: float = 10.0) -> bool: + """ + Check if an image is blurry using the Laplacian method. + """ + nparr = np.frombuffer(image, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + laplacian = cv2.Laplacian(gray, cv2.CV_64F).var() + print("Laplacian variance: ", laplacian) + return laplacian < threshold def process_image(bytes: bytes)->bytes: filters: List[BaseFilter] = [ From 6af55a5f3bd4bb678375a9f7d39fefbbcb404b50 Mon Sep 17 00:00:00 2001 From: uma-shankar-TE Date: Tue, 27 Aug 2024 22:57:14 +0530 Subject: [PATCH 2/2] added topic_changed tool parameter --- assistant/assistant.py | 39 ++++++++++++++++++++++++++------------ assistant/claude_vision.py | 15 ++++++++------- assistant/gpt_vision.py | 9 +++++++-- 3 files changed, 42 insertions(+), 21 deletions(-) diff --git a/assistant/assistant.py b/assistant/assistant.py index a936ba6..3fb010b 100644 --- a/assistant/assistant.py +++ b/assistant/assistant.py @@ -60,6 +60,7 @@ SEARCH_TOOL_NAME = "web_search" VISION_TOOL_NAME = "analyze_photo" QUERY_PARAM_NAME = "query" +TOPIC_CHNAGED_PARAM_NAME = "topic_changed" IMAGE_GENERATION_TOOL_NAME = "generate_image" IMAGE_GENERATION_PARAM_NAME = "description" @@ -69,7 +70,7 @@ "type": "function", "function": { "name": SEARCH_TOOL_NAME, - "description": """Up-to-date information on news, retail products, current events, local conditions, and esoteric knowledge""", + "description": """Up-to-date information on news, retail products, current events, local conditions, and esoteric knowledge. performs a web search based on the user's query.""", "parameters": { "type": "object", "properties": { @@ -77,8 +78,12 @@ "type": "string", "description": "search query", }, + TOPIC_CHNAGED_PARAM_NAME: { + "type": "boolean", + "description": "Whether the topic has changed since the last query" + }, }, - "required": [ QUERY_PARAM_NAME ] + "required": [ QUERY_PARAM_NAME , TOPIC_CHNAGED_PARAM_NAME ] }, }, }, @@ -95,8 +100,12 @@ "type": "string", "description": "User's query to answer expressed as a command that NEVER refers to the photo or image itself" }, + TOPIC_CHNAGED_PARAM_NAME: { + "type": "boolean", + "description": "Whether the topic has changed since the last query" + }, }, - "required": [ QUERY_PARAM_NAME ] + "required": [ QUERY_PARAM_NAME , TOPIC_CHNAGED_PARAM_NAME ] }, }, }, @@ -338,7 +347,7 @@ def _create_speculative_tool_calls( tool_call = ChatCompletionMessageToolCall( id="speculative_web_search_tool", function=Function( - arguments=json.dumps({ "query": query }), + arguments=json.dumps({ "query": query, "topic_changed": False }), name=SEARCH_TOOL_NAME ), type="function" @@ -529,7 +538,8 @@ async def _handle_vision_tool( message_history: List[Message], image_bytes: bytes | None, location_address: str | None, - local_time: str | None + local_time: str | None, + topic_changed: bool | None = None ) -> ToolOutput: t_start = timeit.default_timer() @@ -538,7 +548,8 @@ async def _handle_vision_tool( if image_bytes: image_base64 = base64.b64encode(image_bytes).decode("utf-8") media_type = detect_media_type(image_bytes=image_bytes) - + + extra_context = CONTEXT_SYSTEM_MESSAGE_PREFIX + "\n".join([ f"<{key}>{value}" for key, value in { "location": location_address, "current_time": local_time }.items() if value is not None ]) if self._vision_tool == AssistantVisionTool.GPT4O: output = await vision_query_gpt( client=self._client, @@ -546,7 +557,8 @@ async def _handle_vision_tool( query=query, image_base64=image_base64, media_type=media_type, - message_history=message_history, + message_history= [] if topic_changed else message_history, + extra_context=extra_context ) else: output = await vision_query_claude( @@ -555,7 +567,8 @@ async def _handle_vision_tool( query=query, image_base64=image_base64, media_type=media_type, - message_history=message_history, + message_history= [] if topic_changed else message_history, + extra_context=extra_context ) t_end = timeit.default_timer() @@ -576,10 +589,10 @@ async def _handle_vision_tool( timings=timings, query=output.web_query, flavor_prompt=flavor_prompt, - message_history=message_history, + message_history=[], image_bytes=None, location_address=location_address, - local_time=local_time + local_time=local_time, ) return ToolOutput(text=f"HERE IS WHAT YOU SEE: {output.response}\nEXTRA INFO FROM WEB: {web_result}", safe_for_final_response=False) @@ -593,7 +606,8 @@ async def _handle_web_search_tool( message_history: List[Message], image_bytes: bytes | None, location_address: str | None, - local_time: str | None + local_time: str | None, + topic_changed: bool | None = None ) -> ToolOutput: t_start = timeit.default_timer() output = await self._web_tool.search_web( @@ -601,7 +615,7 @@ async def _handle_web_search_tool( timings=timings, query=query, flavor_prompt=flavor_prompt, - message_history=message_history, + message_history=[] if topic_changed else message_history, location=location_address ) t_end = timeit.default_timer() @@ -651,6 +665,7 @@ def _validate_tool_args(tool_call: ChatCompletionMessageToolCall) -> Dict[str, A args: Dict[str, Any] = {} try: args = json.loads(tool_call.function.arguments) + print(f"Tool arguments: {args}") except: pass for param_name in list(args.keys()): diff --git a/assistant/claude_vision.py b/assistant/claude_vision.py index 1681a79..7426ede 100644 --- a/assistant/claude_vision.py +++ b/assistant/claude_vision.py @@ -47,7 +47,6 @@ class VisionResponse(BaseModel): response: str web_query: Optional[str] = None - web_search_needed: Optional[bool] = None async def vision_query_claude( @@ -56,7 +55,8 @@ async def vision_query_claude( query: str, image_base64: str | None, media_type: str | None, - message_history: list[Message]=[] + message_history: list[Message]=[], + extra_context: str | None = None, ) -> VisionToolOutput: user_message = { @@ -87,11 +87,14 @@ async def vision_query_claude( messages = clean_message_history + [ user_message ] + _system_message = SYSTEM_MESSAGE + if extra_context is not None: + _system_message = _system_message + "\n" + extra_context # Call Claude response = await client.messages.create( model=MODEL, - system=SYSTEM_MESSAGE, + system=_system_message, messages=messages, max_tokens=4096, temperature=0.0 @@ -101,7 +104,6 @@ async def vision_query_claude( # Parse response vision_response = parse_response(content=response.content[0].text) - print(f"Vision response: {vision_response}") if vision_response is None: return VisionToolOutput(is_error=True, response="Error: Unable to parse vision tool response. Tell user a problem interpreting the image occurred and ask them to try again.", web_query=None) web_search_needed = vision_response.web_search_needed and vision_response.web_query is not None and len(vision_response.web_query) > 0 @@ -111,9 +113,8 @@ async def vision_query_claude( def parse_response(content: str) -> Optional[VisionResponse]: try: return VisionResponse.model_validate_json(json_data=content) - except: - pass - return None + except Exception as e: + raise Exception(f"Error: Unable to parse vision response.{e}") def make_alternating(messages: List[Message]) -> List[Message]: """ diff --git a/assistant/gpt_vision.py b/assistant/gpt_vision.py index 671a4cc..db4a025 100644 --- a/assistant/gpt_vision.py +++ b/assistant/gpt_vision.py @@ -50,7 +50,6 @@ class VisionResponse(BaseModel): response: str web_query: Optional[str] = "" - reverse_image_search: Optional[bool] = None async def vision_query_gpt( @@ -60,6 +59,7 @@ async def vision_query_gpt( image_base64: str | None, media_type: str | None, message_history: list[Message]=[], + extra_context: str | None = None ) -> VisionToolOutput: # Create messages for GPT w/ image. No message history or extra context for this tool, as we # will rely on second LLM call. Passing in message history and extra context necessary to @@ -79,10 +79,15 @@ async def vision_query_gpt( ] } clean_message_history = [m for m in message_history if m.role == "user" or m.role == "assistant"] + + _system_message = SYSTEM_MESSAGE + + if extra_context is not None: + _system_message = _system_message + "\n" + extra_context if image_base64 is not None and media_type is not None: user_message["content"].append({ "type": "image_url", "image_url": { "url": f"data:{media_type};base64,{image_base64}" } }) messages = [ - { "role": "system", "content": SYSTEM_MESSAGE }, + { "role": "system", "content": _system_message }, ] + clean_message_history + [user_message] # Call GPT