From d9db9a01bf77e5aa6b015c72072df59b4eb50481 Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Sun, 15 Dec 2024 13:48:55 +1100 Subject: [PATCH] Formatting --- .../remote/inference/groq/groq_utils.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 98ecbe2f2a..2055399da8 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import warnings from typing import AsyncGenerator, Generator, Literal -import json + from groq import Stream from groq.types.chat.chat_completion import ChatCompletion from groq.types.chat.chat_completion_assistant_message_param import ( @@ -14,19 +15,19 @@ ) from groq.types.chat.chat_completion_chunk import ChatCompletionChunk from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from groq.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) from groq.types.chat.chat_completion_system_message_param import ( ChatCompletionSystemMessageParam, ) +from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam from groq.types.chat.chat_completion_user_message_param import ( ChatCompletionUserMessageParam, ) from groq.types.chat.completion_create_params import CompletionCreateParams -from groq.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, -) -from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam from groq.types.shared.function_definition import FunctionDefinition -from groq.types.shared.function_parameters import FunctionParameters + from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -38,13 +39,14 @@ Role, StopReason, ToolCall, + ToolCallDelta, + ToolCallParseStatus, ToolDefinition, ToolParamDefinition, - ToolCallParseStatus, - ToolCallDelta, ToolPromptFormat, ) + def convert_chat_completion_request( request: ChatCompletionRequest, ) -> CompletionCreateParams: @@ -85,6 +87,7 @@ def convert_chat_completion_request( tool_choice=request.tool_choice.value if request.tool_choice else None, ) + def _convert_message(message: Message) -> ChatCompletionMessageParam: if message.role == Role.system.value: return ChatCompletionSystemMessageParam(role="system", content=message.content) @@ -98,7 +101,6 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam: raise ValueError(f"Invalid message role: {message.role}") - def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict: # Groq requires a description for function tools if tool_definition.description is None: @@ -114,13 +116,11 @@ def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict: key: _convert_groq_tool_parameter(param) for key, param in tool_parameters.items() }, - ) + ), ) -def _convert_groq_tool_parameter( - tool_parameter: ToolParamDefinition -) -> dict: +def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict: param = { "type": tool_parameter.param_type, } @@ -211,7 +211,9 @@ def _event_type_generator() -> ( elif choice.delta.tool_calls: # We assume there is only one tool call per chunk, but emit a warning in case we're wrong if len(choice.delta.tool_calls) > 1: - warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.") + warnings.warn( + "Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest." + ) # We assume Groq produces fully formed tool calls for each chunk tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0]) @@ -233,6 +235,7 @@ def _event_type_generator() -> ( ) ) + def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall: return ToolCall( call_id=tool_call.id,