From 06f79ee16bc18a63106fc94055943082c9f005f3 Mon Sep 17 00:00:00 2001 From: Tianqi Xu Date: Wed, 18 Sep 2024 19:42:18 +0300 Subject: [PATCH] Refactor gemini model and pass mypy --- crab/agents/backend_models/gemini_model.py | 171 +++++++++++---------- pyproject.toml | 2 +- 2 files changed, 91 insertions(+), 82 deletions(-) diff --git a/crab/agents/backend_models/gemini_model.py b/crab/agents/backend_models/gemini_model.py index 26123b6..24d3ea9 100644 --- a/crab/agents/backend_models/gemini_model.py +++ b/crab/agents/backend_models/gemini_model.py @@ -15,12 +15,19 @@ from time import sleep from typing import Any -from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType +from PIL.Image import Image + +from crab import Action, ActionOutput, BackendModel, BackendOutput, Message, MessageType from crab.utils.common import base64_to_image, json_expand_refs try: import google.generativeai as genai - from google.ai.generativelanguage_v1beta import FunctionDeclaration, Part, Tool + from google.ai.generativelanguage_v1beta import ( + Content, + FunctionDeclaration, + Part, + Tool, + ) from google.api_core.exceptions import ResourceExhausted from google.generativeai.types import content_types @@ -51,51 +58,70 @@ def __init__( def reset(self, system_message: str, action_space: list[Action] | None) -> None: self.system_message = system_message self.action_space = action_space - self.action_schema = self._convert_action_to_schema(self.action_space) + self.action_schema = _convert_action_to_schema(self.action_space) self.token_usage = 0 - self.chat_history = [] - - def chat(self, message: list[tuple[str, MessageType]]) -> BackendOutput: - # Initialize chat history - request = [] - if self.history_messages_len > 0 and len(self.chat_history) > 0: - for history_message in self.chat_history[-self.history_messages_len :]: - request = request + history_message + self.chat_history: list[list[dict]] = [] - if not isinstance(message, list): + def chat(self, message: list[Message] | Message) -> BackendOutput: + if isinstance(message, tuple): message = [message] - - new_message = { - "role": "user", - "parts": [self._convert_message(part) for part in message], - } + request = self.fetch_from_memory() + new_message = self.construct_new_message(message) request.append(new_message) - - response = self.call_api(request) - response_message = response.candidates[0].content + response_message = self.call_api(request) self.record_message(new_message, response_message) + return self.generate_backend_output(response_message) + + def construct_new_message(self, message: list[Message]) -> dict[str, Any]: + parts: list[str | Image] = [] + for content, msg_type in message: + match msg_type: + case MessageType.TEXT: + parts.append(content) + case MessageType.IMAGE_JPG_BASE64: + parts.append(base64_to_image(content)) + return { + "role": "user", + "parts": parts, + } - tool_calls = [ - Part.to_dict(part)["function_call"] - for part in response.parts - if "function_call" in Part.to_dict(part) - ] + def generate_backend_output(self, response_message: Content) -> BackendOutput: + tool_calls: list[ActionOutput] = [] + for part in response_message.parts: + if "function_call" in Part.to_dict(part): + call = Part.to_dict(part)["function_call"] + tool_calls.append( + ActionOutput( + name=call["name"], + arguments=call["args"], + ) + ) return BackendOutput( message=response_message.parts[0].text or None, - action_list=self._convert_tool_calls_to_action_list(tool_calls), + action_list=tool_calls or None, ) + def fetch_from_memory(self) -> list[dict]: + request: list[dict] = [] + if self.history_messages_len > 0: + fetch_hisotry_len = min(self.history_messages_len, len(self.chat_history)) + for history_message in self.chat_history[-fetch_hisotry_len:]: + request = request + history_message + return request + def get_token_usage(self): return self.token_usage - def record_message(self, new_message: dict, response_message: dict) -> None: + def record_message( + self, new_message: dict[str, Any], response_message: Content + ) -> None: self.chat_history.append([new_message]) self.chat_history[-1].append( {"role": response_message.role, "parts": response_message.parts} ) - def call_api(self, request_messages: list): + def call_api(self, request_messages: list) -> Content: while True: try: if self.action_schema is not None: @@ -131,58 +157,41 @@ def call_api(self, request_messages: list): break self.token_usage += response.candidates[0].token_count - return response - - @staticmethod - def _convert_message(message: tuple[str, MessageType]): - match message[1]: - case MessageType.TEXT: - return message[0] - case MessageType.IMAGE_JPG_BASE64: - return base64_to_image(message[0]) - - @classmethod - def _convert_action_to_schema(cls, action_space): - if action_space is None: - return None - actions = [] - for action in action_space: - actions.append(Tool(function_declarations=[cls._action_to_funcdec(action)])) - return actions - - @staticmethod - def _convert_tool_calls_to_action_list(tool_calls) -> list[ActionOutput]: - if tool_calls: - return [ - ActionOutput( - name=call["name"], - arguments=call["args"], - ) - for call in tool_calls + return response.candidates[0].content + + +def _convert_action_to_schema(action_space: list[Action] | None) -> list[Tool] | None: + if action_space is None: + return None + actions = [ + Tool( + function_declarations=[ + _action_to_funcdec(action) for action in action_space ] - else: - return None - - @classmethod - def _clear_schema(cls, schema_dict: dict): - schema_dict.pop("title", None) - p_type = schema_dict.pop("type", None) - for prop in schema_dict.get("properties", {}).values(): - cls._clear_schema(prop) - if p_type is not None: - schema_dict["type_"] = p_type.upper() - if "items" in schema_dict: - cls._clear_schema(schema_dict["items"]) - - @classmethod - def _action_to_funcdec(cls, action: Action) -> FunctionDeclaration: - "Converts crab Action to google FunctionDeclaration" - p_schema = action.parameters.model_json_schema() - if "$defs" in p_schema: - p_schema = json_expand_refs(p_schema) - cls._clear_schema(p_schema) - return FunctionDeclaration( - name=action.name, - description=action.description, - parameters=p_schema, ) + ] + return actions + + +def _clear_schema(schema_dict: dict): + schema_dict.pop("title", None) + p_type = schema_dict.pop("type", None) + for prop in schema_dict.get("properties", {}).values(): + _clear_schema(prop) + if p_type is not None: + schema_dict["type_"] = p_type.upper() + if "items" in schema_dict: + _clear_schema(schema_dict["items"]) + + +def _action_to_funcdec(action: Action) -> FunctionDeclaration: + "Converts crab Action to google FunctionDeclaration" + p_schema = action.parameters.model_json_schema() + if "$defs" in p_schema: + p_schema = json_expand_refs(p_schema) + _clear_schema(p_schema) + return FunctionDeclaration( + name=action.name, + description=action.description, + parameters=p_schema, + ) diff --git a/pyproject.toml b/pyproject.toml index f22670d..0019855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,5 +118,5 @@ lint.ignore = ["E731"] exclude = ["docs/"] [[tool.mypy.overrides]] -module = ["dill", "easyocr"] +module = ["dill", "easyocr", "google.generativeai.*"] ignore_missing_imports = true