Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to the latest google-cloud-aiplatform dependency and properly handle function calls in gemini #16793

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
0a0072b
(chore) Update to 1.71.1
stfines-clgx Nov 1, 2024
1da6066
Merge branch 'run-llama:main' into dev/1_70_protobuf_only
stfines-clgx Nov 1, 2024
cf7b348
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Nov 2, 2024
1b77e75
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Nov 4, 2024
dd09285
(feat) Wrapped the BaseTool for Gemini
stfines-clgx Nov 7, 2024
2545dc3
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Nov 7, 2024
5190d6f
Merge branch 'dev/1_70_protobuf_only' of github.com:stfines-clgx/llam…
stfines-clgx Nov 7, 2024
f92fbd9
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Nov 8, 2024
f479a98
(bug) #16625 Address that Plan object is not compatible with Protobuf
stfines-clgx Nov 8, 2024
b588cf0
Merge remote-tracking branch 'origin/dev/1_70_protobuf_only' into dev…
stfines-clgx Nov 8, 2024
6214e5f
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Nov 8, 2024
b3bda1a
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Nov 8, 2024
8b83300
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Nov 11, 2024
46839d7
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Nov 18, 2024
42cfcfd
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Nov 22, 2024
1b36d96
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Dec 6, 2024
693aff1
(feat) Add feature flags for model garden
stfines-clgx Dec 6, 2024
5d80b1f
Merge Commit
stfines-clgx Dec 6, 2024
2aeec3e
Merge branch 'run-llama:main' into dev/1_70_protobuf_only
stfines-clgx Dec 10, 2024
8034a99
Merge branch 'main' into dev/1_70_protobuf_only
stfines-clgx Jan 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.core.llms.function_calling import FunctionCallingLLM, ToolSelection
from llama_index.core.utilities.gemini_utils import merge_neighboring_same_role_messages

