diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py index 357ca93e91549..8407c1d87628a 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py @@ -27,6 +27,8 @@ from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.core.llms.function_calling import FunctionCallingLLM, ToolSelection from llama_index.core.utilities.gemini_utils import merge_neighboring_same_role_messages + +from llama_index.llms.vertex.gemini_tool import GeminiToolWrapper from llama_index.llms.vertex.gemini_utils import create_gemini_client, is_gemini_model from llama_index.llms.vertex.utils import ( CHAT_MODELS, @@ -118,6 +120,10 @@ def __init__( completion_to_prompt: Optional[Callable[[str], str]] = None, pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, output_parser: Optional[BaseOutputParser] = None, + is_model_garden: bool = False, + is_chat: bool = False, + is_text: bool = False, + is_code: bool = False, ) -> None: init_vertexai(project=project, location=location, credentials=credentials) @@ -145,23 +151,35 @@ def __init__( self._safety_settings = safety_settings self._is_gemini = False self._is_chat_model = False - if model in CHAT_MODELS: + self._is_model_garden = is_model_garden + self._is_chat = is_chat + self._is_text = is_text + self._is_code = iscode + if model in CHAT_MODELS or ( + self._is_model_garden and self._is_chat and not self._is_code + ): from vertexai.language_models import ChatModel self._chat_client = ChatModel.from_pretrained(model) self._is_chat_model = True - elif model in CODE_CHAT_MODELS: + elif model in CODE_CHAT_MODELS or ( + self._is_model_garden and self._is_chat and self._is_code + ): from vertexai.language_models import CodeChatModel self._chat_client = CodeChatModel.from_pretrained(model) iscode = True self._is_chat_model = True - elif model in CODE_MODELS: + elif model in CODE_MODELS or ( + self._is_model_garden and not self._is_text and self._is_code + ): from vertexai.language_models import CodeGenerationModel self._client = CodeGenerationModel.from_pretrained(model) iscode = True - elif model in TEXT_MODELS: + elif model in TEXT_MODELS or ( + self._is_model_garden and self._is_text and not self._is_code + ): from vertexai.language_models import TextGenerationModel self._client = TextGenerationModel.from_pretrained(model) @@ -458,11 +476,17 @@ def _prepare_chat_with_tools( tool_dicts = [] for tool in tools: + if self._is_gemini: + tool = GeminiToolWrapper(tool) + metadata = tool.metadata() + else: + metadata = tool.metadata + tool_dicts.append( { - "name": tool.metadata.name, - "description": tool.metadata.description, - "parameters": tool.metadata.get_parameters_dict(), + "name": metadata.name, + "description": metadata.description, + "parameters": metadata.get_parameters_dict(), } ) @@ -503,7 +527,7 @@ def get_tool_calls_from_response( tool_selections = [] for tool_call in tool_calls: - response_dict = MessageToDict(tool_call._pb) + response_dict = MessageToDict(tool_call._raw_message._pb) if "args" not in response_dict or "name" not in response_dict: raise ValueError("Invalid tool call.") argument_dict = response_dict["args"] diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py new file mode 100644 index 0000000000000..e6656990b2814 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py @@ -0,0 +1,105 @@ +import json +from typing import Optional, Type, Any + +from llama_index.core.tools import ToolMetadata +from llama_index.core.tools.types import DefaultToolFnSchema +from pydantic import BaseModel + + +def remap_schema(schema: dict) -> dict: + """ + Remap schema to match Gemini's internal API. + """ + parameters = {} + + for key, value in schema.items(): + if key in ["title", "type", "properties", "required", "definitions"]: + parameters[key] = value + elif key == "$ref": + parameters["defs"] = value + else: + continue + + return parameters + + +class GeminiToolMetadataWrapper: + """ + The purpose of this dataclass is to represent the metadata in + a manner that is compatible with Gemini's internal APIs. The + default ToolMetadata class generates a json schema using $ref + and $def field types which break google's protocol buffer + serialization. + """ + + def __init__(self, base: ToolMetadata) -> None: + self._base = base + self._name = self._base.name + self._description = self._base.description + self._fn_schema = self._base.fn_schema + self._parameters = self.get_parameters_dict() + + fn_schema: Optional[Type[BaseModel]] = DefaultToolFnSchema + + def get_parameters_dict(self) -> dict: + parameters = {} + + if self.fn_schema is None: + parameters = { + "type": "object", + "properties": { + "input": {"title": "input query string", "type": "string"}, + }, + "required": ["input"], + } + else: + parameters = remap_schema( + { + k: v + for k, v in self.fn_schema.model_json_schema() + if k in ["type", "properties", "required", "definitions", "$defs"] + } + ) + + return parameters + + @property + def fn_schema_str(self) -> str: + """Get fn schema as string.""" + if self.fn_schema is None: + raise ValueError("fn_schema is None.") + parameters = self.get_parameters_dict() + return json.dumps(parameters) + + def __getattr__(self, item) -> Any: + match item: + case "name": + return self._name + case "description": + return self._description + case "fn_schema": + return self.fn_schema + case "parameters": + return self._parameters + case _: + raise AttributeError( + f"No attribute '{item}' found in GeminiToolMetadataWrapper" + ) + + +class GeminiToolWrapper: + """ + Wraps a base tool object to make it compatible with Gemini's + internal APIs. + """ + + def __init__(self, base_obj, *args, **kwargs) -> None: + self.base_obj = base_obj + # some stuff + + @property + def metadata(self) -> GeminiToolMetadataWrapper: + return GeminiToolMetadataWrapper(self.base_obj.metadata) + + def __getattr__(self, name) -> Any: + return getattr(self.base_obj, name) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py index dbb0a9e962fd5..904144fa344d7 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py @@ -1,7 +1,6 @@ import base64 from typing import Any, Dict, Union, Optional -from vertexai.generative_models._generative_models import SafetySettingsType -from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types +from vertexai.generative_models import SafetySetting from llama_index.core.llms import ChatMessage, MessageRole @@ -9,9 +8,7 @@ def is_gemini_model(model: str) -> bool: return model.startswith("gemini") -def create_gemini_client( - model: str, safety_settings: Optional[SafetySettingsType] -) -> Any: +def create_gemini_client(model: str, safety_settings: Optional[SafetySetting]) -> Any: from vertexai.preview.generative_models import GenerativeModel return GenerativeModel(model_name=model, safety_settings=safety_settings) @@ -46,19 +43,34 @@ def _convert_gemini_part_to_prompt(part: Union[str, Dict]) -> Part: raise ValueError("Only text and image_url types are supported!") return Part.from_image(image) - if message.content == "" and "tool_calls" in message.additional_kwargs: + if ( + message.role == MessageRole.ASSISTANT + and message.content == "" + and "tool_calls" in message.additional_kwargs + ) or (message.role == MessageRole.TOOL): tool_calls = message.additional_kwargs["tool_calls"] - parts = [ - Part._from_gapic(raw_part=gapic_content_types.Part(function_call=tool_call)) - for tool_call in tool_calls - ] - else: - raw_content = message.content + if message.role != MessageRole.TOOL: + parts = [ + Part.from_function_response(tool_call.name, tool_call.args) + for tool_call in tool_calls + ] + parts.append(Part.from_text(handle_raw_content(message))) + else: + ## this handles the case where the Gemini api properly sets the message role to tool instead of assistant + if "name" in message.additional_kwargs: + parts = [ + Part.from_function_response( + message.additional_kwargs["name"], + message.additional_kwargs.get("args", {}), + ) + ] + else: + raise ValueError("Tool name must be provided!") - if raw_content is None: - raw_content = "" - if isinstance(raw_content, str): - raw_content = [raw_content] + raw_content = handle_raw_content(message) + parts.append(_convert_gemini_part_to_prompt(part) for part in raw_content) + else: + raw_content = handle_raw_content(message) parts = [_convert_gemini_part_to_prompt(part) for part in raw_content] @@ -69,3 +81,12 @@ def _convert_gemini_part_to_prompt(part: Union[str, Dict]) -> Part: ) else: return parts + + +def handle_raw_content(message): + raw_content = message.content + if raw_content is None: + raw_content = "" + if isinstance(raw_content, str): + raw_content = [raw_content] + return raw_content diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py index f7766f5fe4631..c2b3c69a5a4f1 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py @@ -19,6 +19,8 @@ from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole +from llama_index.llms.vertex.gemini_tool import GeminiToolWrapper + CHAT_MODELS = ["chat-bison", "chat-bison-32k", "chat-bison@001"] TEXT_MODELS = [ "text-bison", @@ -56,6 +58,8 @@ def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]: def to_gemini_tools(tools) -> Any: func_list = [] for i, tool in enumerate(tools): + _gemini_tools = GeminiToolWrapper(tool) + func_name = f"func_{i}" func_name = FunctionDeclaration( name=tool["name"], diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml index 9c2893a409008..6317450c8fa1e 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml @@ -27,12 +27,14 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-vertex" readme = "README.md" -version = "0.4.2" + +version = "0.4.2+vai_fixes" [tool.poetry.dependencies] -python = ">=3.9,<4.0" -google-cloud-aiplatform = "^1.39.0" -llama-index-core = "^0.12.0" +python = ">=3.9.0,<4.0" +google-cloud-aiplatform = "^1.76.0" +pyarrow = ">=15.0.2, <16.0.0" +llama-index-core = "^0.12.1" [tool.poetry.group.dev.dependencies] ipython = "8.10.0"