Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Support for tool_choice=required #657

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class FireworksImplConfig(BaseModel):
)

@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference/v1",
"api_key": "${env.FIREWORKS_API_KEY}",
Expand Down
27 changes: 19 additions & 8 deletions llama_stack/providers/remote/inference/fireworks/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
interleaved_content_as_str,
request_has_media,
)
from ..nvidia.openai_utils import _convert_tooldef_to_openai_tool, convert_openai_chat_completion_choice

from .config import FireworksImplConfig

Expand Down Expand Up @@ -65,6 +66,10 @@
"fireworks/llama-v3p2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_model_alias(
"fireworks/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
Expand Down Expand Up @@ -205,10 +210,12 @@ async def _nonstream_chat_completion(
) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
print(params)
r = await self._get_client().chat.completions.acreate(**params)
return convert_openai_chat_completion_choice(r.choices[0])
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)
return process_chat_completion_response(r, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest
Expand Down Expand Up @@ -236,14 +243,18 @@ async def _get_params(
media_present = request_has_media(request)

if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_openai_dict(m) for m in request.messages
input_dict["messages"] = [
await convert_message_to_openai_dict(m) for m in request.messages
]
# print(input_dict["messages"])
if request.tool_choice == ToolChoice.required:
input_dict["tool_choice"] = "any"

if request.tools:
input_dict["tools"] = [
_convert_tooldef_to_openai_tool(t) for t in request.tools
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
Copy link
Contributor Author

@aidando73 aidando73 Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were using the completions api before - now we hit the chat completions api directly

# print(input_dict)
else:
assert (
not media_present
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

from llama_models.llama3.api.datatypes import (
BuiltinTool,
CompletionMessage,
# CompletionMessage,
StopReason,
TokenLogProbs,
# TokenLogProbs,
ToolCall,
ToolDefinition,
)
from llama_stack.apis.inference import CompletionMessage, TokenLogProbs
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
Expand Down Expand Up @@ -339,7 +340,7 @@ def _convert_openai_tool_calls(

def _convert_openai_logprobs(
logprobs: OpenAIChoiceLogprobs,
) -> Optional[List[TokenLogProbs]]:
) -> Optional[List[Any]]:
"""
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.

Expand Down
5 changes: 5 additions & 0 deletions llama_stack/templates/fireworks/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ models:
provider_id: fireworks
provider_model_id: fireworks/llama-v3p2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: fireworks
provider_model_id: fireworks/llama-v3p3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-8B
provider_id: fireworks
Expand Down
Loading