diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 13d463ad8a..b250eab6d1 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -141,6 +141,15 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="groq", + pip_packages=["groq"], + module="llama_stack.providers.remote.inference.groq", + config_class="llama_stack.providers.remote.inference.groq.GroqConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/groq/__init__.py b/llama_stack/providers/remote/inference/groq/__init__.py new file mode 100644 index 0000000000..1e3073ab02 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/__init__.py @@ -0,0 +1,15 @@ + +from llama_stack.apis.inference import Inference + +from .config import GroqConfig +from .groq import GroqInferenceAdapter + +async def get_adapter_impl(config: GroqConfig, _deps) -> Inference: + # import dynamically so `llama stack build` does not fail due to missing dependencies + from .groq import GroqInferenceAdapter + + if not isinstance(config, GroqConfig): + raise RuntimeError(f"Unexpected config type: {type(config)}") + + adapter = GroqInferenceAdapter(config) + return adapter diff --git a/llama_stack/providers/remote/inference/groq/config.py b/llama_stack/providers/remote/inference/groq/config.py new file mode 100644 index 0000000000..5d95b9727c --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/config.py @@ -0,0 +1,13 @@ +import os +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class GroqConfig(BaseModel): + api_key: Optional[str] = Field( + default_factory=lambda: os.getenv("GROQ_API_KEY"), + description="The Groq API key", + ) \ No newline at end of file diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py new file mode 100644 index 0000000000..b7e5a00b30 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -0,0 +1,144 @@ +import warnings +from typing import AsyncIterator, List, Optional, Union, AsyncGenerator +import json +from .groq_utils import ( + convert_chat_completion_request, + convert_chat_completion_response, + convert_chat_completion_response_stream, +) +from llama_models.llama3.api.datatypes import ( + InterleavedTextMedia, + Message, + ToolChoice, + ToolDefinition, + ToolPromptFormat, + SamplingStrategy, + ToolParamDefinition, +) +from llama_models.datatypes import SamplingParams +from llama_models.sku_list import CoreModelId +from llama_models.llama3.api.datatypes import StopReason +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseStreamChunk, + ChatCompletionResponseEventType, + CompletionResponse, + CompletionMessage, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + ResponseFormat, + ToolCallDelta, + ToolCall, + ToolCallParseStatus, +) +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + build_model_alias_with_just_provider_model_id, + ModelRegistryHelper, +) +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from groq import Groq +from groq.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) + +_MODEL_ALIASES = [ + build_model_alias( + "llama3-8b-8192", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias_with_just_provider_model_id( + "llama-3.1-8b-instant", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "llama3-70b-8192", + CoreModelId.llama3_70b_instruct.value, + ), + build_model_alias( + "llama-3.3-70b-versatile", + CoreModelId.llama3_3_70b_instruct.value, + ), + # Groq only contains a preview version for llama-3.2-3b + # Preview models aren't recommended for production use, but we include this one + # to pass the test fixture + # TODO(aidand): Replace this with a stable model once Groq supports it + build_model_alias( + "llama-3.2-3b-preview", + CoreModelId.llama3_2_3b_instruct.value, + ), +] + +class GroqInferenceAdapter(Inference, ModelRegistryHelper): + _client: Groq + + def __init__(self, config: GroqConfig): + ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) + self._client = Groq(api_key=config.api_key) + + def completion( + self, + model_id: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: + # Groq doesn't support non-chat completion as of time of writing + raise NotImplementedError() + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ + ToolPromptFormat + ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + + if model_id == "llama-3.2-3b-preview": + warnings.warn( + "Groq only contains a preview version for llama-3.2-3b-instruct. " + "Preview models aren't recommended for production use. " + "They can be discontinued on short notice." + ) + + model_id = self.get_provider_model_id(model_id) + request = convert_chat_completion_request( + request=ChatCompletionRequest( + model=model_id, + messages=messages, + sampling_params=sampling_params, + response_format=response_format, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + ) + + response = self._client.chat.completions.create(**request) + + if stream: + return convert_chat_completion_response_stream(response) + else: + return convert_chat_completion_response(response) + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py new file mode 100644 index 0000000000..dc1fd5bf83 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -0,0 +1,154 @@ +import warnings +from typing import Literal, AsyncGenerator, Generator + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + Message, + Role, + ChatCompletionResponse, + CompletionMessage, + StopReason, + ChatCompletionResponseStreamChunk, + ChatCompletionResponseEventType, + ChatCompletionResponseEvent, +) + +from groq.types.chat.completion_create_params import CompletionCreateParams +from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from groq.types.chat.chat_completion_system_message_param import ( + ChatCompletionSystemMessageParam, +) +from groq.types.chat.chat_completion_user_message_param import ( + ChatCompletionUserMessageParam, +) +from groq.types.chat.chat_completion_assistant_message_param import ( + ChatCompletionAssistantMessageParam, +) +from groq.types.chat.chat_completion import ChatCompletion +from groq.types.chat.chat_completion_chunk import ChatCompletionChunk +from groq import Stream + + +def convert_chat_completion_request( + request: ChatCompletionRequest, +) -> CompletionCreateParams: + """ + Convert a ChatCompletionRequest to a Groq API-compatible dictionary. + Warns client if request contains unsupported features. + """ + + if request.logprobs: + # Groq doesn't support logprobs at the time of writing + warnings.warn("logprobs are not supported yet") + + if request.response_format: + # Groq's JSON mode is beta at the time of writing + warnings.warn("response_format is not supported yet") + + if request.sampling_params.repetition_penalty: + # groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty + # seem to have different semantics + # frequency_penalty defaults to 0 is a float between -2.0 and 2.0 + # repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0 + # so we exclude it for now + warnings.warn("repetition_penalty is not supported") + + if request.tools: + warnings.warn("tools are not supported yet") + + return CompletionCreateParams( + model=request.model, + messages=[_convert_message(message) for message in request.messages], + logprobs=None, + frequency_penalty=None, + stream=request.stream, + # Groq only supports n=1 at the time of writing + n=1, + max_tokens=request.sampling_params.max_tokens or None, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + ) + + +def _convert_message(message: Message) -> ChatCompletionMessageParam: + if message.role == Role.system.value: + return ChatCompletionSystemMessageParam(role="system", content=message.content) + elif message.role == Role.user.value: + return ChatCompletionUserMessageParam(role="user", content=message.content) + elif message.role == Role.assistant.value: + return ChatCompletionAssistantMessageParam( + role="assistant", content=message.content + ) + else: + raise ValueError(f"Invalid message role: {message.role}") + + +def convert_chat_completion_response( + response: ChatCompletion, +) -> ChatCompletionResponse: + # groq only supports n=1 at time of writing, so there is only one choice + choice = response.choices[0] + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content, + stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), + ), + ) + +def _map_finish_reason_to_stop_reason(finish_reason: Literal["stop", "length", "tool_calls"]) -> StopReason: + """ + Convert a Groq chat completion finish_reason to a StopReason. + + finish_reason: Literal["stop", "length", "tool_calls"] + - stop -> model hit a natural stop point or a provided stop sequence + - length -> maximum number of tokens specified in the request was reached + - tool_calls -> model called a tool + """ + if finish_reason == "stop": + return StopReason.end_of_turn + elif finish_reason == "length": + return StopReason.end_of_message + elif finish_reason == "tool_calls": + raise NotImplementedError("tool_calls is not supported yet") + else: + raise ValueError(f"Invalid finish reason: {finish_reason}") + + +async def convert_chat_completion_response_stream( + stream: Stream[ChatCompletionChunk], +) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + + # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... + def _event_type_generator() -> ( + Generator[ChatCompletionResponseEventType, None, None] + ): + yield ChatCompletionResponseEventType.start + while True: + yield ChatCompletionResponseEventType.progress + + event_types = _event_type_generator() + + for chunk in stream: + choice = chunk.choices[0] + + # We assume there's only one finish_reason for the entire stream. + # We collect the last finish_reason + if choice.finish_reason: + stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_types), + delta=choice.delta.content or "", + logprobs=None, + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + logprobs=None, + stop_reason=stop_reason, + ) + ) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 21e1221491..c7978aaec7 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -22,6 +22,7 @@ from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig +from llama_stack.providers.remote.inference.groq import GroqConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.tests.resolver import construct_stack_for_test @@ -146,6 +147,20 @@ def inference_together() -> ProviderFixture: ), ) +@pytest.fixture(scope="session") +def inference_groq() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="groq", + provider_type="remote::groq", + config=GroqConfig().model_dump(), + ) + ], + provider_data=dict( + groq_api_key=get_env_or_fail("GROQ_API_KEY"), + ), + ) @pytest.fixture(scope="session") def inference_bedrock() -> ProviderFixture: @@ -219,6 +234,7 @@ def model_id(inference_model) -> str: "ollama", "fireworks", "together", + "groq", "vllm_remote", "remote", "bedrock", diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py new file mode 100644 index 0000000000..3155a8c664 --- /dev/null +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -0,0 +1,272 @@ +import pytest +from typing import AsyncIterator, AsyncGenerator + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + UserMessage, + SystemMessage, + CompletionMessage, + StopReason, + SamplingStrategy, + ChatCompletionResponseEventType, +) +from llama_stack.providers.remote.inference.groq.groq_utils import ( + convert_chat_completion_request, + convert_chat_completion_response, + convert_chat_completion_response_stream, +) +from groq.types.chat.chat_completion import ChatCompletion, Choice +from groq.types.chat.chat_completion_message import ChatCompletionMessage +from groq.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice as StreamChoice, ChoiceDelta + + + +class TestConvertChatCompletionRequest: + def test_sets_model(self): + request = self._dummy_chat_completion_request() + request.model = "Llama-3.2-3B" + + converted = convert_chat_completion_request(request) + + assert converted["model"] == "Llama-3.2-3B" + + def test_converts_user_message(self): + request = self._dummy_chat_completion_request() + request.messages = [UserMessage(content="Hello World")] + + converted = convert_chat_completion_request(request) + + assert converted["messages"] == [ + {"role": "user", "content": "Hello World"}, + ] + + def test_converts_system_message(self): + request = self._dummy_chat_completion_request() + request.messages = [SystemMessage(content="You are a helpful assistant.")] + + converted = convert_chat_completion_request(request) + + assert converted["messages"] == [ + {"role": "system", "content": "You are a helpful assistant."}, + ] + + def test_converts_completion_message(self): + request = self._dummy_chat_completion_request() + request.messages = [ + UserMessage(content="Hello World"), + CompletionMessage( + content="Hello World! How can I help you today?", + stop_reason=StopReason.end_of_message, + ), + ] + + converted = convert_chat_completion_request(request) + + assert converted["messages"] == [ + {"role": "user", "content": "Hello World"}, + {"role": "assistant", "content": "Hello World! How can I help you today?"}, + ] + + def test_does_not_include_logprobs(self): + request = self._dummy_chat_completion_request() + request.logprobs = True + + with pytest.warns(Warning) as warnings: + converted = convert_chat_completion_request(request) + + assert "logprobs are not supported yet" in warnings[0].message.args[0] + assert converted.get("logprobs") is None + + def test_does_not_include_response_format(self): + request = self._dummy_chat_completion_request() + request.response_format = { + "type": "json_object", + "json_schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "number"}, + }, + }, + } + + with pytest.warns(Warning) as warnings: + converted = convert_chat_completion_request(request) + + assert "response_format is not supported yet" in warnings[0].message.args[0] + assert converted.get("response_format") is None + + def test_does_not_include_repetition_penalty(self): + request = self._dummy_chat_completion_request() + request.sampling_params.repetition_penalty = 1.5 + + with pytest.warns(Warning) as warnings: + converted = convert_chat_completion_request(request) + + assert "repetition_penalty is not supported" in warnings[0].message.args[0] + assert converted.get("repetition_penalty") is None + assert converted.get("frequency_penalty") is None + + + def test_includes_stream(self): + request = self._dummy_chat_completion_request() + request.stream = True + + converted = convert_chat_completion_request(request) + + assert converted["stream"] is True + + def test_n_is_1(self): + request = self._dummy_chat_completion_request() + + converted = convert_chat_completion_request(request) + + assert converted["n"] == 1 + + def test_if_max_tokens_is_0_then_it_is_not_included(self): + request = self._dummy_chat_completion_request() + # 0 is the default value for max_tokens + # So we assume that if it's 0, the user didn't set it + request.sampling_params.max_tokens = 0 + + converted = convert_chat_completion_request(request) + + assert converted.get("max_tokens") is None + + def test_includes_max_tokens_if_set(self): + request = self._dummy_chat_completion_request() + request.sampling_params.max_tokens = 100 + + converted = convert_chat_completion_request(request) + + assert converted["max_tokens"] == 100 + + def _dummy_chat_completion_request(self): + return ChatCompletionRequest( + model="Llama-3.2-3B", + messages=[UserMessage(content="Hello World")], + ) + + def test_includes_temperature(self): + request = self._dummy_chat_completion_request() + request.sampling_params.temperature = 0.5 + + converted = convert_chat_completion_request(request) + + assert converted["temperature"] == 0.5 + + def test_includes_top_p(self): + request = self._dummy_chat_completion_request() + request.sampling_params.top_p = 0.95 + + converted = convert_chat_completion_request(request) + + assert converted["top_p"] == 0.95 + +class TestConvertNonStreamChatCompletionResponse: + def test_returns_response(self): + response = self._dummy_chat_completion_response() + response.choices[0].message.content = "Hello World" + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.content == "Hello World" + + def test_maps_stop_to_end_of_message(self): + response = self._dummy_chat_completion_response() + response.choices[0].finish_reason = "stop" + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.stop_reason == StopReason.end_of_turn + + def test_maps_length_to_end_of_message(self): + response = self._dummy_chat_completion_response() + response.choices[0].finish_reason = "length" + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.stop_reason == StopReason.end_of_message + + def _dummy_chat_completion_response(self): + return ChatCompletion( + id="chatcmpl-123", + model="Llama-3.2-3B", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", content="Hello World" + ), + finish_reason="stop", + ) + ], + created=1729382400, + object="chat.completion", + ) + +class TestConvertStreamChatCompletionResponse: + @pytest.mark.asyncio + async def test_returns_stream(self): + async def chat_completion_stream(): + messages = ["Hello ", "World ", " !"] + for i, message in enumerate(messages): + chunk = self._dummy_chat_completion_chunk() + chunk.choices[0].delta.content = message + if i == len(messages) - 1: + chunk.choices[0].finish_reason = "stop" + else: + chunk.choices[0].finish_reason = None + yield chunk + + chunk = self._dummy_chat_completion_chunk() + chunk.choices[0].delta.content = None + chunk.choices[0].finish_reason = "stop" + yield chunk + + stream = chat_completion_stream() + converted = convert_chat_completion_response_stream(stream) + + iter = converted.__aiter__() + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.start + assert chunk.event.delta == "Hello " + + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.progress + assert chunk.event.delta == "World " + + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.progress + assert chunk.event.delta == " !" + + # Dummy chunk to ensure the last chunk is really the end of the stream + # This one technically maps to Groq's final "stop" chunk + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.progress + assert chunk.event.delta == "" + + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.complete + assert chunk.event.delta == "" + assert chunk.event.stop_reason == StopReason.end_of_turn + + with pytest.raises(StopAsyncIteration): + await iter.__anext__() + + def _dummy_chat_completion_chunk(self): + return ChatCompletionChunk( + id="chatcmpl-123", + model="Llama-3.2-3B", + choices=[ + StreamChoice( + index=0, + delta=ChoiceDelta( + role="assistant", content="Hello World" + ), + ) + ], + created=1729382400, + object="chat.completion.chunk", + x_groq=None, + ) diff --git a/llama_stack/providers/tests/inference/groq/test_init.py b/llama_stack/providers/tests/inference/groq/test_init.py new file mode 100644 index 0000000000..a1b5ff1c09 --- /dev/null +++ b/llama_stack/providers/tests/inference/groq/test_init.py @@ -0,0 +1,30 @@ +import pytest + +from llama_stack.providers.remote.inference.ollama import OllamaImplConfig +from llama_stack.providers.remote.inference.groq import get_adapter_impl +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter +from llama_stack.apis.inference import Inference +import os + + +class TestGroqInit: + @pytest.mark.asyncio + async def test_raises_runtime_error_if_config_is_not_groq_config(self): + config = OllamaImplConfig(model="llama3.1-8b-8192") + + with pytest.raises(RuntimeError): + await get_adapter_impl(config, None) + + @pytest.mark.asyncio + async def test_returns_groq_adapter(self): + config = GroqConfig() + adapter = await get_adapter_impl(config, None) + assert type(adapter) is GroqInferenceAdapter + assert isinstance(adapter, Inference) + +class TestGroqConfig: + def test_api_key_defaults_to_env_var(self): + os.environ["GROQ_API_KEY"] = "test" + config = GroqConfig() + assert config.api_key == "test" diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index b84761219b..955430e03b 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -295,6 +295,13 @@ async def test_chat_completion_with_tool_calling( sample_messages, sample_tool_definition, ): + inference_impl, _ = inference_stack + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type in ( + "remote::groq", + ): + pytest.skip(provider.__provider_spec__.provider_type + " doesn't support tool calling yet") + inference_impl, _ = inference_stack messages = sample_messages + [ UserMessage( @@ -335,6 +342,12 @@ async def test_chat_completion_with_tool_calling_streaming( sample_tool_definition, ): inference_impl, _ = inference_stack + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type in ( + "remote::groq", + ): + pytest.skip(provider.__provider_spec__.provider_type + " doesn't support tool calling yet") + messages = sample_messages + [ UserMessage( content="What's the weather like in San Francisco?",