From c0757fd16971cf0c3e5c8e52e511e3a90563bc64 Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Thu, 12 Dec 2024 21:15:09 +1100 Subject: [PATCH] Add Groq provider - chat completions --- README.md | 1 + llama_stack/providers/registry/inference.py | 10 + .../remote/inference/groq/__init__.py | 26 ++ .../providers/remote/inference/groq/config.py | 19 ++ .../providers/remote/inference/groq/groq.py | 150 ++++++++++ .../remote/inference/groq/groq_utils.py | 153 ++++++++++ .../providers/tests/inference/fixtures.py | 18 ++ .../tests/inference/groq/test_groq_utils.py | 271 ++++++++++++++++++ .../tests/inference/groq/test_init.py | 29 ++ .../tests/inference/test_text_inference.py | 15 + 10 files changed, 692 insertions(+) create mode 100644 llama_stack/providers/remote/inference/groq/__init__.py create mode 100644 llama_stack/providers/remote/inference/groq/config.py create mode 100644 llama_stack/providers/remote/inference/groq/groq.py create mode 100644 llama_stack/providers/remote/inference/groq/groq_utils.py create mode 100644 llama_stack/providers/tests/inference/groq/test_groq_utils.py create mode 100644 llama_stack/providers/tests/inference/groq/test_init.py diff --git a/README.md b/README.md index 16ca48ecb0..c0e15d05f5 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ Additionally, we have designed every element of the Stack such that APIs as well | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | | Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | +| Groq | Hosted | | :heavy_check_mark: | | | | | Ollama | Single Node | | :heavy_check_mark: | | | | | TGI | Hosted and Single Node | | :heavy_check_mark: | | | | | [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | | diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 0ff557b9f9..258212f663 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -149,6 +149,16 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="groq", + pip_packages=["groq"], + module="llama_stack.providers.remote.inference.groq", + config_class="llama_stack.providers.remote.inference.groq.GroqConfig", + provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/groq/__init__.py b/llama_stack/providers/remote/inference/groq/__init__.py new file mode 100644 index 0000000000..923c356969 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from llama_stack.apis.inference import Inference + +from .config import GroqConfig + + +class GroqProviderDataValidator(BaseModel): + groq_api_key: str + + +async def get_adapter_impl(config: GroqConfig, _deps) -> Inference: + # import dynamically so the import is used only when it is needed + from .groq import GroqInferenceAdapter + + if not isinstance(config, GroqConfig): + raise RuntimeError(f"Unexpected config type: {type(config)}") + + adapter = GroqInferenceAdapter(config) + return adapter diff --git a/llama_stack/providers/remote/inference/groq/config.py b/llama_stack/providers/remote/inference/groq/config.py new file mode 100644 index 0000000000..7c50234103 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/config.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class GroqConfig(BaseModel): + api_key: Optional[str] = Field( + # The Groq client library loads the GROQ_API_KEY environment variable by default + default=None, + description="The Groq API key", + ) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py new file mode 100644 index 0000000000..1a19b4d79e --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from typing import AsyncIterator, List, Optional, Union + +from groq import Groq +from llama_models.datatypes import SamplingParams +from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat +from llama_models.sku_list import CoreModelId + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + InterleavedContent, + LogProbConfig, + Message, + ResponseFormat, + ToolChoice, +) +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + build_model_alias_with_just_provider_model_id, + ModelRegistryHelper, +) +from .groq_utils import ( + convert_chat_completion_request, + convert_chat_completion_response, + convert_chat_completion_response_stream, +) + +_MODEL_ALIASES = [ + build_model_alias( + "llama3-8b-8192", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias_with_just_provider_model_id( + "llama-3.1-8b-instant", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "llama3-70b-8192", + CoreModelId.llama3_70b_instruct.value, + ), + build_model_alias( + "llama-3.3-70b-versatile", + CoreModelId.llama3_3_70b_instruct.value, + ), + # Groq only contains a preview version for llama-3.2-3b + # Preview models aren't recommended for production use, but we include this one + # to pass the test fixture + # TODO(aidand): Replace this with a stable model once Groq supports it + build_model_alias( + "llama-3.2-3b-preview", + CoreModelId.llama3_2_3b_instruct.value, + ), +] + + +class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData): + _config: GroqConfig + + def __init__(self, config: GroqConfig): + ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) + self._config = config + + def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: + # Groq doesn't support non-chat completion as of time of writing + raise NotImplementedError() + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ + ToolPromptFormat + ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: + model_id = self.get_provider_model_id(model_id) + if model_id == "llama-3.2-3b-preview": + warnings.warn( + "Groq only contains a preview version for llama-3.2-3b-instruct. " + "Preview models aren't recommended for production use. " + "They can be discontinued on short notice." + ) + + request = convert_chat_completion_request( + request=ChatCompletionRequest( + model=model_id, + messages=messages, + sampling_params=sampling_params, + response_format=response_format, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + ) + + response = self._get_client().chat.completions.create(**request) + + if stream: + return convert_chat_completion_response_stream(response) + else: + return convert_chat_completion_response(response) + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedContent], + ) -> EmbeddingsResponse: + raise NotImplementedError() + + def _get_client(self) -> Groq: + if self._config.api_key is not None: + return Groq(api_key=self.config.api_key) + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.groq_api_key: + raise ValueError( + 'Pass Groq API Key in the header X-LlamaStack-ProviderData as { "groq_api_key": "" }' + ) + return Groq(api_key=provider_data.groq_api_key) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py new file mode 100644 index 0000000000..74c6178a39 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from typing import AsyncGenerator, Literal + +from groq import Stream +from groq.types.chat.chat_completion import ChatCompletion +from groq.types.chat.chat_completion_assistant_message_param import ( + ChatCompletionAssistantMessageParam, +) +from groq.types.chat.chat_completion_chunk import ChatCompletionChunk +from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from groq.types.chat.chat_completion_system_message_param import ( + ChatCompletionSystemMessageParam, +) +from groq.types.chat.chat_completion_user_message_param import ( + ChatCompletionUserMessageParam, +) + +from groq.types.chat.completion_create_params import CompletionCreateParams + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionMessage, + Message, + StopReason, +) + + +def convert_chat_completion_request( + request: ChatCompletionRequest, +) -> CompletionCreateParams: + """ + Convert a ChatCompletionRequest to a Groq API-compatible dictionary. + Warns client if request contains unsupported features. + """ + + if request.logprobs: + # Groq doesn't support logprobs at the time of writing + warnings.warn("logprobs are not supported yet") + + if request.response_format: + # Groq's JSON mode is beta at the time of writing + warnings.warn("response_format is not supported yet") + + if request.sampling_params.repetition_penalty != 1.0: + # groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty + # seem to have different semantics + # frequency_penalty defaults to 0 is a float between -2.0 and 2.0 + # repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0 + # so we exclude it for now + warnings.warn("repetition_penalty is not supported") + + if request.tools: + warnings.warn("tools are not supported yet") + + return CompletionCreateParams( + model=request.model, + messages=[_convert_message(message) for message in request.messages], + logprobs=None, + frequency_penalty=None, + stream=request.stream, + max_tokens=request.sampling_params.max_tokens or None, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + ) + + +def _convert_message(message: Message) -> ChatCompletionMessageParam: + if message.role == "system": + return ChatCompletionSystemMessageParam(role="system", content=message.content) + elif message.role == "user": + return ChatCompletionUserMessageParam(role="user", content=message.content) + elif message.role == "assistant": + return ChatCompletionAssistantMessageParam( + role="assistant", content=message.content + ) + else: + raise ValueError(f"Invalid message role: {message.role}") + + +def convert_chat_completion_response( + response: ChatCompletion, +) -> ChatCompletionResponse: + # groq only supports n=1 at time of writing, so there is only one choice + choice = response.choices[0] + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content, + stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), + ), + ) + + +def _map_finish_reason_to_stop_reason( + finish_reason: Literal["stop", "length", "tool_calls"] +) -> StopReason: + """ + Convert a Groq chat completion finish_reason to a StopReason. + + finish_reason: Literal["stop", "length", "tool_calls"] + - stop -> model hit a natural stop point or a provided stop sequence + - length -> maximum number of tokens specified in the request was reached + - tool_calls -> model called a tool + """ + if finish_reason == "stop": + return StopReason.end_of_turn + elif finish_reason == "length": + return StopReason.out_of_tokens + elif finish_reason == "tool_calls": + raise NotImplementedError("tool_calls is not supported yet") + else: + raise ValueError(f"Invalid finish reason: {finish_reason}") + + +async def convert_chat_completion_response_stream( + stream: Stream[ChatCompletionChunk], +) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + + event_type = ChatCompletionResponseEventType.start + for chunk in stream: + choice = chunk.choices[0] + + # We assume there's only one finish_reason for the entire stream. + # We collect the last finish_reason + if choice.finish_reason: + stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=choice.delta.content or "", + logprobs=None, + ) + ) + event_type = ChatCompletionResponseEventType.progress + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + logprobs=None, + stop_reason=stop_reason, + ) + ) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 7cc15bd9dd..d956caa930 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -19,6 +19,7 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig +from llama_stack.providers.remote.inference.groq import GroqConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig @@ -151,6 +152,22 @@ def inference_together() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_groq() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="groq", + provider_type="remote::groq", + config=GroqConfig().model_dump(), + ) + ], + provider_data=dict( + groq_api_key=get_env_or_fail("GROQ_API_KEY"), + ), + ) + + @pytest.fixture(scope="session") def inference_bedrock() -> ProviderFixture: return ProviderFixture( @@ -236,6 +253,7 @@ def model_id(inference_model) -> str: "ollama", "fireworks", "together", + "groq", "vllm_remote", "remote", "bedrock", diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py new file mode 100644 index 0000000000..53b5c29cb0 --- /dev/null +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from groq.types.chat.chat_completion import ChatCompletion, Choice +from groq.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice as StreamChoice, + ChoiceDelta, +) +from groq.types.chat.chat_completion_message import ChatCompletionMessage + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponseEventType, + CompletionMessage, + StopReason, + SystemMessage, + UserMessage, +) +from llama_stack.providers.remote.inference.groq.groq_utils import ( + convert_chat_completion_request, + convert_chat_completion_response, + convert_chat_completion_response_stream, +) + + +class TestConvertChatCompletionRequest: + def test_sets_model(self): + request = self._dummy_chat_completion_request() + request.model = "Llama-3.2-3B" + + converted = convert_chat_completion_request(request) + + assert converted["model"] == "Llama-3.2-3B" + + def test_converts_user_message(self): + request = self._dummy_chat_completion_request() + request.messages = [UserMessage(content="Hello World")] + + converted = convert_chat_completion_request(request) + + assert converted["messages"] == [ + {"role": "user", "content": "Hello World"}, + ] + + def test_converts_system_message(self): + request = self._dummy_chat_completion_request() + request.messages = [SystemMessage(content="You are a helpful assistant.")] + + converted = convert_chat_completion_request(request) + + assert converted["messages"] == [ + {"role": "system", "content": "You are a helpful assistant."}, + ] + + def test_converts_completion_message(self): + request = self._dummy_chat_completion_request() + request.messages = [ + UserMessage(content="Hello World"), + CompletionMessage( + content="Hello World! How can I help you today?", + stop_reason=StopReason.end_of_message, + ), + ] + + converted = convert_chat_completion_request(request) + + assert converted["messages"] == [ + {"role": "user", "content": "Hello World"}, + {"role": "assistant", "content": "Hello World! How can I help you today?"}, + ] + + def test_does_not_include_logprobs(self): + request = self._dummy_chat_completion_request() + request.logprobs = True + + with pytest.warns(Warning) as warnings: + converted = convert_chat_completion_request(request) + + assert "logprobs are not supported yet" in warnings[0].message.args[0] + assert converted.get("logprobs") is None + + def test_does_not_include_response_format(self): + request = self._dummy_chat_completion_request() + request.response_format = { + "type": "json_object", + "json_schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "number"}, + }, + }, + } + + with pytest.warns(Warning) as warnings: + converted = convert_chat_completion_request(request) + + assert "response_format is not supported yet" in warnings[0].message.args[0] + assert converted.get("response_format") is None + + def test_does_not_include_repetition_penalty(self): + request = self._dummy_chat_completion_request() + request.sampling_params.repetition_penalty = 1.5 + + with pytest.warns(Warning) as warnings: + converted = convert_chat_completion_request(request) + + assert "repetition_penalty is not supported" in warnings[0].message.args[0] + assert converted.get("repetition_penalty") is None + assert converted.get("frequency_penalty") is None + + def test_includes_stream(self): + request = self._dummy_chat_completion_request() + request.stream = True + + converted = convert_chat_completion_request(request) + + assert converted["stream"] is True + + def test_if_max_tokens_is_0_then_it_is_not_included(self): + request = self._dummy_chat_completion_request() + # 0 is the default value for max_tokens + # So we assume that if it's 0, the user didn't set it + request.sampling_params.max_tokens = 0 + + converted = convert_chat_completion_request(request) + + assert converted.get("max_tokens") is None + + def test_includes_max_tokens_if_set(self): + request = self._dummy_chat_completion_request() + request.sampling_params.max_tokens = 100 + + converted = convert_chat_completion_request(request) + + assert converted["max_tokens"] == 100 + + def _dummy_chat_completion_request(self): + return ChatCompletionRequest( + model="Llama-3.2-3B", + messages=[UserMessage(content="Hello World")], + ) + + def test_includes_temperature(self): + request = self._dummy_chat_completion_request() + request.sampling_params.temperature = 0.5 + + converted = convert_chat_completion_request(request) + + assert converted["temperature"] == 0.5 + + def test_includes_top_p(self): + request = self._dummy_chat_completion_request() + request.sampling_params.top_p = 0.95 + + converted = convert_chat_completion_request(request) + + assert converted["top_p"] == 0.95 + + +class TestConvertNonStreamChatCompletionResponse: + def test_returns_response(self): + response = self._dummy_chat_completion_response() + response.choices[0].message.content = "Hello World" + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.content == "Hello World" + + def test_maps_stop_to_end_of_message(self): + response = self._dummy_chat_completion_response() + response.choices[0].finish_reason = "stop" + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.stop_reason == StopReason.end_of_turn + + def test_maps_length_to_end_of_message(self): + response = self._dummy_chat_completion_response() + response.choices[0].finish_reason = "length" + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.stop_reason == StopReason.out_of_tokens + + def _dummy_chat_completion_response(self): + return ChatCompletion( + id="chatcmpl-123", + model="Llama-3.2-3B", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", content="Hello World" + ), + finish_reason="stop", + ) + ], + created=1729382400, + object="chat.completion", + ) + + +class TestConvertStreamChatCompletionResponse: + @pytest.mark.asyncio + async def test_returns_stream(self): + def chat_completion_stream(): + messages = ["Hello ", "World ", " !"] + for i, message in enumerate(messages): + chunk = self._dummy_chat_completion_chunk() + chunk.choices[0].delta.content = message + if i == len(messages) - 1: + chunk.choices[0].finish_reason = "stop" + else: + chunk.choices[0].finish_reason = None + yield chunk + + chunk = self._dummy_chat_completion_chunk() + chunk.choices[0].delta.content = None + chunk.choices[0].finish_reason = "stop" + yield chunk + + stream = chat_completion_stream() + converted = convert_chat_completion_response_stream(stream) + + iter = converted.__aiter__() + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.start + assert chunk.event.delta == "Hello " + + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.progress + assert chunk.event.delta == "World " + + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.progress + assert chunk.event.delta == " !" + + # Dummy chunk to ensure the last chunk is really the end of the stream + # This one technically maps to Groq's final "stop" chunk + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.progress + assert chunk.event.delta == "" + + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.complete + assert chunk.event.delta == "" + assert chunk.event.stop_reason == StopReason.end_of_turn + + with pytest.raises(StopAsyncIteration): + await iter.__anext__() + + def _dummy_chat_completion_chunk(self): + return ChatCompletionChunk( + id="chatcmpl-123", + model="Llama-3.2-3B", + choices=[ + StreamChoice( + index=0, + delta=ChoiceDelta(role="assistant", content="Hello World"), + ) + ], + created=1729382400, + object="chat.completion.chunk", + x_groq=None, + ) diff --git a/llama_stack/providers/tests/inference/groq/test_init.py b/llama_stack/providers/tests/inference/groq/test_init.py new file mode 100644 index 0000000000..d23af5934d --- /dev/null +++ b/llama_stack/providers/tests/inference/groq/test_init.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from llama_stack.apis.inference import Inference +from llama_stack.providers.remote.inference.groq import get_adapter_impl +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter + +from llama_stack.providers.remote.inference.ollama import OllamaImplConfig + + +class TestGroqInit: + @pytest.mark.asyncio + async def test_raises_runtime_error_if_config_is_not_groq_config(self): + config = OllamaImplConfig(model="llama3.1-8b-8192") + + with pytest.raises(RuntimeError): + await get_adapter_impl(config, None) + + @pytest.mark.asyncio + async def test_returns_groq_adapter(self): + config = GroqConfig() + adapter = await get_adapter_impl(config, None) + assert type(adapter) is GroqInferenceAdapter + assert isinstance(adapter, Inference) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99a62ac080..02851830b8 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -350,6 +350,14 @@ async def test_chat_completion_with_tool_calling( sample_messages, sample_tool_definition, ): + inference_impl, _ = inference_stack + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type in ("remote::groq",): + pytest.skip( + provider.__provider_spec__.provider_type + + " doesn't support tool calling yet" + ) + inference_impl, _ = inference_stack messages = sample_messages + [ UserMessage( @@ -390,6 +398,13 @@ async def test_chat_completion_with_tool_calling_streaming( sample_tool_definition, ): inference_impl, _ = inference_stack + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type in ("remote::groq",): + pytest.skip( + provider.__provider_spec__.provider_type + + " doesn't support tool calling yet" + ) + messages = sample_messages + [ UserMessage( content="What's the weather like in San Francisco?",