from llama_index.llms.vertex.gemini_tool import GeminiToolWrapper
from llama_index.llms.vertex.gemini_utils import create_gemini_client, is_gemini_model
from llama_index.llms.vertex.utils import (
CHAT_MODELS,
Expand Down Expand Up @@ -118,6 +120,10 @@ def __init__(
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
is_model_garden: bool = False,
is_chat: bool = False,
is_text: bool = False,
is_code: bool = False,
) -> None:
init_vertexai(project=project, location=location, credentials=credentials)

Expand Down Expand Up @@ -145,23 +151,35 @@ def __init__(
self._safety_settings = safety_settings
self._is_gemini = False
self._is_chat_model = False
if model in CHAT_MODELS:
self._is_model_garden = is_model_garden
self._is_chat = is_chat
self._is_text = is_text
self._is_code = iscode
if model in CHAT_MODELS or (
self._is_model_garden and self._is_chat and not self._is_code
):
from vertexai.language_models import ChatModel

self._chat_client = ChatModel.from_pretrained(model)
self._is_chat_model = True
elif model in CODE_CHAT_MODELS:
elif model in CODE_CHAT_MODELS or (
self._is_model_garden and self._is_chat and self._is_code
):
from vertexai.language_models import CodeChatModel

self._chat_client = CodeChatModel.from_pretrained(model)
iscode = True
self._is_chat_model = True
elif model in CODE_MODELS:
elif model in CODE_MODELS or (
self._is_model_garden and not self._is_text and self._is_code
):
from vertexai.language_models import CodeGenerationModel

self._client = CodeGenerationModel.from_pretrained(model)
iscode = True
elif model in TEXT_MODELS:
elif model in TEXT_MODELS or (
self._is_model_garden and self._is_text and not self._is_code
):
from vertexai.language_models import TextGenerationModel

self._client = TextGenerationModel.from_pretrained(model)
Expand Down Expand Up @@ -458,11 +476,17 @@ def _prepare_chat_with_tools(

tool_dicts = []
for tool in tools:
if self._is_gemini:
tool = GeminiToolWrapper(tool)
metadata = tool.metadata()
else:
metadata = tool.metadata

tool_dicts.append(
{
"name": tool.metadata.name,
"description": tool.metadata.description,
"parameters": tool.metadata.get_parameters_dict(),
"name": metadata.name,
"description": metadata.description,
"parameters": metadata.get_parameters_dict(),
}
)

Expand Down Expand Up @@ -503,7 +527,7 @@ def get_tool_calls_from_response(

tool_selections = []
for tool_call in tool_calls:
response_dict = MessageToDict(tool_call._pb)
response_dict = MessageToDict(tool_call._raw_message._pb)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no way to do this without accessing private attributes hey? 😅 Oh google

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, in another branch, I did it by just creating an entirely new object by hand. I kept it this way to minimize the change surface area. I can move over to the other change style if it is an issue, but yeah- their MessageToDict function doesn't work without private access. Kind of silly; but then again, I'm not entirely sure that google makes apis for use outside of google.

if "args" not in response_dict or "name" not in response_dict:
raise ValueError("Invalid tool call.")
argument_dict = response_dict["args"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import json
from typing import Optional, Type, Any

from llama_index.core.tools import ToolMetadata
from llama_index.core.tools.types import DefaultToolFnSchema
from pydantic import BaseModel


def remap_schema(schema: dict) -> dict:
"""
Remap schema to match Gemini's internal API.
"""
parameters = {}

for key, value in schema.items():
if key in ["title", "type", "properties", "required", "definitions"]:
parameters[key] = value
elif key == "$ref":
parameters["defs"] = value
else:
continue

return parameters


class GeminiToolMetadataWrapper:
"""
The purpose of this dataclass is to represent the metadata in
a manner that is compatible with Gemini's internal APIs. The
default ToolMetadata class generates a json schema using $ref
and $def field types which break google's protocol buffer
serialization.
"""

def __init__(self, base: ToolMetadata) -> None:
self._base = base
self._name = self._base.name
self._description = self._base.description
self._fn_schema = self._base.fn_schema
self._parameters = self.get_parameters_dict()

fn_schema: Optional[Type[BaseModel]] = DefaultToolFnSchema

def get_parameters_dict(self) -> dict:
parameters = {}

if self.fn_schema is None:
parameters = {
"type": "object",
"properties": {
"input": {"title": "input query string", "type": "string"},
},
"required": ["input"],
}
else:
parameters = remap_schema(
{
k: v
for k, v in self.fn_schema.model_json_schema()
if k in ["type", "properties", "required", "definitions", "$defs"]
}
)

return parameters

@property
def fn_schema_str(self) -> str:
"""Get fn schema as string."""
if self.fn_schema is None:
raise ValueError("fn_schema is None.")
parameters = self.get_parameters_dict()
return json.dumps(parameters)

def __getattr__(self, item) -> Any:
match item:
case "name":
return self._name
case "description":
return self._description
case "fn_schema":
return self.fn_schema
case "parameters":
return self._parameters
case _:
raise AttributeError(
f"No attribute '{item}' found in GeminiToolMetadataWrapper"
)


class GeminiToolWrapper:
"""
Wraps a base tool object to make it compatible with Gemini's
internal APIs.
"""

def __init__(self, base_obj, *args, **kwargs) -> None:
self.base_obj = base_obj
# some stuff

@property
def metadata(self) -> GeminiToolMetadataWrapper:
return GeminiToolMetadataWrapper(self.base_obj.metadata)

def __getattr__(self, name) -> Any:
return getattr(self.base_obj, name)
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import base64
from typing import Any, Dict, Union, Optional
from vertexai.generative_models._generative_models import SafetySettingsType
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
from vertexai.generative_models import SafetySetting
from llama_index.core.llms import ChatMessage, MessageRole


def is_gemini_model(model: str) -> bool:
return model.startswith("gemini")


def create_gemini_client(
model: str, safety_settings: Optional[SafetySettingsType]
) -> Any:
def create_gemini_client(model: str, safety_settings: Optional[SafetySetting]) -> Any:
from vertexai.preview.generative_models import GenerativeModel

return GenerativeModel(model_name=model, safety_settings=safety_settings)
Expand Down Expand Up @@ -46,19 +43,34 @@ def _convert_gemini_part_to_prompt(part: Union[str, Dict]) -> Part:
raise ValueError("Only text and image_url types are supported!")
return Part.from_image(image)

if message.content == "" and "tool_calls" in message.additional_kwargs:
if (
message.role == MessageRole.ASSISTANT
and message.content == ""
and "tool_calls" in message.additional_kwargs
) or (message.role == MessageRole.TOOL):
tool_calls = message.additional_kwargs["tool_calls"]
parts = [
Part._from_gapic(raw_part=gapic_content_types.Part(function_call=tool_call))
for tool_call in tool_calls
]
else:
raw_content = message.content
if message.role != MessageRole.TOOL:
parts = [
Part.from_function_response(tool_call.name, tool_call.args)
for tool_call in tool_calls
]
parts.append(Part.from_text(handle_raw_content(message)))
else:
## this handles the case where the Gemini api properly sets the message role to tool instead of assistant
if "name" in message.additional_kwargs:
parts = [
Part.from_function_response(
message.additional_kwargs["name"],
message.additional_kwargs.get("args", {}),
)
]
else:
raise ValueError("Tool name must be provided!")

if raw_content is None:
raw_content = ""
if isinstance(raw_content, str):
raw_content = [raw_content]
raw_content = handle_raw_content(message)
parts.append(_convert_gemini_part_to_prompt(part) for part in raw_content)
else:
raw_content = handle_raw_content(message)

parts = [_convert_gemini_part_to_prompt(part) for part in raw_content]

Expand All @@ -69,3 +81,12 @@ def _convert_gemini_part_to_prompt(part: Union[str, Dict]) -> Part:
)
else:
return parts


def handle_raw_content(message):
raw_content = message.content
if raw_content is None:
raw_content = ""
if isinstance(raw_content, str):
raw_content = [raw_content]
return raw_content
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole

from llama_index.llms.vertex.gemini_tool import GeminiToolWrapper

CHAT_MODELS = ["chat-bison", "chat-bison-32k", "chat-bison@001"]
TEXT_MODELS = [
"text-bison",
Expand Down Expand Up @@ -56,6 +58,8 @@ def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
def to_gemini_tools(tools) -> Any:
func_list = []
for i, tool in enumerate(tools):
_gemini_tools = GeminiToolWrapper(tool)

func_name = f"func_{i}"
func_name = FunctionDeclaration(
name=tool["name"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-vertex"
readme = "README.md"
version = "0.4.2"

version = "0.4.2+vai_fixes"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
google-cloud-aiplatform = "^1.39.0"
llama-index-core = "^0.12.0"
python = ">=3.9.0,<4.0"
google-cloud-aiplatform = "^1.76.0"
pyarrow = ">=15.0.2, <16.0.0"
llama-index-core = "^0.12.1"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
Expand Down