Skip to content

Commit

Permalink
Refactor gemini model and pass mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Sep 18, 2024
1 parent aff284e commit 06f79ee
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 82 deletions.
171 changes: 90 additions & 81 deletions crab/agents/backend_models/gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
from time import sleep
from typing import Any

from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType
from PIL.Image import Image

from crab import Action, ActionOutput, BackendModel, BackendOutput, Message, MessageType
from crab.utils.common import base64_to_image, json_expand_refs

try:
import google.generativeai as genai
from google.ai.generativelanguage_v1beta import FunctionDeclaration, Part, Tool
from google.ai.generativelanguage_v1beta import (
Content,
FunctionDeclaration,
Part,
Tool,
)
from google.api_core.exceptions import ResourceExhausted
from google.generativeai.types import content_types

Expand Down Expand Up @@ -51,51 +58,70 @@ def __init__(
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
self.system_message = system_message
self.action_space = action_space
self.action_schema = self._convert_action_to_schema(self.action_space)
self.action_schema = _convert_action_to_schema(self.action_space)
self.token_usage = 0
self.chat_history = []

def chat(self, message: list[tuple[str, MessageType]]) -> BackendOutput:
# Initialize chat history
request = []
if self.history_messages_len > 0 and len(self.chat_history) > 0:
for history_message in self.chat_history[-self.history_messages_len :]:
request = request + history_message
self.chat_history: list[list[dict]] = []

if not isinstance(message, list):
def chat(self, message: list[Message] | Message) -> BackendOutput:
if isinstance(message, tuple):
message = [message]

new_message = {
"role": "user",
"parts": [self._convert_message(part) for part in message],
}
request = self.fetch_from_memory()
new_message = self.construct_new_message(message)
request.append(new_message)

response = self.call_api(request)
response_message = response.candidates[0].content
response_message = self.call_api(request)
self.record_message(new_message, response_message)
return self.generate_backend_output(response_message)

def construct_new_message(self, message: list[Message]) -> dict[str, Any]:
parts: list[str | Image] = []
for content, msg_type in message:
match msg_type:
case MessageType.TEXT:
parts.append(content)
case MessageType.IMAGE_JPG_BASE64:
parts.append(base64_to_image(content))
return {
"role": "user",
"parts": parts,
}

tool_calls = [
Part.to_dict(part)["function_call"]
for part in response.parts
if "function_call" in Part.to_dict(part)
]
def generate_backend_output(self, response_message: Content) -> BackendOutput:
tool_calls: list[ActionOutput] = []
for part in response_message.parts:
if "function_call" in Part.to_dict(part):
call = Part.to_dict(part)["function_call"]
tool_calls.append(
ActionOutput(
name=call["name"],
arguments=call["args"],
)
)

return BackendOutput(
message=response_message.parts[0].text or None,
action_list=self._convert_tool_calls_to_action_list(tool_calls),
action_list=tool_calls or None,
)

def fetch_from_memory(self) -> list[dict]:
request: list[dict] = []
if self.history_messages_len > 0:
fetch_hisotry_len = min(self.history_messages_len, len(self.chat_history))
for history_message in self.chat_history[-fetch_hisotry_len:]:
request = request + history_message
return request

def get_token_usage(self):
return self.token_usage

def record_message(self, new_message: dict, response_message: dict) -> None:
def record_message(
self, new_message: dict[str, Any], response_message: Content
) -> None:
self.chat_history.append([new_message])
self.chat_history[-1].append(
{"role": response_message.role, "parts": response_message.parts}
)

def call_api(self, request_messages: list):
def call_api(self, request_messages: list) -> Content:
while True:
try:
if self.action_schema is not None:
Expand Down Expand Up @@ -131,58 +157,41 @@ def call_api(self, request_messages: list):
break

self.token_usage += response.candidates[0].token_count
return response

@staticmethod
def _convert_message(message: tuple[str, MessageType]):
match message[1]:
case MessageType.TEXT:
return message[0]
case MessageType.IMAGE_JPG_BASE64:
return base64_to_image(message[0])

@classmethod
def _convert_action_to_schema(cls, action_space):
if action_space is None:
return None
actions = []
for action in action_space:
actions.append(Tool(function_declarations=[cls._action_to_funcdec(action)]))
return actions

@staticmethod
def _convert_tool_calls_to_action_list(tool_calls) -> list[ActionOutput]:
if tool_calls:
return [
ActionOutput(
name=call["name"],
arguments=call["args"],
)
for call in tool_calls
return response.candidates[0].content


def _convert_action_to_schema(action_space: list[Action] | None) -> list[Tool] | None:
if action_space is None:
return None
actions = [
Tool(
function_declarations=[
_action_to_funcdec(action) for action in action_space
]
else:
return None

@classmethod
def _clear_schema(cls, schema_dict: dict):
schema_dict.pop("title", None)
p_type = schema_dict.pop("type", None)
for prop in schema_dict.get("properties", {}).values():
cls._clear_schema(prop)
if p_type is not None:
schema_dict["type_"] = p_type.upper()
if "items" in schema_dict:
cls._clear_schema(schema_dict["items"])

@classmethod
def _action_to_funcdec(cls, action: Action) -> FunctionDeclaration:
"Converts crab Action to google FunctionDeclaration"
p_schema = action.parameters.model_json_schema()
if "$defs" in p_schema:
p_schema = json_expand_refs(p_schema)
cls._clear_schema(p_schema)
return FunctionDeclaration(
name=action.name,
description=action.description,
parameters=p_schema,
)
]
return actions


def _clear_schema(schema_dict: dict):
schema_dict.pop("title", None)
p_type = schema_dict.pop("type", None)
for prop in schema_dict.get("properties", {}).values():
_clear_schema(prop)
if p_type is not None:
schema_dict["type_"] = p_type.upper()
if "items" in schema_dict:
_clear_schema(schema_dict["items"])


def _action_to_funcdec(action: Action) -> FunctionDeclaration:
"Converts crab Action to google FunctionDeclaration"
p_schema = action.parameters.model_json_schema()
if "$defs" in p_schema:
p_schema = json_expand_refs(p_schema)
_clear_schema(p_schema)
return FunctionDeclaration(
name=action.name,
description=action.description,
parameters=p_schema,
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,5 @@ lint.ignore = ["E731"]
exclude = ["docs/"]

[[tool.mypy.overrides]]
module = ["dill", "easyocr"]
module = ["dill", "easyocr", "google.generativeai.*"]
ignore_missing_imports = true

0 comments on commit 06f79ee

Please sign in to comment.