Skip to content

Commit

Permalink
Add Groq Provider - chat completions
Browse files Browse the repository at this point in the history
  • Loading branch information
aidando73 committed Dec 12, 2024
1 parent e2054d5 commit 98e3563
Show file tree
Hide file tree
Showing 9 changed files with 666 additions and 0 deletions.
9 changes: 9 additions & 0 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["groq"],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
Expand Down
15 changes: 15 additions & 0 deletions llama_stack/providers/remote/inference/groq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

from llama_stack.apis.inference import Inference

from .config import GroqConfig
from .groq import GroqInferenceAdapter

async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
from .groq import GroqInferenceAdapter

if not isinstance(config, GroqConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")

adapter = GroqInferenceAdapter(config)
return adapter
13 changes: 13 additions & 0 deletions llama_stack/providers/remote/inference/groq/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os
from typing import Optional

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field


@json_schema_type
class GroqConfig(BaseModel):
api_key: Optional[str] = Field(
default_factory=lambda: os.getenv("GROQ_API_KEY"),
description="The Groq API key",
)
144 changes: 144 additions & 0 deletions llama_stack/providers/remote/inference/groq/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import warnings
from typing import AsyncIterator, List, Optional, Union, AsyncGenerator
import json
from .groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
SamplingStrategy,
ToolParamDefinition,
)
from llama_models.datatypes import SamplingParams
from llama_models.sku_list import CoreModelId
from llama_models.llama3.api.datatypes import StopReason
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseStreamChunk,
ChatCompletionResponseEventType,
CompletionResponse,
CompletionMessage,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
ResponseFormat,
ToolCallDelta,
ToolCall,
ToolCallParseStatus,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
)
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from groq import Groq
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)

_MODEL_ALIASES = [
build_model_alias(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias_with_just_provider_model_id(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_model_alias(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_model_alias(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
]

class GroqInferenceAdapter(Inference, ModelRegistryHelper):
_client: Groq

def __init__(self, config: GroqConfig):
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
self._client = Groq(api_key=config.api_key)

def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
# Groq doesn't support non-chat completion as of time of writing
raise NotImplementedError()

async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[
ToolPromptFormat
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:

if model_id == "llama-3.2-3b-preview":
warnings.warn(
"Groq only contains a preview version for llama-3.2-3b-instruct. "
"Preview models aren't recommended for production use. "
"They can be discontinued on short notice."
)

model_id = self.get_provider_model_id(model_id)
request = convert_chat_completion_request(
request=ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
)

response = self._client.chat.completions.create(**request)

if stream:
return convert_chat_completion_response_stream(response)
else:
return convert_chat_completion_response(response)

async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
154 changes: 154 additions & 0 deletions llama_stack/providers/remote/inference/groq/groq_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import warnings
from typing import Literal, AsyncGenerator, Generator

from llama_stack.apis.inference import (
ChatCompletionRequest,
Message,
Role,
ChatCompletionResponse,
CompletionMessage,
StopReason,
ChatCompletionResponseStreamChunk,
ChatCompletionResponseEventType,
ChatCompletionResponseEvent,
)

from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from groq.types.chat.chat_completion_system_message_param import (
ChatCompletionSystemMessageParam,
)
from groq.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)
from groq.types.chat.chat_completion_assistant_message_param import (
ChatCompletionAssistantMessageParam,
)
from groq.types.chat.chat_completion import ChatCompletion
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
from groq import Stream


def convert_chat_completion_request(
request: ChatCompletionRequest,
) -> CompletionCreateParams:
"""
Convert a ChatCompletionRequest to a Groq API-compatible dictionary.
Warns client if request contains unsupported features.
"""

if request.logprobs:
# Groq doesn't support logprobs at the time of writing
warnings.warn("logprobs are not supported yet")

if request.response_format:
# Groq's JSON mode is beta at the time of writing
warnings.warn("response_format is not supported yet")

if request.sampling_params.repetition_penalty:
# groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty
# seem to have different semantics
# frequency_penalty defaults to 0 is a float between -2.0 and 2.0
# repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0
# so we exclude it for now
warnings.warn("repetition_penalty is not supported")

if request.tools:
warnings.warn("tools are not supported yet")

return CompletionCreateParams(
model=request.model,
messages=[_convert_message(message) for message in request.messages],
logprobs=None,
frequency_penalty=None,
stream=request.stream,
# Groq only supports n=1 at the time of writing
n=1,
max_tokens=request.sampling_params.max_tokens or None,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
)


def _convert_message(message: Message) -> ChatCompletionMessageParam:
if message.role == Role.system.value:
return ChatCompletionSystemMessageParam(role="system", content=message.content)
elif message.role == Role.user.value:
return ChatCompletionUserMessageParam(role="user", content=message.content)
elif message.role == Role.assistant.value:
return ChatCompletionAssistantMessageParam(
role="assistant", content=message.content
)
else:
raise ValueError(f"Invalid message role: {message.role}")


def convert_chat_completion_response(
response: ChatCompletion,
) -> ChatCompletionResponse:
# groq only supports n=1 at time of writing, so there is only one choice
choice = response.choices[0]
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)

def _map_finish_reason_to_stop_reason(finish_reason: Literal["stop", "length", "tool_calls"]) -> StopReason:
"""
Convert a Groq chat completion finish_reason to a StopReason.
finish_reason: Literal["stop", "length", "tool_calls"]
- stop -> model hit a natural stop point or a provided stop sequence
- length -> maximum number of tokens specified in the request was reached
- tool_calls -> model called a tool
"""
if finish_reason == "stop":
return StopReason.end_of_turn
elif finish_reason == "length":
return StopReason.end_of_message
elif finish_reason == "tool_calls":
raise NotImplementedError("tool_calls is not supported yet")
else:
raise ValueError(f"Invalid finish reason: {finish_reason}")


async def convert_chat_completion_response_stream(
stream: Stream[ChatCompletionChunk],
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:

# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> (
Generator[ChatCompletionResponseEventType, None, None]
):
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress

event_types = _event_type_generator()

for chunk in stream:
choice = chunk.choices[0]

# We assume there's only one finish_reason for the entire stream.
# We collect the last finish_reason
if choice.finish_reason:
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_types),
delta=choice.delta.content or "",
logprobs=None,
)
)

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
logprobs=None,
stop_reason=stop_reason,
)
)
16 changes: 16 additions & 0 deletions llama_stack/providers/tests/inference/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
Expand Down Expand Up @@ -146,6 +147,20 @@ def inference_together() -> ProviderFixture:
),
)

@pytest.fixture(scope="session")
def inference_groq() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="groq",
provider_type="remote::groq",
config=GroqConfig().model_dump(),
)
],
provider_data=dict(
groq_api_key=get_env_or_fail("GROQ_API_KEY"),
),
)

@pytest.fixture(scope="session")
def inference_bedrock() -> ProviderFixture:
Expand Down Expand Up @@ -219,6 +234,7 @@ def model_id(inference_model) -> str:
"ollama",
"fireworks",
"together",
"groq",
"vllm_remote",
"remote",
"bedrock",
Expand Down
Loading

0 comments on commit 98e3563

Please sign in to comment.