Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: use google's new sdk #11801

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 119 additions & 141 deletions api/core/model_runtime/model_providers/google/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
import os
import tempfile
import time
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import Optional, Union

import google.ai.generativelanguage as glm
import google.generativeai as genai
import requests
from google.api_core import exceptions
from google.generativeai.types import ContentType, File, GenerateContentResponse
from google.generativeai.types.content_types import to_part
from google import genai
from google.genai import errors, types

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
Expand Down Expand Up @@ -98,7 +95,7 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:

return text.rstrip()

def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> types.Tool:
"""
Convert tool messages to glm tools

Expand All @@ -109,29 +106,31 @@ def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool
for tool in tools:
properties = {}
for key, value in tool.parameters.get("properties", {}).items():
properties[key] = {
"type_": glm.Type.STRING,
property_def = {
"type": "STRING",
"description": value.get("description", ""),
"enum": value.get("enum", []),
}
if "enum" in value:
property_def["enum"] = value["enum"]
properties[key] = property_def

if properties:
parameters = glm.Schema(
type=glm.Type.OBJECT,
parameters = types.Schema(
type="OBJECT",
properties=properties,
required=tool.parameters.get("required", []),
)
else:
parameters = None

function_declaration = glm.FunctionDeclaration(
functions = types.FunctionDeclaration(
name=tool.name,
parameters=parameters,
description=tool.description,
)
function_declarations.append(function_declaration)
function_declarations.append(functions)

return glm.Tool(function_declarations=function_declarations)
return types.Tool(function_declarations=function_declarations)

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Expand Down Expand Up @@ -177,42 +176,51 @@ def _generate(
try:
schema = json.loads(schema)
except:
raise exceptions.InvalidArgument("Invalid JSON Schema")
raise errors.FunctionInvocationError("Invalid JSON Schema")
if tools:
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
raise errors.FunctionInvocationError("gemini not support use Tools and JSON Schema at same time")
config_kwargs["response_schema"] = schema
config_kwargs["response_mime_type"] = "application/json"

if stop:
config_kwargs["stop_sequences"] = stop

genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)
config_kwargs["tools"] = []
if tools:
config_kwargs["tools"].append(self._convert_tools_to_glm_tool(tools))

self.client = genai.Client(api_key=credentials["google_api_key"])

history = []

for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
if history and history[-1].role == content.role:
history[-1].parts.extend(content.parts)
else:
history.append(content)

response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(**config_kwargs),
stream=stream,
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
request_options={"timeout": 600},
)

if stream:
response = self.client.models.generate_content_stream(
model=model,
contents=history,
config=types.GenerateContentConfig(**config_kwargs),
)
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)

response = self.client.models.generate_content(
model=model,
contents=history,
config=types.GenerateContentConfig(**config_kwargs),
)
return self._handle_generate_response(model, credentials, response, prompt_messages)

def _handle_generate_response(
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
self,
model: str,
credentials: dict,
response: types.GenerateContentResponse,
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
Handle llm response
Expand Down Expand Up @@ -248,7 +256,11 @@ def _handle_generate_response(
return result

def _handle_generate_stream_response(
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
self,
model: str,
credentials: dict,
response: Iterator[types.GenerateContentResponse],
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm stream response
Expand All @@ -260,56 +272,52 @@ def _handle_generate_stream_response(
:return: llm response chunk generator result
"""
index = -1
for chunk in response:
for part in chunk.parts:
assistant_prompt_message = AssistantPromptMessage(content="")

if part.text:
assistant_prompt_message.content += part.text

if part.function_call:
assistant_prompt_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=part.function_call.name,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=part.function_call.name,
arguments=json.dumps(dict(part.function_call.args.items())),
),
)
]

index += 1

if not response._done:
# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
)
else:
# calculate num tokens
if hasattr(response, "usage_metadata") and response.usage_metadata:
prompt_tokens = response.usage_metadata.prompt_token_count
completion_tokens = response.usage_metadata.candidates_token_count
else:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])

# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)

yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=str(chunk.candidates[0].finish_reason),
usage=usage,
),
)
for r in response:
assistant_prompt_message = AssistantPromptMessage(content="")
parts = r.candidates[0].content.parts
index += 1
if parts is None:
# calculate num tokens
prompt_tokens = r.usage_metadata.prompt_token_count or self.get_num_tokens(
model, credentials, prompt_messages
)
completion_tokens = r.usage_metadata.candidates_token_count or self.get_num_tokens(
model, credentials, [assistant_prompt_message]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=str(r.candidates[0].finish_reason),
usage=usage,
),
)

