From 0a0072bd14453e1c322e7c5e68903e77bb9b2b35 Mon Sep 17 00:00:00 2001 From: Steven Fines Date: Fri, 1 Nov 2024 16:51:00 -0700 Subject: [PATCH 1/4] (chore) Update to 1.71.1 Google updated the cloud-aiplatform dependencies, so in a probably futile attempt to keep them somewhat current I have updated them. --- .../llama_index/llms/vertex/base.py | 2 +- .../llama_index/llms/vertex/gemini_utils.py | 43 +++++++++++++------ .../llama-index-llms-vertex/pyproject.toml | 4 +- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py index 23db7a6ffc4ab..2834295d408ea 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py @@ -495,7 +495,7 @@ def get_tool_calls_from_response( tool_selections = [] for tool_call in tool_calls: - response_dict = MessageToDict(tool_call._pb) + response_dict = MessageToDict(tool_call._raw_message._pb) if "args" not in response_dict or "name" not in response_dict: raise ValueError("Invalid tool call.") argument_dict = response_dict["args"] diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py index dbb0a9e962fd5..4924c5f48f1c1 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py @@ -1,7 +1,6 @@ import base64 from typing import Any, Dict, Union, Optional -from vertexai.generative_models._generative_models import SafetySettingsType -from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types +from vertexai.generative_models import SafetySetting from llama_index.core.llms import ChatMessage, MessageRole @@ -10,7 +9,7 @@ def is_gemini_model(model: str) -> bool: def create_gemini_client( - model: str, safety_settings: Optional[SafetySettingsType] + model: str, safety_settings: Optional[SafetySetting] ) -> Any: from vertexai.preview.generative_models import GenerativeModel @@ -46,19 +45,26 @@ def _convert_gemini_part_to_prompt(part: Union[str, Dict]) -> Part: raise ValueError("Only text and image_url types are supported!") return Part.from_image(image) - if message.content == "" and "tool_calls" in message.additional_kwargs: + if (MessageRole.ASSISTANT == message.role and message.content == "" and "tool_calls" in message.additional_kwargs)\ + or (MessageRole.TOOL == message.role): tool_calls = message.additional_kwargs["tool_calls"] - parts = [ - Part._from_gapic(raw_part=gapic_content_types.Part(function_call=tool_call)) - for tool_call in tool_calls - ] - else: - raw_content = message.content + if message.role != MessageRole.TOOL: + parts = [ + Part.from_function_response(tool_call.name, tool_call.args) + for tool_call in tool_calls + ] + parts.append(Part.from_text(handle_raw_content(message))) + else: + ## this handles the case where the Gemini api properly sets the message role to tool instead of assistant + if 'name' in message.additional_kwargs: + parts = [Part.from_function_response(message.additional_kwargs['name'], message.additional_kwargs.get('args', {}))] + else: + raise ValueError("Tool name must be provided!") - if raw_content is None: - raw_content = "" - if isinstance(raw_content, str): - raw_content = [raw_content] + raw_content = handle_raw_content(message) + parts.append(_convert_gemini_part_to_prompt(part) for part in raw_content) + else: + raw_content = handle_raw_content(message) parts = [_convert_gemini_part_to_prompt(part) for part in raw_content] @@ -69,3 +75,12 @@ def _convert_gemini_part_to_prompt(part: Union[str, Dict]) -> Part: ) else: return parts + + +def handle_raw_content(message): + raw_content = message.content + if raw_content is None: + raw_content = "" + if isinstance(raw_content, str): + raw_content = [raw_content] + return raw_content diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml index cf188c2b3deec..8f87524f7f549 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml @@ -27,11 +27,11 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-vertex" readme = "README.md" -version = "0.3.7" +version = "0.3.7+1_71_1_gai_1" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -google-cloud-aiplatform = "^1.39.0" +google-cloud-aiplatform = "^1.71.1" pyarrow = "^15.0.2" llama-index-core = "^0.11.0" From dd092854c74855e7841c001c21d49b60befd5bb2 Mon Sep 17 00:00:00 2001 From: Steven Fines Date: Thu, 7 Nov 2024 11:04:56 -0800 Subject: [PATCH 2/4] (feat) Wrapped the BaseTool for Gemini The BaseTool object serializes things using pydantic but this a json schema objects which do not currently play well with protobuf. This version is an attempt to handle that issue. --- .../llama_index/llms/vertex/base.py | 14 ++- .../llama_index/llms/vertex/gemini_tool.py | 91 +++++++++++++++++++ .../llama_index/llms/vertex/gemini_utils.py | 20 ++-- .../llama_index/llms/vertex/utils.py | 4 + .../llama-index-llms-vertex/pyproject.toml | 2 +- 5 files changed, 120 insertions(+), 11 deletions(-) create mode 100644 llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py index 2834295d408ea..b1adbac8261f5 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py @@ -27,6 +27,8 @@ from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.core.llms.function_calling import FunctionCallingLLM, ToolSelection from llama_index.core.utilities.gemini_utils import merge_neighboring_same_role_messages + +from llama_index.llms.vertex.gemini_tool import GeminiToolWrapper from llama_index.llms.vertex.gemini_utils import create_gemini_client, is_gemini_model from llama_index.llms.vertex.utils import ( CHAT_MODELS, @@ -450,11 +452,17 @@ def _prepare_chat_with_tools( tool_dicts = [] for tool in tools: + if self._is_gemini: + tool = GeminiToolWrapper(tool) + metadata = tool.metadata() + else: + metadata = tool.metadata + tool_dicts.append( { - "name": tool.metadata.name, - "description": tool.metadata.description, - "parameters": tool.metadata.get_parameters_dict(), + "name": metadata.name, + "description": metadata.description, + "parameters": metadata.get_parameters_dict(), } ) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py new file mode 100644 index 0000000000000..ffdbc80e5c6a1 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py @@ -0,0 +1,91 @@ +from typing import Optional, Type, Any + +from llama_index.core.tools import ToolMetadata +from llama_index.core.tools.types import DefaultToolFnSchema +from pydantic import BaseModel + + +def remap_schema(schema: dict) -> dict: + """ + Remap schema to match Gemini's internal API. + """ + parameters = {} + + for key, value in schema.items(): + if key in ["title", "type", "properties", "required", "definitions"]: + parameters[key] = value + elif key == "$ref": + parameters["defs"] = value + else: + continue + + return parameters + + +class GeminiToolMetadataWrapper: + """ + The purpose of this dataclass is to represent the metadata in + a manner that is compatible with Gemini's internal APIs. The + default ToolMetadata class generates a json schema using $ref + and $def field types which break google's protocol buffer + serialization. + """ + + def __init__(self, base: ToolMetadata) -> None: + self._base = base + self._name = self._base.name + self._description = self._base.description + self._fn_schema = self._base.fn_schema + self._parameters = self.get_parameters_dict() + + fn_schema: Optional[Type[BaseModel]] = DefaultToolFnSchema + + def get_parameters_dict(self) -> dict: + parameters = {} + + if self.fn_schema is None: + parameters = { + "type": "object", + "properties": { + "input": {"title": "input query string", "type": "string"}, + }, + "required": ["input"], + } + else: + parameters = remap_schema( + self.fn_schema.model_json_schema(ref_template="#/defs/{model}") + ) + + return parameters + + def __getattr__(self, item) -> Any: + match item: + case "name": + return self._name + case "description": + return self._description + case "fn_schema": + return self.fn_schema + case "parameters": + return self._parameters + case _: + raise AttributeError( + f"No attribute '{item}' found in GeminiToolMetadataWrapper" + ) + + +class GeminiToolWrapper: + """ + Wraps a base tool object to make it compatible with Gemini's + internal APIs. + """ + + def __init__(self, base_obj, *args, **kwargs) -> None: + self.base_obj = base_obj + # some stuff + + def metadata(self) -> GeminiToolMetadataWrapper: + return GeminiToolMetadataWrapper(self.base_obj.metadata) + + def __getattr__(self, name) -> Any: + return getattr(self.base_obj, name) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py index 4924c5f48f1c1..904144fa344d7 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_utils.py @@ -8,9 +8,7 @@ def is_gemini_model(model: str) -> bool: return model.startswith("gemini") -def create_gemini_client( - model: str, safety_settings: Optional[SafetySetting] -) -> Any: +def create_gemini_client(model: str, safety_settings: Optional[SafetySetting]) -> Any: from vertexai.preview.generative_models import GenerativeModel return GenerativeModel(model_name=model, safety_settings=safety_settings) @@ -45,8 +43,11 @@ def _convert_gemini_part_to_prompt(part: Union[str, Dict]) -> Part: raise ValueError("Only text and image_url types are supported!") return Part.from_image(image) - if (MessageRole.ASSISTANT == message.role and message.content == "" and "tool_calls" in message.additional_kwargs)\ - or (MessageRole.TOOL == message.role): + if ( + message.role == MessageRole.ASSISTANT + and message.content == "" + and "tool_calls" in message.additional_kwargs + ) or (message.role == MessageRole.TOOL): tool_calls = message.additional_kwargs["tool_calls"] if message.role != MessageRole.TOOL: parts = [ @@ -56,8 +57,13 @@ def _convert_gemini_part_to_prompt(part: Union[str, Dict]) -> Part: parts.append(Part.from_text(handle_raw_content(message))) else: ## this handles the case where the Gemini api properly sets the message role to tool instead of assistant - if 'name' in message.additional_kwargs: - parts = [Part.from_function_response(message.additional_kwargs['name'], message.additional_kwargs.get('args', {}))] + if "name" in message.additional_kwargs: + parts = [ + Part.from_function_response( + message.additional_kwargs["name"], + message.additional_kwargs.get("args", {}), + ) + ] else: raise ValueError("Tool name must be provided!") diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py index f7766f5fe4631..c2b3c69a5a4f1 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/utils.py @@ -19,6 +19,8 @@ from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole +from llama_index.llms.vertex.gemini_tool import GeminiToolWrapper + CHAT_MODELS = ["chat-bison", "chat-bison-32k", "chat-bison@001"] TEXT_MODELS = [ "text-bison", @@ -56,6 +58,8 @@ def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]: def to_gemini_tools(tools) -> Any: func_list = [] for i, tool in enumerate(tools): + _gemini_tools = GeminiToolWrapper(tool) + func_name = f"func_{i}" func_name = FunctionDeclaration( name=tool["name"], diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml index 8f87524f7f549..7a14160e0fa5e 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-vertex" readme = "README.md" -version = "0.3.7+1_71_1_gai_1" +version = "0.3.7+1_71_1_gai_8" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" From f479a98621eaf88e00ef031acf2254c4d5a777ea Mon Sep 17 00:00:00 2001 From: Steven Fines Date: Fri, 8 Nov 2024 15:21:01 -0800 Subject: [PATCH 3/4] (bug) #16625 Address that Plan object is not compatible with Protobuf This changes how plans are serialized so that they are compatible with protobuf. It's a workaround until Google addresses the issue in google-cloud-aiplatform. --- .../llama_index/llms/vertex/gemini_tool.py | 16 +++++++++++++++- .../llms/llama-index-llms-vertex/pyproject.toml | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py index ffdbc80e5c6a1..e6656990b2814 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/gemini_tool.py @@ -1,3 +1,4 @@ +import json from typing import Optional, Type, Any from llama_index.core.tools import ToolMetadata @@ -53,11 +54,23 @@ def get_parameters_dict(self) -> dict: } else: parameters = remap_schema( - self.fn_schema.model_json_schema(ref_template="#/defs/{model}") + { + k: v + for k, v in self.fn_schema.model_json_schema() + if k in ["type", "properties", "required", "definitions", "$defs"] + } ) return parameters + @property + def fn_schema_str(self) -> str: + """Get fn schema as string.""" + if self.fn_schema is None: + raise ValueError("fn_schema is None.") + parameters = self.get_parameters_dict() + return json.dumps(parameters) + def __getattr__(self, item) -> Any: match item: case "name": @@ -84,6 +97,7 @@ def __init__(self, base_obj, *args, **kwargs) -> None: self.base_obj = base_obj # some stuff + @property def metadata(self) -> GeminiToolMetadataWrapper: return GeminiToolMetadataWrapper(self.base_obj.metadata) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml index 7a14160e0fa5e..235b81f23311e 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-vertex" readme = "README.md" -version = "0.3.7+1_71_1_gai_8" +version = "0.3.8.7+model_garden" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" From 693aff1352e4e065268bb3f2159b4c6a609338ae Mon Sep 17 00:00:00 2001 From: Steven Fines Date: Fri, 6 Dec 2024 11:00:51 -0800 Subject: [PATCH 4/4] (feat) Add feature flags for model garden Adds a feature flag for model garden. --- .../llama_index/llms/vertex/base.py | 24 +++++++++++++++---- .../llama-index-llms-vertex/pyproject.toml | 8 +++---- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py index b1adbac8261f5..610e5a7bd33e5 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py +++ b/llama-index-integrations/llms/llama-index-llms-vertex/llama_index/llms/vertex/base.py @@ -115,6 +115,10 @@ def __init__( completion_to_prompt: Optional[Callable[[str], str]] = None, pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, output_parser: Optional[BaseOutputParser] = None, + is_model_garden: bool = False, + is_chat: bool = False, + is_text: bool = False, + is_code: bool = False, ) -> None: init_vertexai(project=project, location=location, credentials=credentials) @@ -141,23 +145,35 @@ def __init__( self._safety_settings = safety_settings self._is_gemini = False self._is_chat_model = False - if model in CHAT_MODELS: + self._is_model_garden = is_model_garden + self._is_chat = is_chat + self._is_text = is_text + self._is_code = iscode + if model in CHAT_MODELS or ( + self._is_model_garden and self._is_chat and not self._is_code + ): from vertexai.language_models import ChatModel self._chat_client = ChatModel.from_pretrained(model) self._is_chat_model = True - elif model in CODE_CHAT_MODELS: + elif model in CODE_CHAT_MODELS or ( + self._is_model_garden and self._is_chat and self._is_code + ): from vertexai.language_models import CodeChatModel self._chat_client = CodeChatModel.from_pretrained(model) iscode = True self._is_chat_model = True - elif model in CODE_MODELS: + elif model in CODE_MODELS or ( + self._is_model_garden and not self._is_text and self._is_code + ): from vertexai.language_models import CodeGenerationModel self._client = CodeGenerationModel.from_pretrained(model) iscode = True - elif model in TEXT_MODELS: + elif model in TEXT_MODELS or ( + self._is_model_garden and self._is_text and not self._is_code + ): from vertexai.language_models import TextGenerationModel self._client = TextGenerationModel.from_pretrained(model) diff --git a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml index 2916707b04155..ef2042ee807dc 100644 --- a/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-vertex/pyproject.toml @@ -27,15 +27,13 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-vertex" readme = "README.md" - version = "0.4.0+vai_fixes" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -google-cloud-aiplatform = "^1.71.1" - -pyarrow = "^15.0.2" -llama-index-core = "^0.12.0" +google-cloud-aiplatform = "^1.73.0" +pyarrow = ">=15.0.2, <16.0.0" +llama-index-core = "^0.12.1" [tool.poetry.group.dev.dependencies] ipython = "8.10.0"