Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference to use provider resource id to register and validate #428

Merged
merged 13 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions docs/resources/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543"
},
"servers": [
{
Expand Down Expand Up @@ -2856,7 +2856,7 @@
"ChatCompletionRequest": {
"type": "object",
"properties": {
"model": {
"model_id": {
"type": "string"
},
"messages": {
Expand Down Expand Up @@ -2993,7 +2993,7 @@
},
"additionalProperties": false,
"required": [
"model",
"model_id",
"messages"
]
},
Expand Down Expand Up @@ -3120,7 +3120,7 @@
"CompletionRequest": {
"type": "object",
"properties": {
"model": {
"model_id": {
"type": "string"
},
"content": {
Expand Down Expand Up @@ -3249,7 +3249,7 @@
},
"additionalProperties": false,
"required": [
"model",
"model_id",
"content"
]
},
Expand Down Expand Up @@ -4552,7 +4552,7 @@
"EmbeddingsRequest": {
"type": "object",
"properties": {
"model": {
"model_id": {
"type": "string"
},
"contents": {
Expand Down Expand Up @@ -4584,7 +4584,7 @@
},
"additionalProperties": false,
"required": [
"model",
"model_id",
"contents"
]
},
Expand Down Expand Up @@ -7837,58 +7837,58 @@
],
"tags": [
{
"name": "MemoryBanks"
"name": "Safety"
},
{
"name": "BatchInference"
"name": "EvalTasks"
},
{
"name": "Agents"
"name": "Shields"
},
{
"name": "Inference"
"name": "Telemetry"
},
{
"name": "DatasetIO"
"name": "Memory"
},
{
"name": "Eval"
"name": "Scoring"
},
{
"name": "Models"
"name": "ScoringFunctions"
},
{
"name": "PostTraining"
"name": "SyntheticDataGeneration"
},
{
"name": "ScoringFunctions"
"name": "Models"
},
{
"name": "Datasets"
"name": "Agents"
},
{
"name": "Shields"
"name": "MemoryBanks"
},
{
"name": "Telemetry"
"name": "DatasetIO"
},
{
"name": "Inspect"
"name": "Inference"
},
{
"name": "Safety"
"name": "Datasets"
},
{
"name": "SyntheticDataGeneration"
"name": "PostTraining"
},
{
"name": "Memory"
"name": "BatchInference"
},
{
"name": "Scoring"
"name": "Eval"
},
{
"name": "EvalTasks"
"name": "Inspect"
},
{
"name": "BuiltinTool",
Expand Down
42 changes: 21 additions & 21 deletions docs/resources/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ components:
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
type: array
model:
model_id:
type: string
response_format:
oneOf:
Expand Down Expand Up @@ -453,7 +453,7 @@ components:
$ref: '#/components/schemas/ToolDefinition'
type: array
required:
- model
- model_id
- messages
type: object
ChatCompletionResponse:
Expand Down Expand Up @@ -577,7 +577,7 @@ components:
default: 0
type: integer
type: object
model:
model_id:
type: string
response_format:
oneOf:
Expand Down Expand Up @@ -626,7 +626,7 @@ components:
stream:
type: boolean
required:
- model
- model_id
- content
type: object
CompletionResponse:
Expand Down Expand Up @@ -903,10 +903,10 @@ components:
- $ref: '#/components/schemas/ImageMedia'
type: array
type: array
model:
model_id:
type: string
required:
- model
- model_id
- contents
type: object
EmbeddingsResponse:
Expand Down Expand Up @@ -3384,7 +3384,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
\ draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
Expand Down Expand Up @@ -4748,24 +4748,24 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
- name: MemoryBanks
- name: BatchInference
- name: Agents
- name: Inference
- name: DatasetIO
- name: Eval
- name: Models
- name: PostTraining
- name: ScoringFunctions
- name: Datasets
- name: Safety
- name: EvalTasks
- name: Shields
- name: Telemetry
- name: Inspect
- name: Safety
- name: SyntheticDataGeneration
- name: Memory
- name: Scoring
- name: EvalTasks
- name: ScoringFunctions
- name: SyntheticDataGeneration
- name: Models
- name: Agents
- name: MemoryBanks
- name: DatasetIO
- name: Inference
- name: Datasets
- name: PostTraining
- name: BatchInference
- name: Eval
- name: Inspect
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ Once the server is set up, we can test it with a client to verify it's working c
$ curl http://localhost:5000/inference/chat_completion \
-H "Content-Type: application/json" \
-d '{
"model": "Llama3.1-8B-Instruct",
"model_id": "Llama3.1-8B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write me a 2 sentence poem about the moon"}
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class Inference(Protocol):
@webmethod(route="/inference/completion")
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
Expand All @@ -237,7 +237,7 @@ async def completion(
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,
model: str,
model_id: str,
Copy link
Contributor

@yanxi0830 yanxi0830 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama-stack-apps and llama-stack-client-python also needs to be updated to reflect the model -> model_id change.

https://github.com/meta-llama/llama-stack-apps/blob/0dc9c42fb42bf21d35e6d231afc4e0360a9eac61/examples/inference/client.py#L46-L49

messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
Expand All @@ -254,6 +254,6 @@ async def chat_completion(
@webmethod(route="/inference/embeddings")
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...
18 changes: 9 additions & 9 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def register_model(

async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
Expand All @@ -106,7 +106,7 @@ async def chat_completion(
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
params = dict(
model=model,
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
Expand All @@ -116,24 +116,24 @@ async def chat_completion(
stream=stream,
logprobs=logprobs,
)
provider = self.routing_table.get_provider_impl(model)
provider = self.routing_table.get_provider_impl(model_id)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
else:
return await provider.chat_completion(**params)

async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
provider = self.routing_table.get_provider_impl(model)
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model=model,
model_id=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
Expand All @@ -147,11 +147,11 @@ async def completion(

async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(model).embeddings(
model=model,
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
)

Expand Down
Loading