else:
for part in parts:
if part.text:
assistant_prompt_message.content += part.text
elif part.function_call:
assistant_prompt_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=part.function_call.name,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=part.function_call.name,
arguments=json.dumps(dict(part.function_call.args.items())),
),
)
]
# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
)

def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Expand All @@ -336,11 +344,11 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:

return message_text

def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> types.File:
key = f"{message_content.type.value}:{hash(message_content.data)}"
if redis_client.exists(key):
try:
return genai.get_file(redis_client.get(key).decode())
return self.client.files.get(name=redis_client.get(key).decode())
except:
pass
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
Expand All @@ -356,10 +364,10 @@ def _upload_file_content_to_google(self, message_content: PromptMessageContent)
raise ValueError(f"Failed to fetch data from url {message_content.url}, {ex}")
temp_file.flush()

file = genai.upload_file(path=temp_file.name, mime_type=message_content.mime_type)
while file.state.name == "PROCESSING":
file = self.client.files.upload(path=temp_file.name, config={"mime_type": message_content.mime_type})
while file.state == "PROCESSING":
time.sleep(5)
file = genai.get_file(file.name)
file = self.client.files.get(name=file.name)
# google will delete your upload files in 2 days.
redis_client.setex(key, 47 * 60 * 60, file.name)

Expand All @@ -370,94 +378,64 @@ def _upload_file_content_to_google(self, message_content: PromptMessageContent)
pass
return file

def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
def _format_message_to_glm_content(self, message: PromptMessage) -> types.Content:
"""
Format a single message into glm.Content for Google API

:param message: one PromptMessage
:return: glm Content representation of message
"""
if isinstance(message, UserPromptMessage):
glm_content = {"role": "user", "parts": []}
glm_content = types.Content(role="user", parts=[])
if isinstance(message.content, str):
glm_content["parts"].append(to_part(message.content))
glm_content.parts.append(types.Part.from_text(message.content))
else:
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content["parts"].append(to_part(c.data))
glm_content.parts.append(types.Part.from_text(c.data))
else:
glm_content["parts"].append(self._upload_file_content_to_google(c))
f = self._upload_file_content_to_google(c)
glm_content.parts.append(types.Part.from_uri(file_uri=f.uri, mime_type=f.mime_type))

return glm_content
elif isinstance(message, AssistantPromptMessage):
glm_content = {"role": "model", "parts": []}
glm_content = types.Content(role="model", parts=[])
if message.content:
glm_content["parts"].append(to_part(message.content))
glm_content.parts.append(types.Part.from_text(message.content))
if message.tool_calls:
glm_content["parts"].append(
to_part(
glm.FunctionCall(
name=message.tool_calls[0].function.name,
args=json.loads(message.tool_calls[0].function.arguments),
)
glm_content.parts.append(
types.Part.from_function_call(
name=message.tool_calls[0].function.name,
args=json.loads(message.tool_calls[0].function.arguments),
)
)
return glm_content
elif isinstance(message, SystemPromptMessage):
return {"role": "user", "parts": [to_part(message.content)]}
return types.Content(role="user", parts=[types.Part.from_text(message.content)])
elif isinstance(message, ToolPromptMessage):
return {
"role": "function",
"parts": [
glm.Part(
function_response=glm.FunctionResponse(
name=message.name, response={"response": message.content}
)
)
],
}
return types.Content(
role="function",
parts=[types.Part.from_function_response(name=message.name, response={"response": message.content})],
)
else:
raise ValueError(f"Got unknown type {message}")

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
which needs to be converted into a unified error type for the caller.

:return: Invoke emd = genai.GenerativeModel(model) error mapping
"""
return {
InvokeConnectionError: [exceptions.RetryError],
InvokeConnectionError: [errors.APIError],
InvokeServerUnavailableError: [
exceptions.ServiceUnavailable,
exceptions.InternalServerError,
exceptions.BadGateway,
exceptions.GatewayTimeout,
exceptions.DeadlineExceeded,
],
InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests],
InvokeAuthorizationError: [
exceptions.Unauthenticated,
exceptions.PermissionDenied,
exceptions.Unauthenticated,
exceptions.Forbidden,
errors.ServerError,
],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [
exceptions.BadRequest,
exceptions.InvalidArgument,
exceptions.FailedPrecondition,
exceptions.OutOfRange,
exceptions.NotFound,
exceptions.MethodNotAllowed,
exceptions.Conflict,
exceptions.AlreadyExists,
exceptions.Aborted,
exceptions.LengthRequired,
exceptions.PreconditionFailed,
exceptions.RequestRangeNotSatisfiable,
exceptions.Cancelled,
errors.ClientError,
errors.UnkownFunctionCallArgumentError,
errors.UnsupportedFunctionError,
errors.FunctionInvocationError,
],
}
Loading
Loading