From 95b7f57d92d4d96110e06aeeafdd96bc455109c8 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 08:21:37 -0800 Subject: [PATCH 01/13] use provider resource id to validate for models --- llama_stack/apis/inference/inference.py | 6 +-- llama_stack/distribution/routers/routers.py | 21 ++++++---- .../remote/inference/fireworks/fireworks.py | 10 ++--- .../remote/inference/together/together.py | 10 ++--- .../providers/tests/inference/fixtures.py | 29 ++++++++++++- .../tests/inference/test_text_inference.py | 41 ++++++++++--------- .../utils/inference/model_registry.py | 4 +- 7 files changed, 75 insertions(+), 46 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 1e7b29722..b2681e578 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -226,7 +226,7 @@ class Inference(Protocol): @webmethod(route="/inference/completion") async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -237,7 +237,7 @@ async def completion( @webmethod(route="/inference/chat_completion") async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), # zero-shot tool definitions as input to the model @@ -254,6 +254,6 @@ async def chat_completion( @webmethod(route="/inference/embeddings") async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 220dfdb56..7d4e43f60 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -95,7 +95,7 @@ async def register_model( async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -105,8 +105,9 @@ async def chat_completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.routing_table.get_model(model_id) params = dict( - model=model, + model_id=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -116,7 +117,7 @@ async def chat_completion( stream=stream, logprobs=logprobs, ) - provider = self.routing_table.get_provider_impl(model) + provider = self.routing_table.get_provider_impl(model_id) if stream: return (chunk async for chunk in await provider.chat_completion(**params)) else: @@ -124,16 +125,17 @@ async def chat_completion( async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - provider = self.routing_table.get_provider_impl(model) + model = await self.routing_table.get_model(model_id) + provider = self.routing_table.get_provider_impl(model_id) params = dict( - model=model, + model_id=model.provider_resource_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -147,11 +149,12 @@ async def completion( async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - return await self.routing_table.get_provider_impl(model).embeddings( - model=model, + model = await self.routing_table.get_model(model_id) + return await self.routing_table.get_provider_impl(model_id).embeddings( + model_id=model.provider_resource_id, contents=contents, ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 57e851c5b..67bf1cf47 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -74,7 +74,7 @@ def _get_client(self) -> Fireworks: async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -82,7 +82,7 @@ async def completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = CompletionRequest( - model=model, + model=model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -138,7 +138,7 @@ def _build_options( async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, @@ -149,7 +149,7 @@ async def chat_completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -229,7 +229,7 @@ async def _get_params( async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 28a566415..1b04ae556 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -63,7 +63,7 @@ async def shutdown(self) -> None: async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -71,7 +71,7 @@ async def completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = CompletionRequest( - model=model, + model=model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -135,7 +135,7 @@ def _build_options( async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, @@ -146,7 +146,7 @@ async def chat_completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -221,7 +221,7 @@ async def _get_params( async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index d35ebab28..2a4eba3ad 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -142,6 +142,31 @@ def inference_bedrock() -> ProviderFixture: ) +def get_model_short_name(model_name: str) -> str: + """Convert model name to a short test identifier. + + Args: + model_name: Full model name like "Llama3.1-8B-Instruct" + + Returns: + Short name like "llama_8b" suitable for test markers + """ + model_name = model_name.lower() + if "vision" in model_name: + return "llama_vision" + elif "3b" in model_name: + return "llama_3b" + elif "8b" in model_name: + return "llama_8b" + else: + return model_name.replace(".", "_").replace("-", "_") + + +@pytest.fixture(scope="session") +def model_id(inference_model) -> str: + return get_model_short_name(inference_model) + + INFERENCE_FIXTURES = [ "meta_reference", "ollama", @@ -154,7 +179,7 @@ def inference_bedrock() -> ProviderFixture: @pytest_asyncio.fixture(scope="session") -async def inference_stack(request, inference_model): +async def inference_stack(request, inference_model, model_id): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") impls = await resolve_impls_for_test_v2( @@ -163,7 +188,7 @@ async def inference_stack(request, inference_model): inference_fixture.provider_data, models=[ ModelInput( - model_id=inference_model, + model_id=model_id, ) ], ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index e7bfbc135..9850b328e 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -64,7 +64,7 @@ def sample_tool_definition(): class TestInference: @pytest.mark.asyncio - async def test_model_list(self, inference_model, inference_stack): + async def test_model_list(self, inference_model, inference_stack, model_id): _, models_impl = inference_stack response = await models_impl.list_models() assert isinstance(response, list) @@ -73,17 +73,16 @@ async def test_model_list(self, inference_model, inference_stack): model_def = None for model in response: - if model.identifier == inference_model: + if model.identifier == model_id: model_def = model break assert model_def is not None @pytest.mark.asyncio - async def test_completion(self, inference_model, inference_stack): + async def test_completion(self, inference_model, inference_stack, model_id): inference_impl, _ = inference_stack - - provider = inference_impl.routing_table.get_provider_impl(inference_model) + provider = inference_impl.routing_table.get_provider_impl(model_id) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::ollama", @@ -96,7 +95,7 @@ async def test_completion(self, inference_model, inference_stack): response = await inference_impl.completion( content="Micheael Jordan is born in ", stream=False, - model=inference_model, + model_id=model_id, sampling_params=SamplingParams( max_tokens=50, ), @@ -110,7 +109,7 @@ async def test_completion(self, inference_model, inference_stack): async for r in await inference_impl.completion( content="Roses are red,", stream=True, - model=inference_model, + model_id=model_id, sampling_params=SamplingParams( max_tokens=50, ), @@ -125,11 +124,11 @@ async def test_completion(self, inference_model, inference_stack): @pytest.mark.asyncio @pytest.mark.skip("This test is not quite robust") async def test_completions_structured_output( - self, inference_model, inference_stack + self, inference_model, inference_stack, model_id ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) + provider = inference_impl.routing_table.get_provider_impl(model_id) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::tgi", @@ -149,7 +148,7 @@ class Output(BaseModel): response = await inference_impl.completion( content=user_input, stream=False, - model=inference_model, + model_id=model_id, sampling_params=SamplingParams( max_tokens=50, ), @@ -167,11 +166,11 @@ class Output(BaseModel): @pytest.mark.asyncio async def test_chat_completion_non_streaming( - self, inference_model, inference_stack, common_params, sample_messages + self, inference_model, inference_stack, common_params, sample_messages, model_id ): inference_impl, _ = inference_stack response = await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=sample_messages, stream=False, **common_params, @@ -184,11 +183,11 @@ async def test_chat_completion_non_streaming( @pytest.mark.asyncio async def test_structured_output( - self, inference_model, inference_stack, common_params + self, inference_model, inference_stack, common_params, model_id ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) + provider = inference_impl.routing_table.get_provider_impl(model_id) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::fireworks", @@ -204,7 +203,7 @@ class AnswerFormat(BaseModel): num_seasons_in_nba: int response = await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(content="Please give me information about Michael Jordan."), @@ -227,7 +226,7 @@ class AnswerFormat(BaseModel): assert answer.num_seasons_in_nba == 15 response = await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(content="Please give me information about Michael Jordan."), @@ -244,13 +243,13 @@ class AnswerFormat(BaseModel): @pytest.mark.asyncio async def test_chat_completion_streaming( - self, inference_model, inference_stack, common_params, sample_messages + self, inference_model, inference_stack, common_params, sample_messages, model_id ): inference_impl, _ = inference_stack response = [ r async for r in await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=sample_messages, stream=True, **common_params, @@ -277,6 +276,7 @@ async def test_chat_completion_with_tool_calling( common_params, sample_messages, sample_tool_definition, + model_id, ): inference_impl, _ = inference_stack messages = sample_messages + [ @@ -286,7 +286,7 @@ async def test_chat_completion_with_tool_calling( ] response = await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=messages, tools=[sample_tool_definition], stream=False, @@ -316,6 +316,7 @@ async def test_chat_completion_with_tool_calling_streaming( common_params, sample_messages, sample_tool_definition, + model_id, ): inference_impl, _ = inference_stack messages = sample_messages + [ @@ -327,7 +328,7 @@ async def test_chat_completion_with_tool_calling_streaming( response = [ r async for r in await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=messages, tools=[sample_tool_definition], stream=True, diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 141e4af31..bdc5af0f9 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -29,7 +29,7 @@ def map_to_provider_model(self, identifier: str) -> str: return self.stack_to_provider_models_map[identifier] async def register_model(self, model: Model) -> None: - if model.identifier not in self.stack_to_provider_models_map: + if model.provider_resource_id not in self.stack_to_provider_models_map: raise ValueError( - f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}" + f"Unsupported model {model.provider_resource_id}. Supported models: {self.stack_to_provider_models_map.keys()}" ) From d69f4f8635b058e1a903e79477d94e6552d8a936 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 10:13:43 -0800 Subject: [PATCH 02/13] fix model provider validation and inference params --- .../inference/meta_reference/inference.py | 12 ++++++------ .../providers/inline/inference/vllm/vllm.py | 10 +++++----- .../remote/inference/ollama/ollama.py | 16 +++++++++------- .../providers/remote/inference/vllm/vllm.py | 19 +++++++++++++------ .../providers/tests/inference/fixtures.py | 2 +- 5 files changed, 34 insertions(+), 25 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 2fdc8f2d5..1e668b183 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -46,7 +46,7 @@ async def initialize(self) -> None: self.generator = Llama.build(self.config) async def register_model(self, model: Model) -> None: - if model.identifier != self.model.descriptor(): + if model.provider_resource_id != self.model.descriptor(): raise ValueError( f"Model mismatch: {model.identifier} != {self.model.descriptor()}" ) @@ -68,7 +68,7 @@ def check_model(self, request) -> None: async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -79,7 +79,7 @@ async def completion( assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" request = CompletionRequest( - model=model, + model=model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -186,7 +186,7 @@ def impl(): async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -201,7 +201,7 @@ async def chat_completion( # wrapper request to make it easier to pass around (internal only, not exposed to API) request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -386,7 +386,7 @@ def impl(): async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 3b1a0dd50..8869cc07f 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -110,7 +110,7 @@ def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParam async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -120,7 +120,7 @@ async def completion( log.info("vLLM completion") messages = [UserMessage(content=content)] return self.chat_completion( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, stream=stream, @@ -129,7 +129,7 @@ async def completion( async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, @@ -144,7 +144,7 @@ async def chat_completion( assert self.engine is not None request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -215,7 +215,7 @@ async def _generate_and_convert_to_openai_compat(): yield chunk async def embeddings( - self, model: str, contents: list[InterleavedTextMedia] + self, model_id: str, contents: list[InterleavedTextMedia] ) -> EmbeddingsResponse: log.info("vLLM embeddings") # TODO diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 938d05c08..f5750e0cf 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -66,8 +66,10 @@ async def shutdown(self) -> None: pass async def register_model(self, model: Model) -> None: - if model.identifier not in OLLAMA_SUPPORTED_MODELS: - raise ValueError(f"Model {model.identifier} is not supported by Ollama") + if model.provider_resource_id not in OLLAMA_SUPPORTED_MODELS: + raise ValueError( + f"Model {model.provider_resource_id} is not supported by Ollama" + ) async def list_models(self) -> List[Model]: ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()} @@ -94,7 +96,7 @@ async def list_models(self) -> List[Model]: async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -102,7 +104,7 @@ async def completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = CompletionRequest( - model=model, + model=model_id, content=content, sampling_params=sampling_params, stream=stream, @@ -148,7 +150,7 @@ async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenera async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -159,7 +161,7 @@ async def chat_completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -271,7 +273,7 @@ async def _generate_and_convert_to_openai_compat(): async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index bd7f5073c..3a8b8c326 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -45,8 +45,15 @@ async def initialize(self) -> None: self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) async def register_model(self, model: Model) -> None: - for running_model in self.client.models.list(): - repo = running_model.id + pass + + async def shutdown(self) -> None: + pass + + async def list_models(self) -> List[Model]: + models = [] + for model in self.client.models.list(): + repo = model.id if repo not in self.huggingface_repo_to_llama_model_id: print(f"Unknown model served by vllm: {repo}") continue @@ -67,7 +74,7 @@ async def shutdown(self) -> None: async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -78,7 +85,7 @@ async def completion( async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -89,7 +96,7 @@ async def chat_completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -173,7 +180,7 @@ async def _get_params( async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 2a4eba3ad..59bd492b9 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -49,7 +49,7 @@ def inference_meta_reference(inference_model) -> ProviderFixture: providers=[ Provider( provider_id=f"meta-reference-{i}", - provider_type="meta-reference", + provider_type="inline::meta-reference", config=MetaReferenceInferenceConfig( model=m, max_seq_len=4096, From 25d8ab0e14861d8dc477ad50b3a25ff6ab07abf2 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 10:17:51 -0800 Subject: [PATCH 03/13] fix bedrock --- llama_stack/providers/remote/inference/bedrock/bedrock.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index d9f82c611..7900e096d 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -49,7 +49,7 @@ async def shutdown(self) -> None: async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -286,7 +286,7 @@ def _tools_to_tool_config( async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -299,7 +299,7 @@ async def chat_completion( ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -433,7 +433,7 @@ def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dic async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() From 8de4cee3735e46507fc4eec87b17c2b0871ed446 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 13:07:35 -0800 Subject: [PATCH 04/13] working fireworks and together --- llama_stack/distribution/routers/routers.py | 9 +-- .../distribution/routers/routing_tables.py | 27 +++++-- .../remote/inference/bedrock/bedrock.py | 28 ++++--- .../remote/inference/databricks/databricks.py | 2 +- .../remote/inference/fireworks/fireworks.py | 79 ++++++++++++++----- .../remote/inference/together/together.py | 73 ++++++++++++----- .../utils/inference/model_registry.py | 60 +++++++++----- .../utils/inference/prompt_adapter.py | 13 +-- 8 files changed, 205 insertions(+), 86 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 7d4e43f60..5a62b6d64 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -105,9 +105,8 @@ async def chat_completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - model = await self.routing_table.get_model(model_id) params = dict( - model_id=model.provider_resource_id, + model_id=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -132,10 +131,9 @@ async def completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - model = await self.routing_table.get_model(model_id) provider = self.routing_table.get_provider_impl(model_id) params = dict( - model_id=model.provider_resource_id, + model_id=model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -152,9 +150,8 @@ async def embeddings( model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - model = await self.routing_table.get_model(model_id) return await self.routing_table.get_provider_impl(model_id).embeddings( - model_id=model.provider_resource_id, + model_id=model_id, contents=contents, ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d6fb5d662..249d3a144 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -28,7 +28,9 @@ def get_impl_api(p: Any) -> Api: return p.__provider_spec__.api -async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: +# TODO: this should return the registered object for all APIs +async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: + api = get_impl_api(p) if obj.provider_id == "remote": @@ -42,7 +44,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: obj.provider_id = "" if api == Api.inference: - await p.register_model(obj) + return await p.register_model(obj) elif api == Api.safety: await p.register_shield(obj) elif api == Api.memory: @@ -167,7 +169,9 @@ async def get_object_by_identifier( assert len(objects) == 1 return objects[0] - async def register_object(self, obj: RoutableObjectWithProvider): + async def register_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: # Get existing objects from registry existing_objects = await self.dist_registry.get(obj.type, obj.identifier) @@ -177,7 +181,7 @@ async def register_object(self, obj: RoutableObjectWithProvider): print( f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" ) - return + return existing_obj # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: @@ -188,8 +192,15 @@ async def register_object(self, obj: RoutableObjectWithProvider): p = self.impls_by_provider_id[obj.provider_id] - await register_object_with_provider(obj, p) - await self.dist_registry.register(obj) + registered_obj = await register_object_with_provider(obj, p) + # TODO: This needs to be fixed for all APIs once they return the registered object + if obj.type == ResourceType.model.value: + await self.dist_registry.register(registered_obj) + return registered_obj + + else: + await self.dist_registry.register(obj) + return obj async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() @@ -228,8 +239,8 @@ async def register_model( provider_id=provider_id, metadata=metadata, ) - await self.register_object(model) - return model + registered_model = await self.register_object(model) + return registered_model class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 7900e096d..2f1378696 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -11,7 +11,10 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) from llama_stack.apis.inference import * # noqa: F403 @@ -19,19 +22,26 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client -BEDROCK_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", - "Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", - "Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", -} +model_aliases = [ + ModelAlias( + provider_model_id="meta.llama3-1-8b-instruct-v1:0", + aliases=["Llama3.1-8B"], + ), + ModelAlias( + provider_model_id="meta.llama3-1-70b-instruct-v1:0", + aliases=["Llama3.1-70B"], + ), + ModelAlias( + provider_model_id="meta.llama3-1-405b-instruct-v1:0", + aliases=["Llama3.1-405B"], + ), +] # NOTE: this is not quite tested after the recent refactors class BedrockInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: BedrockConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, model_aliases) self._config = config self._client = create_bedrock_client(config) diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index f12ecb7f5..8e1f7693a 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -37,7 +37,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: DatabricksImplConfig) -> None: ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS + self, provider_to_common_model_aliases_map=DATABRICKS_SUPPORTED_MODELS ) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 67bf1cf47..ce9639cbd 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -7,14 +7,17 @@ from typing import AsyncGenerator from fireworks.client import Fireworks +from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer - from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -31,25 +34,61 @@ from .config import FireworksImplConfig -FIREWORKS_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct", - "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct", - "Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct", - "Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct", - "Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct", - "Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct", - "Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct", - "Llama-Guard-3-8B": "fireworks/llama-guard-3-8b", -} + +model_aliases = [ + ModelAlias( + provider_model_id="fireworks/llama-v3p1-8b-instruct", + aliases=["Llama3.1-8B-Instruct"], + llama_model=CoreModelId.llama3_1_8b_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p1-70b-instruct", + aliases=["Llama3.1-70B-Instruct"], + llama_model=CoreModelId.llama3_1_70b_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p1-405b-instruct", + aliases=["Llama3.1-405B-Instruct"], + llama_model=CoreModelId.llama3_1_405b_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p2-1b-instruct", + aliases=["Llama3.2-1B-Instruct"], + llama_model=CoreModelId.llama3_2_3b_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p2-3b-instruct", + aliases=["Llama3.2-3B-Instruct"], + llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p2-11b-vision-instruct", + aliases=["Llama3.2-11B-Vision-Instruct"], + llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p2-90b-vision-instruct", + aliases=["Llama3.2-90B-Vision-Instruct"], + llama_model=CoreModelId.llama3_2_90b_vision_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-guard-3-8b", + aliases=["Llama-Guard-3-8B"], + llama_model=CoreModelId.llama_guard_3_8b.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-guard-3-11b-vision", + aliases=["Llama-Guard-3-11B-Vision"], + llama_model=CoreModelId.llama_guard_3_11b_vision.value, + ), +] class FireworksInferenceAdapter( ModelRegistryHelper, Inference, NeedsRequestProviderData ): def __init__(self, config: FireworksImplConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, model_aliases) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -81,8 +120,9 @@ async def completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model_id, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -148,8 +188,9 @@ async def chat_completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -207,7 +248,7 @@ async def _get_params( ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( @@ -221,7 +262,7 @@ async def _get_params( input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] return { - "model": self.map_to_provider_model(request.model), + "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format), diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 1b04ae556..75f93f64f 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -6,6 +6,8 @@ from typing import AsyncGenerator +from llama_models.datatypes import CoreModelId + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message @@ -15,7 +17,10 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -33,25 +38,55 @@ from .config import TogetherImplConfig -TOGETHER_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo", - "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", - "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", - "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", - "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", -} +model_aliases = [ + ModelAlias( + provider_model_id="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + aliases=["Llama3.1-8B-Instruct"], + llama_model=CoreModelId.llama3_1_8b_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + aliases=["Llama3.1-70B-Instruct"], + llama_model=CoreModelId.llama3_1_70b_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + aliases=["Llama3.1-405B-Instruct"], + llama_model=CoreModelId.llama3_1_405b_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Llama-3.2-3B-Instruct-Turbo", + aliases=["Llama3.2-3B-Instruct"], + llama_model=CoreModelId.llama3_2_3b_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", + aliases=["Llama3.2-11B-Vision-Instruct"], + llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", + aliases=["Llama3.2-90B-Vision-Instruct"], + llama_model=CoreModelId.llama3_2_90b_vision_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Meta-Llama-Guard-3-8B", + aliases=["Llama-Guard-3-8B"], + llama_model=CoreModelId.llama_guard_3_8b.value, + ), + ModelAlias( + provider_model_id="meta-llama/Llama-Guard-3-11B-Vision-Turbo", + aliases=["Llama-Guard-3-11B-Vision"], + llama_model=CoreModelId.llama_guard_3_11b_vision.value, + ), +] class TogetherInferenceAdapter( ModelRegistryHelper, Inference, NeedsRequestProviderData ): def __init__(self, config: TogetherImplConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, model_aliases) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -70,8 +105,9 @@ async def completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model_id, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -145,8 +181,9 @@ async def chat_completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -204,7 +241,7 @@ async def _get_params( ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( @@ -213,7 +250,7 @@ async def _get_params( input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) return { - "model": self.map_to_provider_model(request.model), + "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format), diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index bdc5af0f9..b3401d8f5 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -4,32 +4,54 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict - -from llama_models.sku_list import resolve_model +from collections import namedtuple +from typing import List from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) + + +class ModelLookup: + def __init__( + self, + model_aliases: List[ModelAlias], + ): + self.alias_to_provider_id_map = {} + self.provider_id_to_llama_model_map = {} + for alias_obj in model_aliases: + for alias in alias_obj.aliases: + self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id + # also add a mapping from provider model id to itself for easy lookup + self.alias_to_provider_id_map[alias_obj.provider_model_id] = ( + alias_obj.provider_model_id + ) + self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = ( + alias_obj.llama_model + ) + + def get_provider_model_id(self, identifier: str) -> str: + if identifier in self.alias_to_provider_id_map: + return self.alias_to_provider_id_map[identifier] + else: + raise ValueError(f"Unknown model: `{identifier}`") + class ModelRegistryHelper(ModelsProtocolPrivate): - def __init__(self, stack_to_provider_models_map: Dict[str, str]): - self.stack_to_provider_models_map = stack_to_provider_models_map + def __init__(self, model_aliases: List[ModelAlias]): + self.model_lookup = ModelLookup(model_aliases) - def map_to_provider_model(self, identifier: str) -> str: - model = resolve_model(identifier) - if not model: - raise ValueError(f"Unknown model: `{identifier}`") + def get_llama_model(self, provider_model_id: str) -> str: + return self.model_lookup.provider_id_to_llama_model_map[provider_model_id] - if identifier not in self.stack_to_provider_models_map: - raise ValueError( - f"Model {identifier} not found in map {self.stack_to_provider_models_map}" - ) + async def register_model(self, model: Model) -> Model: + provider_model_id = self.model_lookup.get_provider_model_id( + model.provider_resource_id + ) + if not provider_model_id: + raise ValueError(f"Unknown model: `{model.provider_resource_id}`") - return self.stack_to_provider_models_map[identifier] + model.provider_resource_id = provider_model_id - async def register_model(self, model: Model) -> None: - if model.provider_resource_id not in self.stack_to_provider_models_map: - raise ValueError( - f"Unsupported model {model.provider_resource_id}. Supported models: {self.stack_to_provider_models_map.keys()}" - ) + return model diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 45e43c898..2df04664f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -147,17 +147,17 @@ def augment_content_with_response_format_prompt(response_format, content): def chat_completion_request_to_prompt( - request: ChatCompletionRequest, formatter: ChatFormat + request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat ) -> str: - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, llama_model) model_input = formatter.encode_dialog_prompt(messages) return formatter.tokenizer.decode(model_input.tokens) def chat_completion_request_to_model_input_info( - request: ChatCompletionRequest, formatter: ChatFormat + request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat ) -> Tuple[str, int]: - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, llama_model) model_input = formatter.encode_dialog_prompt(messages) return ( formatter.tokenizer.decode(model_input.tokens), @@ -167,14 +167,15 @@ def chat_completion_request_to_model_input_info( def chat_completion_request_to_messages( request: ChatCompletionRequest, + llama_model: str, ) -> List[Message]: """Reads chat completion request and augments the messages to handle tools. For eg. for llama_3_1, add system message with the appropriate tools or add user messsage for custom tools, etc. """ - model = resolve_model(request.model) + model = resolve_model(llama_model) if model is None: - cprint(f"Could not resolve model {request.model}", color="red") + cprint(f"Could not resolve model {llama_model}", color="red") return request.messages if model.descriptor() not in supported_inference_models(): From 5b2282afd452483143007f6216b27375ac62ffc5 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 13:17:27 -0800 Subject: [PATCH 05/13] ollama and databricks --- .../remote/inference/databricks/databricks.py | 32 ++++++-- .../remote/inference/ollama/ollama.py | 76 ++++++++++++++----- 2 files changed, 79 insertions(+), 29 deletions(-) diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 8e1f7693a..fedea0f86 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -6,6 +6,8 @@ from typing import AsyncGenerator +from llama_models.datatypes import CoreModelId + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message @@ -15,7 +17,10 @@ from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -28,16 +33,25 @@ from .config import DatabricksImplConfig -DATABRICKS_SUPPORTED_MODELS = { - "Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct", - "Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct", -} +model_aliases = [ + ModelAlias( + provider_model_id="databricks-meta-llama-3-1-70b-instruct", + aliases=["Llama3.1-70B-Instruct"], + llama_model=CoreModelId.llama3_1_70b_instruct.value, + ), + ModelAlias( + provider_model_id="databricks-meta-llama-3-1-405b-instruct", + aliases=["Llama3.1-405B-Instruct"], + llama_model=CoreModelId.llama3_1_405b_instruct.value, + ), +] class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: DatabricksImplConfig) -> None: ModelRegistryHelper.__init__( - self, provider_to_common_model_aliases_map=DATABRICKS_SUPPORTED_MODELS + self, + model_aliases=model_aliases, ) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -113,8 +127,10 @@ async def _to_async_generator(): def _get_params(self, request: ChatCompletionRequest) -> dict: return { - "model": self.map_to_provider_model(request.model), - "prompt": chat_completion_request_to_prompt(request, self.formatter), + "model": request.model, + "prompt": chat_completion_request_to_prompt( + request, self.get_llama_model(request.model), self.formatter + ), "stream": request.stream, **get_sampling_options(request.sampling_params), } diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f5750e0cf..bc80c7db2 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -7,13 +7,18 @@ from typing import AsyncGenerator import httpx +from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer - from ollama import AsyncClient +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) + from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate @@ -33,19 +38,52 @@ request_has_media, ) -OLLAMA_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", - "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", - "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16", - "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", - "Llama-Guard-3-8B": "llama-guard3:8b", - "Llama-Guard-3-1B": "llama-guard3:1b", - "Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16", -} - -class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): +model_aliases = [ + ModelAlias( + provider_model_id="llama3.1:8b-instruct-fp16", + aliases=["Llama3.1-8B-Instruct"], + llama_model=CoreModelId.llama3_1_8b_instruct.value, + ), + ModelAlias( + provider_model_id="llama3.1:70b-instruct-fp16", + aliases=["Llama3.1-70B-Instruct"], + llama_model=CoreModelId.llama3_1_70b_instruct.value, + ), + ModelAlias( + provider_model_id="llama3.2:1b-instruct-fp16", + aliases=["Llama3.2-1B-Instruct"], + llama_model=CoreModelId.llama3_2_1b_instruct.value, + ), + ModelAlias( + provider_model_id="llama3.2:3b-instruct-fp16", + aliases=["Llama3.2-3B-Instruct"], + llama_model=CoreModelId.llama3_2_3b_instruct.value, + ), + ModelAlias( + provider_model_id="llama-guard3:8b", + aliases=["Llama-Guard-3-8B"], + llama_model=CoreModelId.llama_guard_3_8b.value, + ), + ModelAlias( + provider_model_id="llama-guard3:1b", + aliases=["Llama-Guard-3-1B"], + llama_model=CoreModelId.llama_guard_3_1b.value, + ), + ModelAlias( + provider_model_id="x/llama3.2-vision:11b-instruct-fp16", + aliases=["Llama3.2-11B-Vision-Instruct"], + llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + ), +] + + +class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): def __init__(self, url: str) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=model_aliases, + ) self.url = url self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -65,12 +103,6 @@ async def initialize(self) -> None: async def shutdown(self) -> None: pass - async def register_model(self, model: Model) -> None: - if model.provider_resource_id not in OLLAMA_SUPPORTED_MODELS: - raise ValueError( - f"Model {model.provider_resource_id} is not supported by Ollama" - ) - async def list_models(self) -> List[Model]: ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()} @@ -103,8 +135,9 @@ async def completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model_id, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, stream=stream, @@ -160,8 +193,9 @@ async def chat_completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -199,7 +233,7 @@ async def _get_params( else: input_dict["raw"] = True input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( From 71219b493717478ba135fd831f675be3bd555ee8 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 13:23:02 -0800 Subject: [PATCH 06/13] ollama --- .../remote/inference/ollama/ollama.py | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index bc80c7db2..4a7f548a6 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -20,7 +20,7 @@ ) from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, @@ -103,29 +103,6 @@ async def initialize(self) -> None: async def shutdown(self) -> None: pass - async def list_models(self) -> List[Model]: - ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()} - - ret = [] - res = await self.client.ps() - for r in res["models"]: - if r["model"] not in ollama_to_llama: - print(f"Ollama is running a model unknown to Llama Stack: {r['model']}") - continue - - llama_model = ollama_to_llama[r["model"]] - print(f"Found model {llama_model} in Ollama") - ret.append( - Model( - identifier=llama_model, - metadata={ - "ollama_model": r["model"], - }, - ) - ) - - return ret - async def completion( self, model_id: str, @@ -243,7 +220,7 @@ async def _get_params( input_dict["raw"] = True return { - "model": OLLAMA_SUPPORTED_MODELS[request.model], + "model": request.model, **input_dict, "options": sampling_options, "stream": request.stream, From 92ee627e89edc2070259194477674696b6af524e Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 13:59:46 -0800 Subject: [PATCH 07/13] vllm --- .../providers/remote/inference/vllm/vllm.py | 68 +++++++++---------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 3a8b8c326..c49541fd9 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -8,13 +8,17 @@ from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import all_registered_models, resolve_model +from llama_models.sku_list import all_registered_models from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -30,8 +34,24 @@ from .config import VLLMInferenceAdapterConfig -class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): +def build_model_aliases(): + return [ + ModelAlias( + provider_model_id=model.huggingface_repo, + aliases=[model.descriptor()], + llama_model=model.descriptor(), + ) + for model in all_registered_models() + if model.huggingface_repo + ] + + +class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=build_model_aliases(), + ) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None @@ -44,31 +64,6 @@ def __init__(self, config: VLLMInferenceAdapterConfig) -> None: async def initialize(self) -> None: self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) - async def register_model(self, model: Model) -> None: - pass - - async def shutdown(self) -> None: - pass - - async def list_models(self) -> List[Model]: - models = [] - for model in self.client.models.list(): - repo = model.id - if repo not in self.huggingface_repo_to_llama_model_id: - print(f"Unknown model served by vllm: {repo}") - continue - - identifier = self.huggingface_repo_to_llama_model_id[repo] - if identifier == model.provider_resource_id: - print( - f"Verified that model {model.provider_resource_id} is being served by vLLM" - ) - return - - raise ValueError( - f"Model {model.provider_resource_id} is not being served by vLLM" - ) - async def shutdown(self) -> None: pass @@ -95,8 +90,9 @@ async def chat_completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -148,10 +144,6 @@ async def _get_params( if "max_tokens" not in options: options["max_tokens"] = self.config.max_tokens - model = resolve_model(request.model) - if model is None: - raise ValueError(f"Unknown model: {request.model}") - input_dict = {} media_present = request_has_media(request) if isinstance(request, ChatCompletionRequest): @@ -163,16 +155,20 @@ async def _get_params( ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( not media_present ), "Together does not support media for Completion requests" - input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = completion_request_to_prompt( + request, + self.get_llama_model(request.model), + self.formatter, + ) return { - "model": model.huggingface_repo, + "model": request.model, **input_dict, "stream": request.stream, **options, From d5874735eada4c4a0d91b0c8d1cd639cd0941292 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 14:08:47 -0800 Subject: [PATCH 08/13] bedrock --- .../providers/remote/inference/bedrock/bedrock.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 2f1378696..47abff689 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -7,6 +7,7 @@ from typing import * # noqa: F403 from botocore.client import BaseClient +from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -25,15 +26,18 @@ model_aliases = [ ModelAlias( provider_model_id="meta.llama3-1-8b-instruct-v1:0", - aliases=["Llama3.1-8B"], + aliases=["Llama3.1-8B-Instruct"], + llama_model=CoreModelId.llama3_1_8b_instruct, ), ModelAlias( provider_model_id="meta.llama3-1-70b-instruct-v1:0", - aliases=["Llama3.1-70B"], + aliases=["Llama3.1-70B-Instruct"], + llama_model=CoreModelId.llama3_1_70b_instruct, ), ModelAlias( provider_model_id="meta.llama3-1-405b-instruct-v1:0", - aliases=["Llama3.1-405B"], + aliases=["Llama3.1-405B-Instruct"], + llama_model=CoreModelId.llama3_1_405b_instruct, ), ] @@ -308,8 +312,9 @@ async def chat_completion( ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -414,7 +419,7 @@ async def _stream_chat_completion( pass def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: - bedrock_model = self.map_to_provider_model(request.model) + bedrock_model = request.model inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( request.sampling_params ) From 948f6ece6ec70c8567bf3bc4917e586958c44c8b Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 14:25:28 -0800 Subject: [PATCH 09/13] fixes for all providers --- .../inference/meta_reference/inference.py | 17 ++++- .../remote/inference/bedrock/bedrock.py | 23 +++---- .../remote/inference/databricks/databricks.py | 16 ++--- .../remote/inference/fireworks/fireworks.py | 65 ++++++++----------- .../remote/inference/ollama/ollama.py | 51 +++++++-------- .../remote/inference/together/together.py | 58 +++++++---------- .../providers/remote/inference/vllm/vllm.py | 14 ++-- .../utils/inference/model_registry.py | 24 ++++++- 8 files changed, 133 insertions(+), 135 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 1e668b183..844cf6939 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -11,9 +11,11 @@ from llama_models.sku_list import resolve_model from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import build_model_alias +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_media_to_url, request_has_media, @@ -28,10 +30,19 @@ SEMAPHORE = asyncio.Semaphore(1) -class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): +class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate): def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config model = resolve_model(config.model) + ModelRegistryHelper.__init__( + self, + [ + build_model_alias( + model.descriptor(), + model.core_model_id, + ) + ], + ) if model is None: raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") self.model = model diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 47abff689..8762a6c95 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -13,7 +13,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.providers.utils.inference.model_registry import ( - ModelAlias, + build_model_alias, ModelRegistryHelper, ) @@ -24,20 +24,17 @@ model_aliases = [ - ModelAlias( - provider_model_id="meta.llama3-1-8b-instruct-v1:0", - aliases=["Llama3.1-8B-Instruct"], - llama_model=CoreModelId.llama3_1_8b_instruct, + build_model_alias( + "meta.llama3-1-8b-instruct-v1:0", + CoreModelId.llama3_1_8b_instruct, ), - ModelAlias( - provider_model_id="meta.llama3-1-70b-instruct-v1:0", - aliases=["Llama3.1-70B-Instruct"], - llama_model=CoreModelId.llama3_1_70b_instruct, + build_model_alias( + "meta.llama3-1-70b-instruct-v1:0", + CoreModelId.llama3_1_70b_instruct, ), - ModelAlias( - provider_model_id="meta.llama3-1-405b-instruct-v1:0", - aliases=["Llama3.1-405B-Instruct"], - llama_model=CoreModelId.llama3_1_405b_instruct, + build_model_alias( + "meta.llama3-1-405b-instruct-v1:0", + CoreModelId.llama3_1_405b_instruct, ), ] diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index fedea0f86..1337b6c09 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -18,7 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.model_registry import ( - ModelAlias, + build_model_alias, ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( @@ -34,15 +34,13 @@ model_aliases = [ - ModelAlias( - provider_model_id="databricks-meta-llama-3-1-70b-instruct", - aliases=["Llama3.1-70B-Instruct"], - llama_model=CoreModelId.llama3_1_70b_instruct.value, + build_model_alias( + "databricks-meta-llama-3-1-70b-instruct", + CoreModelId.llama3_1_70b_instruct, ), - ModelAlias( - provider_model_id="databricks-meta-llama-3-1-405b-instruct", - aliases=["Llama3.1-405B-Instruct"], - llama_model=CoreModelId.llama3_1_405b_instruct.value, + build_model_alias( + "databricks-meta-llama-3-1-405b-instruct", + CoreModelId.llama3_1_405b_instruct, ), ] diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index ce9639cbd..e0d42c721 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -15,7 +15,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( - ModelAlias, + build_model_alias, ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( @@ -36,50 +36,41 @@ model_aliases = [ - ModelAlias( - provider_model_id="fireworks/llama-v3p1-8b-instruct", - aliases=["Llama3.1-8B-Instruct"], - llama_model=CoreModelId.llama3_1_8b_instruct.value, + build_model_alias( + "fireworks/llama-v3p1-8b-instruct", + CoreModelId.llama3_1_8b_instruct, ), - ModelAlias( - provider_model_id="fireworks/llama-v3p1-70b-instruct", - aliases=["Llama3.1-70B-Instruct"], - llama_model=CoreModelId.llama3_1_70b_instruct.value, + build_model_alias( + "fireworks/llama-v3p1-70b-instruct", + CoreModelId.llama3_1_70b_instruct, ), - ModelAlias( - provider_model_id="fireworks/llama-v3p1-405b-instruct", - aliases=["Llama3.1-405B-Instruct"], - llama_model=CoreModelId.llama3_1_405b_instruct.value, + build_model_alias( + "fireworks/llama-v3p1-405b-instruct", + CoreModelId.llama3_1_405b_instruct, ), - ModelAlias( - provider_model_id="fireworks/llama-v3p2-1b-instruct", - aliases=["Llama3.2-1B-Instruct"], - llama_model=CoreModelId.llama3_2_3b_instruct.value, + build_model_alias( + "fireworks/llama-v3p2-1b-instruct", + CoreModelId.llama3_2_3b_instruct, ), - ModelAlias( - provider_model_id="fireworks/llama-v3p2-3b-instruct", - aliases=["Llama3.2-3B-Instruct"], - llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + build_model_alias( + "fireworks/llama-v3p2-3b-instruct", + CoreModelId.llama3_2_11b_vision_instruct, ), - ModelAlias( - provider_model_id="fireworks/llama-v3p2-11b-vision-instruct", - aliases=["Llama3.2-11B-Vision-Instruct"], - llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + build_model_alias( + "fireworks/llama-v3p2-11b-vision-instruct", + CoreModelId.llama3_2_11b_vision_instruct, ), - ModelAlias( - provider_model_id="fireworks/llama-v3p2-90b-vision-instruct", - aliases=["Llama3.2-90B-Vision-Instruct"], - llama_model=CoreModelId.llama3_2_90b_vision_instruct.value, + build_model_alias( + "fireworks/llama-v3p2-90b-vision-instruct", + CoreModelId.llama3_2_90b_vision_instruct, ), - ModelAlias( - provider_model_id="fireworks/llama-guard-3-8b", - aliases=["Llama-Guard-3-8B"], - llama_model=CoreModelId.llama_guard_3_8b.value, + build_model_alias( + "fireworks/llama-guard-3-8b", + CoreModelId.llama_guard_3_8b, ), - ModelAlias( - provider_model_id="fireworks/llama-guard-3-11b-vision", - aliases=["Llama-Guard-3-11B-Vision"], - llama_model=CoreModelId.llama_guard_3_11b_vision.value, + build_model_alias( + "fireworks/llama-guard-3-11b-vision", + CoreModelId.llama_guard_3_11b_vision, ), ] diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 4a7f548a6..34af95b50 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -15,7 +15,7 @@ from ollama import AsyncClient from llama_stack.providers.utils.inference.model_registry import ( - ModelAlias, + build_model_alias, ModelRegistryHelper, ) @@ -40,40 +40,33 @@ model_aliases = [ - ModelAlias( - provider_model_id="llama3.1:8b-instruct-fp16", - aliases=["Llama3.1-8B-Instruct"], - llama_model=CoreModelId.llama3_1_8b_instruct.value, + build_model_alias( + "llama3.1:8b-instruct-fp16", + CoreModelId.llama3_1_8b_instruct, ), - ModelAlias( - provider_model_id="llama3.1:70b-instruct-fp16", - aliases=["Llama3.1-70B-Instruct"], - llama_model=CoreModelId.llama3_1_70b_instruct.value, + build_model_alias( + "llama3.1:70b-instruct-fp16", + CoreModelId.llama3_1_70b_instruct, ), - ModelAlias( - provider_model_id="llama3.2:1b-instruct-fp16", - aliases=["Llama3.2-1B-Instruct"], - llama_model=CoreModelId.llama3_2_1b_instruct.value, + build_model_alias( + "llama3.2:1b-instruct-fp16", + CoreModelId.llama3_2_1b_instruct, ), - ModelAlias( - provider_model_id="llama3.2:3b-instruct-fp16", - aliases=["Llama3.2-3B-Instruct"], - llama_model=CoreModelId.llama3_2_3b_instruct.value, + build_model_alias( + "llama3.2:3b-instruct-fp16", + CoreModelId.llama3_2_3b_instruct, ), - ModelAlias( - provider_model_id="llama-guard3:8b", - aliases=["Llama-Guard-3-8B"], - llama_model=CoreModelId.llama_guard_3_8b.value, + build_model_alias( + "llama-guard3:8b", + CoreModelId.llama_guard_3_8b, ), - ModelAlias( - provider_model_id="llama-guard3:1b", - aliases=["Llama-Guard-3-1B"], - llama_model=CoreModelId.llama_guard_3_1b.value, + build_model_alias( + "llama-guard3:1b", + CoreModelId.llama_guard_3_1b, ), - ModelAlias( - provider_model_id="x/llama3.2-vision:11b-instruct-fp16", - aliases=["Llama3.2-11B-Vision-Instruct"], - llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + build_model_alias( + "x/llama3.2-vision:11b-instruct-fp16", + CoreModelId.llama3_2_11b_vision_instruct, ), ] diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 75f93f64f..644302a0f 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -18,7 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( - ModelAlias, + build_model_alias, ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( @@ -39,45 +39,37 @@ model_aliases = [ - ModelAlias( - provider_model_id="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - aliases=["Llama3.1-8B-Instruct"], - llama_model=CoreModelId.llama3_1_8b_instruct.value, + build_model_alias( + "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + CoreModelId.llama3_1_8b_instruct, ), - ModelAlias( - provider_model_id="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - aliases=["Llama3.1-70B-Instruct"], - llama_model=CoreModelId.llama3_1_70b_instruct.value, + build_model_alias( + "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + CoreModelId.llama3_1_70b_instruct, ), - ModelAlias( - provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - aliases=["Llama3.1-405B-Instruct"], - llama_model=CoreModelId.llama3_1_405b_instruct.value, + build_model_alias( + "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + CoreModelId.llama3_1_405b_instruct, ), - ModelAlias( - provider_model_id="meta-llama/Llama-3.2-3B-Instruct-Turbo", - aliases=["Llama3.2-3B-Instruct"], - llama_model=CoreModelId.llama3_2_3b_instruct.value, + build_model_alias( + "meta-llama/Llama-3.2-3B-Instruct-Turbo", + CoreModelId.llama3_2_3b_instruct, ), - ModelAlias( - provider_model_id="meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", - aliases=["Llama3.2-11B-Vision-Instruct"], - llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + build_model_alias( + "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", + CoreModelId.llama3_2_11b_vision_instruct, ), - ModelAlias( - provider_model_id="meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", - aliases=["Llama3.2-90B-Vision-Instruct"], - llama_model=CoreModelId.llama3_2_90b_vision_instruct.value, + build_model_alias( + "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", + CoreModelId.llama3_2_90b_vision_instruct, ), - ModelAlias( - provider_model_id="meta-llama/Meta-Llama-Guard-3-8B", - aliases=["Llama-Guard-3-8B"], - llama_model=CoreModelId.llama_guard_3_8b.value, + build_model_alias( + "meta-llama/Meta-Llama-Guard-3-8B", + CoreModelId.llama_guard_3_8b, ), - ModelAlias( - provider_model_id="meta-llama/Llama-Guard-3-11B-Vision-Turbo", - aliases=["Llama-Guard-3-11B-Vision"], - llama_model=CoreModelId.llama_guard_3_11b_vision.value, + build_model_alias( + "meta-llama/Llama-Guard-3-11B-Vision-Turbo", + CoreModelId.llama_guard_3_11b_vision, ), ] diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index c49541fd9..9bf25c5ad 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -16,7 +16,7 @@ from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( - ModelAlias, + build_model_alias, ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( @@ -36,10 +36,9 @@ def build_model_aliases(): return [ - ModelAlias( - provider_model_id=model.huggingface_repo, - aliases=[model.descriptor()], - llama_model=model.descriptor(), + build_model_alias( + model.huggingface_repo, + model.core_model_id, ) for model in all_registered_models() if model.huggingface_repo @@ -55,11 +54,6 @@ def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None - self.huggingface_repo_to_llama_model_id = { - model.huggingface_repo: model.descriptor() - for model in all_registered_models() - if model.huggingface_repo - } async def initialize(self) -> None: self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index b3401d8f5..35d67a4cc 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -5,13 +5,35 @@ # the root directory of this source tree. from collections import namedtuple -from typing import List +from typing import List, Optional + +from llama_models.datatypes import CoreModelId +from llama_models.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) +def get_huggingface_repo(core_model_id: CoreModelId) -> Optional[str]: + """Get the Hugging Face repository for a given CoreModelId.""" + for model in all_registered_models(): + if model.core_model_id == core_model_id: + return model.huggingface_repo + return None + + +def build_model_alias(provider_model_id: str, core_model_id: CoreModelId) -> ModelAlias: + return ModelAlias( + provider_model_id=provider_model_id, + aliases=[ + core_model_id.value, + get_huggingface_repo(core_model_id), + ], + llama_model=core_model_id.value, + ) + + class ModelLookup: def __init__( self, From 919d421bcf3bcd73123caedc08c741566e1f0d14 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 15:37:07 -0800 Subject: [PATCH 10/13] fixes after rebase --- .../inference/meta_reference/generation.py | 14 +++++-- .../inference/meta_reference/inference.py | 8 +--- .../remote/inference/bedrock/bedrock.py | 6 +-- .../remote/inference/databricks/databricks.py | 4 +- .../remote/inference/fireworks/fireworks.py | 18 ++++---- .../remote/inference/ollama/ollama.py | 15 +++---- .../remote/inference/together/together.py | 16 ++++---- .../providers/remote/inference/vllm/vllm.py | 3 +- .../providers/tests/inference/fixtures.py | 4 +- .../tests/inference/test_text_inference.py | 41 +++++++++---------- .../utils/inference/model_registry.py | 13 +++--- 11 files changed, 72 insertions(+), 70 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 2f296c7c2..38c982473 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -86,6 +86,7 @@ def build( and loads the pre-trained model and tokenizer. """ model = resolve_model(config.model) + llama_model = model.core_model_id.value if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") @@ -186,13 +187,20 @@ def build( model.load_state_dict(state_dict, strict=False) print(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, model_args) + return Llama(model, tokenizer, model_args, llama_model) - def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs): + def __init__( + self, + model: Transformer, + tokenizer: Tokenizer, + args: ModelArgs, + llama_model: str, + ): self.args = args self.model = model self.tokenizer = tokenizer self.formatter = ChatFormat(tokenizer) + self.llama_model = llama_model @torch.inference_mode() def generate( @@ -369,7 +377,7 @@ def chat_completion( self, request: ChatCompletionRequest, ) -> Generator: - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, self.llama_model) sampling_params = request.sampling_params max_gen_len = sampling_params.max_tokens diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 844cf6939..4f5c0c8c2 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -39,7 +39,7 @@ def __init__(self, config: MetaReferenceInferenceConfig) -> None: [ build_model_alias( model.descriptor(), - model.core_model_id, + model.core_model_id.value, ) ], ) @@ -56,12 +56,6 @@ async def initialize(self) -> None: else: self.generator = Llama.build(self.config) - async def register_model(self, model: Model) -> None: - if model.provider_resource_id != self.model.descriptor(): - raise ValueError( - f"Model mismatch: {model.identifier} != {self.model.descriptor()}" - ) - async def shutdown(self) -> None: if self.config.create_distributed_process_group: self.generator.stop() diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 8762a6c95..f575d9dc3 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -26,15 +26,15 @@ model_aliases = [ build_model_alias( "meta.llama3-1-8b-instruct-v1:0", - CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_8b_instruct.value, ), build_model_alias( "meta.llama3-1-70b-instruct-v1:0", - CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_70b_instruct.value, ), build_model_alias( "meta.llama3-1-405b-instruct-v1:0", - CoreModelId.llama3_1_405b_instruct, + CoreModelId.llama3_1_405b_instruct.value, ), ] diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 1337b6c09..0ebb625bc 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -36,11 +36,11 @@ model_aliases = [ build_model_alias( "databricks-meta-llama-3-1-70b-instruct", - CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_70b_instruct.value, ), build_model_alias( "databricks-meta-llama-3-1-405b-instruct", - CoreModelId.llama3_1_405b_instruct, + CoreModelId.llama3_1_405b_instruct.value, ), ] diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index e0d42c721..42075eff7 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -38,39 +38,39 @@ model_aliases = [ build_model_alias( "fireworks/llama-v3p1-8b-instruct", - CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_8b_instruct.value, ), build_model_alias( "fireworks/llama-v3p1-70b-instruct", - CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_70b_instruct.value, ), build_model_alias( "fireworks/llama-v3p1-405b-instruct", - CoreModelId.llama3_1_405b_instruct, + CoreModelId.llama3_1_405b_instruct.value, ), build_model_alias( "fireworks/llama-v3p2-1b-instruct", - CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_3b_instruct.value, ), build_model_alias( "fireworks/llama-v3p2-3b-instruct", - CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_11b_vision_instruct.value, ), build_model_alias( "fireworks/llama-v3p2-11b-vision-instruct", - CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_11b_vision_instruct.value, ), build_model_alias( "fireworks/llama-v3p2-90b-vision-instruct", - CoreModelId.llama3_2_90b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct.value, ), build_model_alias( "fireworks/llama-guard-3-8b", - CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_8b.value, ), build_model_alias( "fireworks/llama-guard-3-11b-vision", - CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_11b_vision.value, ), ] diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 34af95b50..99f74572e 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -42,31 +42,31 @@ model_aliases = [ build_model_alias( "llama3.1:8b-instruct-fp16", - CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_8b_instruct.value, ), build_model_alias( "llama3.1:70b-instruct-fp16", - CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_70b_instruct.value, ), build_model_alias( "llama3.2:1b-instruct-fp16", - CoreModelId.llama3_2_1b_instruct, + CoreModelId.llama3_2_1b_instruct.value, ), build_model_alias( "llama3.2:3b-instruct-fp16", - CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_3b_instruct.value, ), build_model_alias( "llama-guard3:8b", - CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_8b.value, ), build_model_alias( "llama-guard3:1b", - CoreModelId.llama_guard_3_1b, + CoreModelId.llama_guard_3_1b.value, ), build_model_alias( "x/llama3.2-vision:11b-instruct-fp16", - CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_11b_vision_instruct.value, ), ] @@ -164,6 +164,7 @@ async def chat_completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) + print(f"model={model}") request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 644302a0f..aae34bb87 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -41,35 +41,35 @@ model_aliases = [ build_model_alias( "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_8b_instruct.value, ), build_model_alias( "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_70b_instruct.value, ), build_model_alias( "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - CoreModelId.llama3_1_405b_instruct, + CoreModelId.llama3_1_405b_instruct.value, ), build_model_alias( "meta-llama/Llama-3.2-3B-Instruct-Turbo", - CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_3b_instruct.value, ), build_model_alias( "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", - CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_11b_vision_instruct.value, ), build_model_alias( "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", - CoreModelId.llama3_2_90b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct.value, ), build_model_alias( "meta-llama/Meta-Llama-Guard-3-8B", - CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_8b.value, ), build_model_alias( "meta-llama/Llama-Guard-3-11B-Vision-Turbo", - CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_11b_vision.value, ), ] diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 9bf25c5ad..2d03a9ef8 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -38,7 +38,7 @@ def build_model_aliases(): return [ build_model_alias( model.huggingface_repo, - model.core_model_id, + model.descriptor(), ) for model in all_registered_models() if model.huggingface_repo @@ -85,6 +85,7 @@ async def chat_completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) + print(f"model={model}") request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 59bd492b9..f6f2a30e8 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -179,7 +179,7 @@ def model_id(inference_model) -> str: @pytest_asyncio.fixture(scope="session") -async def inference_stack(request, inference_model, model_id): +async def inference_stack(request, inference_model): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") impls = await resolve_impls_for_test_v2( @@ -188,7 +188,7 @@ async def inference_stack(request, inference_model, model_id): inference_fixture.provider_data, models=[ ModelInput( - model_id=model_id, + model_id=inference_model, ) ], ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 9850b328e..70047a61f 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -64,7 +64,7 @@ def sample_tool_definition(): class TestInference: @pytest.mark.asyncio - async def test_model_list(self, inference_model, inference_stack, model_id): + async def test_model_list(self, inference_model, inference_stack): _, models_impl = inference_stack response = await models_impl.list_models() assert isinstance(response, list) @@ -73,16 +73,17 @@ async def test_model_list(self, inference_model, inference_stack, model_id): model_def = None for model in response: - if model.identifier == model_id: + if model.identifier == inference_model: model_def = model break assert model_def is not None @pytest.mark.asyncio - async def test_completion(self, inference_model, inference_stack, model_id): + async def test_completion(self, inference_model, inference_stack): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(model_id) + + provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::ollama", @@ -95,7 +96,7 @@ async def test_completion(self, inference_model, inference_stack, model_id): response = await inference_impl.completion( content="Micheael Jordan is born in ", stream=False, - model_id=model_id, + model_id=inference_model, sampling_params=SamplingParams( max_tokens=50, ), @@ -109,7 +110,7 @@ async def test_completion(self, inference_model, inference_stack, model_id): async for r in await inference_impl.completion( content="Roses are red,", stream=True, - model_id=model_id, + model_id=inference_model, sampling_params=SamplingParams( max_tokens=50, ), @@ -124,11 +125,11 @@ async def test_completion(self, inference_model, inference_stack, model_id): @pytest.mark.asyncio @pytest.mark.skip("This test is not quite robust") async def test_completions_structured_output( - self, inference_model, inference_stack, model_id + self, inference_model, inference_stack ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(model_id) + provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::tgi", @@ -148,7 +149,7 @@ class Output(BaseModel): response = await inference_impl.completion( content=user_input, stream=False, - model_id=model_id, + model=inference_model, sampling_params=SamplingParams( max_tokens=50, ), @@ -166,11 +167,11 @@ class Output(BaseModel): @pytest.mark.asyncio async def test_chat_completion_non_streaming( - self, inference_model, inference_stack, common_params, sample_messages, model_id + self, inference_model, inference_stack, common_params, sample_messages ): inference_impl, _ = inference_stack response = await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=sample_messages, stream=False, **common_params, @@ -183,11 +184,11 @@ async def test_chat_completion_non_streaming( @pytest.mark.asyncio async def test_structured_output( - self, inference_model, inference_stack, common_params, model_id + self, inference_model, inference_stack, common_params ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(model_id) + provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::fireworks", @@ -203,7 +204,7 @@ class AnswerFormat(BaseModel): num_seasons_in_nba: int response = await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(content="Please give me information about Michael Jordan."), @@ -226,7 +227,7 @@ class AnswerFormat(BaseModel): assert answer.num_seasons_in_nba == 15 response = await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(content="Please give me information about Michael Jordan."), @@ -243,13 +244,13 @@ class AnswerFormat(BaseModel): @pytest.mark.asyncio async def test_chat_completion_streaming( - self, inference_model, inference_stack, common_params, sample_messages, model_id + self, inference_model, inference_stack, common_params, sample_messages ): inference_impl, _ = inference_stack response = [ r async for r in await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=sample_messages, stream=True, **common_params, @@ -276,7 +277,6 @@ async def test_chat_completion_with_tool_calling( common_params, sample_messages, sample_tool_definition, - model_id, ): inference_impl, _ = inference_stack messages = sample_messages + [ @@ -286,7 +286,7 @@ async def test_chat_completion_with_tool_calling( ] response = await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=messages, tools=[sample_tool_definition], stream=False, @@ -316,7 +316,6 @@ async def test_chat_completion_with_tool_calling_streaming( common_params, sample_messages, sample_tool_definition, - model_id, ): inference_impl, _ = inference_stack messages = sample_messages + [ @@ -328,7 +327,7 @@ async def test_chat_completion_with_tool_calling_streaming( response = [ r async for r in await inference_impl.chat_completion( - model_id=model_id, + model_id=inference_model, messages=messages, tools=[sample_tool_definition], stream=True, diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 35d67a4cc..c44c641a2 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -7,7 +7,6 @@ from collections import namedtuple from typing import List, Optional -from llama_models.datatypes import CoreModelId from llama_models.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate @@ -15,22 +14,22 @@ ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) -def get_huggingface_repo(core_model_id: CoreModelId) -> Optional[str]: +def get_huggingface_repo(model_descriptor: str) -> Optional[str]: """Get the Hugging Face repository for a given CoreModelId.""" for model in all_registered_models(): - if model.core_model_id == core_model_id: + if model.descriptor() == model_descriptor: return model.huggingface_repo return None -def build_model_alias(provider_model_id: str, core_model_id: CoreModelId) -> ModelAlias: +def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias: return ModelAlias( provider_model_id=provider_model_id, aliases=[ - core_model_id.value, - get_huggingface_repo(core_model_id), + model_descriptor, + get_huggingface_repo(model_descriptor), ], - llama_model=core_model_id.value, + llama_model=model_descriptor, ) From 55d66ca918a192f365a3776be8c4401694ce61e6 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 15:47:41 -0800 Subject: [PATCH 11/13] run openapi gen --- docs/resources/llama-stack-spec.html | 50 ++++++++++++++-------------- docs/resources/llama-stack-spec.yaml | 42 +++++++++++------------ 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 7ef4ece21..f87cb5590 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543" }, "servers": [ { @@ -2856,7 +2856,7 @@ "ChatCompletionRequest": { "type": "object", "properties": { - "model": { + "model_id": { "type": "string" }, "messages": { @@ -2993,7 +2993,7 @@ }, "additionalProperties": false, "required": [ - "model", + "model_id", "messages" ] }, @@ -3120,7 +3120,7 @@ "CompletionRequest": { "type": "object", "properties": { - "model": { + "model_id": { "type": "string" }, "content": { @@ -3249,7 +3249,7 @@ }, "additionalProperties": false, "required": [ - "model", + "model_id", "content" ] }, @@ -4552,7 +4552,7 @@ "EmbeddingsRequest": { "type": "object", "properties": { - "model": { + "model_id": { "type": "string" }, "contents": { @@ -4584,7 +4584,7 @@ }, "additionalProperties": false, "required": [ - "model", + "model_id", "contents" ] }, @@ -7837,58 +7837,58 @@ ], "tags": [ { - "name": "MemoryBanks" + "name": "Safety" }, { - "name": "BatchInference" + "name": "EvalTasks" }, { - "name": "Agents" + "name": "Shields" }, { - "name": "Inference" + "name": "Telemetry" }, { - "name": "DatasetIO" + "name": "Memory" }, { - "name": "Eval" + "name": "Scoring" }, { - "name": "Models" + "name": "ScoringFunctions" }, { - "name": "PostTraining" + "name": "SyntheticDataGeneration" }, { - "name": "ScoringFunctions" + "name": "Models" }, { - "name": "Datasets" + "name": "Agents" }, { - "name": "Shields" + "name": "MemoryBanks" }, { - "name": "Telemetry" + "name": "DatasetIO" }, { - "name": "Inspect" + "name": "Inference" }, { - "name": "Safety" + "name": "Datasets" }, { - "name": "SyntheticDataGeneration" + "name": "PostTraining" }, { - "name": "Memory" + "name": "BatchInference" }, { - "name": "Scoring" + "name": "Eval" }, { - "name": "EvalTasks" + "name": "Inspect" }, { "name": "BuiltinTool", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index b86c0df61..87268ff47 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -396,7 +396,7 @@ components: - $ref: '#/components/schemas/ToolResponseMessage' - $ref: '#/components/schemas/CompletionMessage' type: array - model: + model_id: type: string response_format: oneOf: @@ -453,7 +453,7 @@ components: $ref: '#/components/schemas/ToolDefinition' type: array required: - - model + - model_id - messages type: object ChatCompletionResponse: @@ -577,7 +577,7 @@ components: default: 0 type: integer type: object - model: + model_id: type: string response_format: oneOf: @@ -626,7 +626,7 @@ components: stream: type: boolean required: - - model + - model_id - content type: object CompletionResponse: @@ -903,10 +903,10 @@ components: - $ref: '#/components/schemas/ImageMedia' type: array type: array - model: + model_id: type: string required: - - model + - model_id - contents type: object EmbeddingsResponse: @@ -3384,7 +3384,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782" + \ draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4748,24 +4748,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: MemoryBanks -- name: BatchInference -- name: Agents -- name: Inference -- name: DatasetIO -- name: Eval -- name: Models -- name: PostTraining -- name: ScoringFunctions -- name: Datasets +- name: Safety +- name: EvalTasks - name: Shields - name: Telemetry -- name: Inspect -- name: Safety -- name: SyntheticDataGeneration - name: Memory - name: Scoring -- name: EvalTasks +- name: ScoringFunctions +- name: SyntheticDataGeneration +- name: Models +- name: Agents +- name: MemoryBanks +- name: DatasetIO +- name: Inference +- name: Datasets +- name: PostTraining +- name: BatchInference +- name: Eval +- name: Inspect - description: name: BuiltinTool - description: Date: Tue, 12 Nov 2024 18:14:58 -0800 Subject: [PATCH 12/13] fix evals and scoring --- llama_stack/providers/inline/eval/meta_reference/eval.py | 2 +- .../scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py | 2 +- llama_stack/providers/remote/inference/vllm/vllm.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index ba2fc7c95..58241eb42 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -150,7 +150,7 @@ async def evaluate_rows( messages.append(candidate.system_message) messages += input_messages response = await self.inference_api.chat_completion( - model=candidate.model, + model_id=candidate.model, messages=messages, sampling_params=candidate.sampling_params, ) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index a950f35f9..4b43de93f 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -62,7 +62,7 @@ async def score_row( ) judge_response = await self.inference_api.chat_completion( - model=fn_def.params.judge_model, + model_id=fn_def.params.judge_model, messages=[ { "role": "user", diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 2d03a9ef8..e5eb6e1ea 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -85,7 +85,6 @@ async def chat_completion( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) - print(f"model={model}") request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, From 1bb01f934608672736716f65438c7b7f5ee79cd2 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 20:00:48 -0800 Subject: [PATCH 13/13] remove model lookup class --- docs/source/getting_started/index.md | 2 +- .../utils/inference/model_registry.py | 22 ++++--------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/docs/source/getting_started/index.md b/docs/source/getting_started/index.md index d1d61d770..eb95db7cc 100644 --- a/docs/source/getting_started/index.md +++ b/docs/source/getting_started/index.md @@ -538,7 +538,7 @@ Once the server is set up, we can test it with a client to verify it's working c $ curl http://localhost:5000/inference/chat_completion \ -H "Content-Type: application/json" \ -d '{ - "model": "Llama3.1-8B-Instruct", + "model_id": "Llama3.1-8B-Instruct", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Write me a 2 sentence poem about the moon"} diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index c44c641a2..7120e9e97 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -15,7 +15,6 @@ def get_huggingface_repo(model_descriptor: str) -> Optional[str]: - """Get the Hugging Face repository for a given CoreModelId.""" for model in all_registered_models(): if model.descriptor() == model_descriptor: return model.huggingface_repo @@ -33,11 +32,8 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli ) -class ModelLookup: - def __init__( - self, - model_aliases: List[ModelAlias], - ): +class ModelRegistryHelper(ModelsProtocolPrivate): + def __init__(self, model_aliases: List[ModelAlias]): self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} for alias_obj in model_aliases: @@ -57,22 +53,12 @@ def get_provider_model_id(self, identifier: str) -> str: else: raise ValueError(f"Unknown model: `{identifier}`") - -class ModelRegistryHelper(ModelsProtocolPrivate): - - def __init__(self, model_aliases: List[ModelAlias]): - self.model_lookup = ModelLookup(model_aliases) - def get_llama_model(self, provider_model_id: str) -> str: - return self.model_lookup.provider_id_to_llama_model_map[provider_model_id] + return self.provider_id_to_llama_model_map[provider_model_id] async def register_model(self, model: Model) -> Model: - provider_model_id = self.model_lookup.get_provider_model_id( + model.provider_resource_id = self.get_provider_model_id( model.provider_resource_id ) - if not provider_model_id: - raise ValueError(f"Unknown model: `{model.provider_resource_id}`") - - model.provider_resource_id = provider_model_id return model