From 2dd8c4bcb6216daebeaafac282add176cb7b5047 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 22 Oct 2024 14:31:11 -0400 Subject: [PATCH 1/9] add NVIDIA NIM inference adapter --- .../adapters/inference/nvidia/__init__.py | 18 + .../adapters/inference/nvidia/_config.py | 52 +++ .../adapters/inference/nvidia/_nvidia.py | 176 ++++++++++ .../adapters/inference/nvidia/_utils.py | 328 ++++++++++++++++++ llama_stack/providers/registry/inference.py | 9 + tests/nvidia/README.md | 26 ++ tests/nvidia/integration/conftest.py | 67 ++++ tests/nvidia/integration/test_inference.py | 117 +++++++ tests/nvidia/unit/conftest.py | 73 ++++ tests/nvidia/unit/test_chat_completion.py | 203 +++++++++++ tests/nvidia/unit/test_health.py | 35 ++ tests/nvidia/unit/test_import.py | 11 + 12 files changed, 1115 insertions(+) create mode 100644 llama_stack/providers/adapters/inference/nvidia/__init__.py create mode 100644 llama_stack/providers/adapters/inference/nvidia/_config.py create mode 100644 llama_stack/providers/adapters/inference/nvidia/_nvidia.py create mode 100644 llama_stack/providers/adapters/inference/nvidia/_utils.py create mode 100644 tests/nvidia/README.md create mode 100644 tests/nvidia/integration/conftest.py create mode 100644 tests/nvidia/integration/test_inference.py create mode 100644 tests/nvidia/unit/conftest.py create mode 100644 tests/nvidia/unit/test_chat_completion.py create mode 100644 tests/nvidia/unit/test_health.py create mode 100644 tests/nvidia/unit/test_import.py diff --git a/llama_stack/providers/adapters/inference/nvidia/__init__.py b/llama_stack/providers/adapters/inference/nvidia/__init__.py new file mode 100644 index 0000000000..63b4669333 --- /dev/null +++ b/llama_stack/providers/adapters/inference/nvidia/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from ._config import NVIDIAConfig +from ._nvidia import NVIDIAInferenceAdapter + + +async def get_adapter_impl(config: NVIDIAConfig, _deps) -> NVIDIAInferenceAdapter: + if not isinstance(config, NVIDIAConfig): + raise RuntimeError(f"Unexpected config type: {type(config)}") + adapter = NVIDIAInferenceAdapter(config) + return adapter + + +__all__ = ["get_adapter_impl", "NVIDIAConfig"] diff --git a/llama_stack/providers/adapters/inference/nvidia/_config.py b/llama_stack/providers/adapters/inference/nvidia/_config.py new file mode 100644 index 0000000000..46ac3fa5ba --- /dev/null +++ b/llama_stack/providers/adapters/inference/nvidia/_config.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class NVIDIAConfig(BaseModel): + """ + Configuration for the NVIDIA NIM inference endpoint. + + Attributes: + base_url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000 + api_key (str): The access key for the hosted NIM endpoints + + There are two ways to access NVIDIA NIMs - + 0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com + 1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure + + By default the configuration is set to use the hosted APIs. This requires + an API key which can be obtained from https://ngc.nvidia.com/. + + By default the configuration will attempt to read the NVIDIA_API_KEY environment + variable to set the api_key. Please do not put your API key in code. + + If you are using a self-hosted NVIDIA NIM, you can set the base_url to the + URL of your running NVIDIA NIM and do not need to set the api_key. + """ + + base_url: str = Field( + default="https://integrate.api.nvidia.com", + description="A base url for accessing the NVIDIA NIM", + ) + api_key: Optional[str] = Field( + default_factory=lambda: os.getenv("NVIDIA_API_KEY"), + description="The NVIDIA API key, only needed of using the hosted service", + ) + timeout: int = Field( + default=60, + description="Timeout for the HTTP requests", + ) + + @property + def is_hosted(self) -> bool: + return "integrate.api.nvidia.com" in self.base_url diff --git a/llama_stack/providers/adapters/inference/nvidia/_nvidia.py b/llama_stack/providers/adapters/inference/nvidia/_nvidia.py new file mode 100644 index 0000000000..621e3e0db0 --- /dev/null +++ b/llama_stack/providers/adapters/inference/nvidia/_nvidia.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from typing import Dict, List, Optional, Union + +import httpx +from llama_models.datatypes import SamplingParams +from llama_models.llama3.api.datatypes import ( + InterleavedTextMedia, + Message, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_models.sku_list import CoreModelId + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + ModelDef, + ResponseFormat, +) + +from ._config import NVIDIAConfig +from ._utils import check_health, convert_chat_completion_request, parse_completion + +SUPPORTED_MODELS: Dict[CoreModelId, str] = { + CoreModelId.llama3_8b_instruct: "meta/llama3-8b-instruct", + CoreModelId.llama3_70b_instruct: "meta/llama3-70b-instruct", + CoreModelId.llama3_1_8b_instruct: "meta/llama-3.1-8b-instruct", + CoreModelId.llama3_1_70b_instruct: "meta/llama-3.1-70b-instruct", + CoreModelId.llama3_1_405b_instruct: "meta/llama-3.1-405b-instruct", + # TODO(mf): how do we handle Nemotron models? + # "Llama3.1-Nemotron-51B-Instruct": "meta/llama-3.1-nemotron-51b-instruct", + CoreModelId.llama3_2_1b_instruct: "meta/llama-3.2-1b-instruct", + CoreModelId.llama3_2_3b_instruct: "meta/llama-3.2-3b-instruct", + CoreModelId.llama3_2_11b_vision_instruct: "meta/llama-3.2-11b-vision-instruct", + CoreModelId.llama3_2_90b_vision_instruct: "meta/llama-3.2-90b-vision-instruct", +} + + +class NVIDIAInferenceAdapter(Inference): + def __init__(self, config: NVIDIAConfig) -> None: + + print(f"Initializing NVIDIAInferenceAdapter({config.base_url})...") + + if config.is_hosted: + if not config.api_key: + raise RuntimeError( + "API key is required for hosted NVIDIA NIM. " + "Either provide an API key or use a self-hosted NIM." + ) + # elif self._config.api_key: + # + # we don't raise this warning because a user may have deployed their + # self-hosted NIM with an API key requirement. + # + # warnings.warn( + # "API key is not required for self-hosted NVIDIA NIM. " + # "Consider removing the api_key from the configuration." + # ) + + self._config = config + + @property + def _headers(self) -> dict: + return { + b"User-Agent": b"llama-stack: nvidia-inference-adapter", + **( + {b"Authorization": f"Bearer {self._config.api_key}"} + if self._config.api_key + else {} + ), + } + + async def list_models(self) -> List[ModelDef]: + # TODO(mf): filter by available models + return [ + ModelDef(identifier=model, llama_model=id_) + for model, id_ in SUPPORTED_MODELS.items() + ] + + def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + raise NotImplementedError() + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ + ToolPromptFormat + ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: + if tool_prompt_format: + warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") + + if stream: + raise ValueError("Streamed completions are not supported") + + await check_health(self._config) # this raises errors + + request = ChatCompletionRequest( + model=SUPPORTED_MODELS[CoreModelId(model)], + messages=messages, + sampling_params=sampling_params, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + + async with httpx.AsyncClient(timeout=self._config.timeout) as client: + try: + response = await client.post( + f"{self._config.base_url}/v1/chat/completions", + headers=self._headers, + json=convert_chat_completion_request(request, n=1), + ) + except httpx.ReadTimeout as e: + raise TimeoutError( + f"Request timed out. timeout set to {self._config.timeout}. Use `llama stack configure ...` to adjust it." + ) from e + + if response.status_code == 401: + raise PermissionError( + "Unauthorized. Please check your API key, reconfigure, and try again." + ) + + if response.status_code == 400: + raise ValueError( + f"Bad request. Please check the request and try again. Detail: {response.text}" + ) + + if response.status_code == 404: + raise ValueError( + "Model not found. Please check the model name and try again." + ) + + assert ( + response.status_code == 200 + ), f"Failed to get completion: {response.text}" + + # we pass n=1 to get only one completion + return parse_completion(response.json()["choices"][0]) diff --git a/llama_stack/providers/adapters/inference/nvidia/_utils.py b/llama_stack/providers/adapters/inference/nvidia/_utils.py new file mode 100644 index 0000000000..6b90750500 --- /dev/null +++ b/llama_stack/providers/adapters/inference/nvidia/_utils.py @@ -0,0 +1,328 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from typing import Any, Dict, List, Optional, Tuple + +import httpx +from llama_models.llama3.api.datatypes import ( + CompletionMessage, + StopReason, + TokenLogProbs, + ToolCall, +) + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + Message, +) + +from ._config import NVIDIAConfig + + +def convert_message(message: Message) -> dict: + """ + Convert a Message to an OpenAI API-compatible dictionary. + """ + out_dict = message.dict() + # Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool" + if out_dict["role"] == "ipython": + out_dict.update(role="tool") + + if "stop_reason" in out_dict: + out_dict.update(stop_reason=out_dict["stop_reason"].value) + + # TODO(mf): tool_calls + + return out_dict + + +async def _get_health(url: str) -> Tuple[bool, bool]: + """ + Query {url}/v1/health/{live,ready} to check if the server is running and ready + + Args: + url (str): URL of the server + + Returns: + Tuple[bool, bool]: (is_live, is_ready) + """ + async with httpx.AsyncClient() as client: + live = await client.get(f"{url}/v1/health/live") + ready = await client.get(f"{url}/v1/health/ready") + return live.status_code == 200, ready.status_code == 200 + + +async def check_health(config: NVIDIAConfig) -> None: + """ + Check if the server is running and ready + + Args: + url (str): URL of the server + + Raises: + RuntimeError: If the server is not running or ready + """ + if not config.is_hosted: + print("Checking NVIDIA NIM health...") + try: + is_live, is_ready = await _get_health(config.base_url) + if not is_live: + raise ConnectionError("NVIDIA NIM is not running") + if not is_ready: + raise ConnectionError("NVIDIA NIM is not ready") + # TODO(mf): should we wait for the server to be ready? + except httpx.ConnectError as e: + raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e + + +def convert_chat_completion_request( + request: ChatCompletionRequest, + n: int = 1, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. + """ + # model -> model + # messages -> messages + # sampling_params TODO(mattf): review strategy + # strategy=greedy -> nvext.top_k = -1, temperature = temperature + # strategy=top_p -> nvext.top_k = -1, top_p = top_p + # strategy=top_k -> nvext.top_k = top_k + # temperature -> temperature + # top_p -> top_p + # top_k -> nvext.top_k + # max_tokens -> max_tokens + # repetition_penalty -> nvext.repetition_penalty + # tools -> tools + # tool_choice ("auto", "required") -> tool_choice + # tool_prompt_format -> TBD + # stream -> stream + # logprobs -> logprobs + + print(f"sampling_params: {request.sampling_params}") + + payload: Dict[str, Any] = dict( + model=request.model, + messages=[convert_message(message) for message in request.messages], + stream=request.stream, + nvext={}, + n=n, + ) + nvext = payload["nvext"] + + if request.tools: + payload.update(tools=request.tools) + if request.tool_choice: + payload.update( + tool_choice=request.tool_choice.value + ) # we cannot include tool_choice w/o tools, server will complain + + if request.logprobs: + payload.update(logprobs=True) + payload.update(top_logprobs=request.logprobs.top_k) + + if request.sampling_params: + nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) + + if request.sampling_params.max_tokens: + payload.update(max_tokens=request.sampling_params.max_tokens) + + if request.sampling_params.strategy == "top_p": + nvext.update(top_k=-1) + payload.update(top_p=request.sampling_params.top_p) + elif request.sampling_params.strategy == "top_k": + if ( + request.sampling_params.top_k != -1 + and request.sampling_params.top_k < 1 + ): + warnings.warn("top_k must be -1 or >= 1") + nvext.update(top_k=request.sampling_params.top_k) + elif request.sampling_params.strategy == "greedy": + nvext.update(top_k=-1) + payload.update(temperature=request.sampling_params.temperature) + + return payload + + +def _parse_content(completion: dict) -> str: + """ + Get the content from an OpenAI completion response. + + OpenAI completion response format - + { + ... + "message": {"role": "assistant", "content": ..., ...}, + ... + } + """ + # content is nullable in the OpenAI response, common for tool calls + return completion["message"]["content"] or "" + + +def _parse_stop_reason(completion: dict) -> StopReason: + """ + Get the StopReason from an OpenAI completion response. + + OpenAI completion response format - + { + ... + "finish_reason": "length" or "stop" or "tool_calls", + ... + } + """ + + # StopReason options are end_of_turn, end_of_message, out_of_tokens + # TODO(mf): is end_of_turn and end_of_message usage correct? + stop_reason = StopReason.end_of_turn + if completion["finish_reason"] == "length": + stop_reason = StopReason.out_of_tokens + elif completion["finish_reason"] == "stop": + stop_reason = StopReason.end_of_message + elif completion["finish_reason"] == "tool_calls": + stop_reason = StopReason.end_of_turn + return stop_reason + + +def _parse_tool_calls(completion: dict) -> List[ToolCall]: + """ + Get the tool calls from an OpenAI completion response. + + OpenAI completion response format - + { + ..., + "message": { + ..., + "tool_calls": [ + { + "id": X, + "type": "function", + "function": { + "name": Y, + "arguments": Z, + }, + }* + ], + }, + } + -> + [ + ToolCall(call_id=X, tool_name=Y, arguments=Z), + ... + ] + """ + tool_calls = [] + if "tool_calls" in completion["message"]: + assert isinstance( + completion["message"]["tool_calls"], list + ), "error in server response: tool_calls not a list" + for call in completion["message"]["tool_calls"]: + assert "id" in call, "error in server response: tool call id not found" + assert ( + "function" in call + ), "error in server response: tool call function not found" + assert ( + "name" in call["function"] + ), "error in server response: tool call function name not found" + assert ( + "arguments" in call["function"] + ), "error in server response: tool call function arguments not found" + tool_calls.append( + ToolCall( + call_id=call["id"], + tool_name=call["function"]["name"], + arguments=call["function"]["arguments"], + ) + ) + + return tool_calls + + +def _parse_logprobs(completion: dict) -> Optional[List[TokenLogProbs]]: + """ + Extract logprobs from OpenAI as a list of TokenLogProbs. + + OpenAI completion response format - + { + ... + "logprobs": { + content: [ + { + ..., + top_logprobs: [{token: X, logprob: Y, bytes: [...]}+] + }+ + ] + }, + ... + } + -> + [ + TokenLogProbs( + logprobs_by_token={X: Y, ...} + ), + ... + ] + """ + if not (logprobs := completion.get("logprobs")): + return None + + return [ + TokenLogProbs( + logprobs_by_token={ + logprobs["token"]: logprobs["logprob"] + for logprobs in content["top_logprobs"] + } + ) + for content in logprobs["content"] + ] + + +def parse_completion( + completion: dict, +) -> ChatCompletionResponse: + """ + Parse an OpenAI completion response into a CompletionMessage and logprobs. + + OpenAI completion response format - + { + "message": { + "role": "assistant", + "content": ..., + "tool_calls": [ + { + ... + "id": ..., + "function": { + "name": ..., + "arguments": ..., + }, + }* + ]?, + "finish_reason": ..., + "logprobs": { + "content": [ + { + ..., + "top_logprobs": [{"token": ..., "logprob": ..., ...}+] + }+ + ] + }? + } + """ + assert "message" in completion, "error in server response: message not found" + assert ( + "finish_reason" in completion + ), "error in server response: finish_reason not found" + + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=_parse_content(completion), + stop_reason=_parse_stop_reason(completion), + tool_calls=_parse_tool_calls(completion), + ), + logprobs=_parse_logprobs(completion), + ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 88265f1b46..18397a08d7 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -140,6 +140,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="nvidia", + pip_packages=[], # TODO(mf): need to specify httpx if it's already a llama-stack dep? + module="llama_stack.providers.adapters.inference.nvidia", + config_class="llama_stack.providers.adapters.inference.nvidia.NVIDIAConfig", + ), + ), InlineProviderSpec( api=Api.inference, provider_type="vllm", diff --git a/tests/nvidia/README.md b/tests/nvidia/README.md new file mode 100644 index 0000000000..939a998d70 --- /dev/null +++ b/tests/nvidia/README.md @@ -0,0 +1,26 @@ +# NVIDIA tests + +## Running tests + +**Install the required dependencies:** + ```bash + pip install pytest pytest-asyncio pytest-httpx + ``` + +There are three modes for testing: + +1. Unit tests - this mode checks the provider functionality and does not require a network connection or running distribution + + ```bash + pytest tests/nvidia/unit + ``` + +2. Integration tests against hosted preview APIs - this mode checks the provider functionality against a live system and requires an API key. Get an API key by 0. going to https://build.nvidia.com, 1. selecting a Llama model, e.g. https://build.nvidia.com/meta/llama-3_1-8b-instruct, and 2. clicking "Get API Key". Store the API key in the `NVIDIA_API_KEY` environment variable. + + ```bash + export NVIDIA_API_KEY=... + + pytest tests/nvidia/integration --base-url https://integrate.api.nvidia.com + ``` + +3. Integration tests against a running distribution - this mode checks the provider functionality in the context of a running distribution. This involves running a local NIM, see https://build.nvidia.com/meta/llama-3_1-8b-instruct?snippet_tab=Docker, and creating & configuring a distribution to use it. Details to come. diff --git a/tests/nvidia/integration/conftest.py b/tests/nvidia/integration/conftest.py new file mode 100644 index 0000000000..0691b74537 --- /dev/null +++ b/tests/nvidia/integration/conftest.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest + +from llama_stack.apis.inference import Inference +from llama_stack.providers.adapters.inference.nvidia import ( + get_adapter_impl, + NVIDIAConfig, +) + + +def pytest_collection_modifyitems(config, items): + """ + Skip all integration tests if NVIDIA_API_KEY is not set and --base-url + includes "https://integrate.api.nvidia.com". It is needed to access the + hosted preview APIs. + """ + if "integrate.api.nvidia.com" in config.getoption( + "--base-url" + ) and not os.environ.get("NVIDIA_API_KEY"): + skip_nvidia = pytest.mark.skip( + reason="NVIDIA_API_KEY environment variable must be set to access integrate.api.nvidia.com" + ) + for item in items: + item.add_marker(skip_nvidia) + + +def pytest_addoption(parser): + parser.addoption( + "--base-url", + action="store", + default="http://localhost:8000", + help="Base URL for the tests", + ) + parser.addoption( + "--model", + action="store", + default="Llama-3-8B-Instruct", + help="Model option for the tests", + ) + + +@pytest.fixture +def base_url(request): + return request.config.getoption("--base-url") + + +@pytest.fixture +def model(request): + return request.config.getoption("--model") + + +@pytest.fixture +def client(base_url: str) -> Inference: + return get_adapter_impl( + NVIDIAConfig( + base_url=base_url, + api_key=os.environ.get("NVIDIA_API_KEY"), + ), + {}, + ) diff --git a/tests/nvidia/integration/test_inference.py b/tests/nvidia/integration/test_inference.py new file mode 100644 index 0000000000..2e7b33e4f4 --- /dev/null +++ b/tests/nvidia/integration/test_inference.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import itertools +from typing import Generator, List, Tuple + +import pytest + +from llama_stack.apis.inference import ( + ChatCompletionResponse, + CompletionMessage, + Inference, + Message, + StopReason, + SystemMessage, + ToolResponseMessage, + UserMessage, +) +from llama_stack.providers.adapters.inference.nvidia import ( + get_adapter_impl, + NVIDIAConfig, +) + +pytestmark = pytest.mark.asyncio + + +# TODO(mf): test bad creds raises PermissionError +# TODO(mf): test bad params, e.g. max_tokens=0 raises ValidationError +# TODO(mf): test bad model name raises ValueError +# TODO(mf): test short timeout raises TimeoutError +# TODO(mf): new file, test cli model listing +# TODO(mf): test streaming +# TODO(mf): test tool calls w/ tool_choice + + +def message_combinations( + length: int, +) -> Generator[Tuple[List[Message], str], None, None]: + """ + Generate all possible combinations of message types of given length. + """ + message_types = [ + UserMessage, + SystemMessage, + ToolResponseMessage, + CompletionMessage, + ] + for count in range(1, length + 1): + for combo in itertools.product(message_types, repeat=count): + messages = [] + for i, msg in enumerate(combo): + if msg == ToolResponseMessage: + messages.append( + msg( + content=f"Message {i + 1}", + call_id=f"call_{i + 1}", + tool_name=f"tool_{i + 1}", + ) + ) + elif msg == CompletionMessage: + messages.append( + msg(content=f"Message {i + 1}", stop_reason="end_of_message") + ) + else: + messages.append(msg(content=f"Message {i + 1}")) + id_str = "-".join([msg.__name__ for msg in combo]) + yield messages, id_str + + +@pytest.mark.parametrize("combo", message_combinations(3), ids=lambda x: x[1]) +async def test_chat_completion_messages( + client: Inference, + model: str, + combo: Tuple[List[Message], str], +): + """ + Test the chat completion endpoint with different message combinations. + """ + client = await client + messages, _ = combo + + response = await client.chat_completion( + model=model, + messages=messages, + stream=False, + ) + + assert isinstance(response, ChatCompletionResponse) + assert isinstance(response.completion_message.content, str) + # we're not testing accuracy, so no assertions on the result.completion_message.content + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.stop_reason, StopReason) + assert response.completion_message.tool_calls == [] + + +async def test_bad_base_url( + model: str, +): + """ + Test that a bad base_url raises a ConnectionError. + """ + client = await get_adapter_impl( + NVIDIAConfig( + base_url="http://localhost:32123", + ), + {}, + ) + + with pytest.raises(ConnectionError): + await client.chat_completion( + model=model, + messages=[UserMessage(content="Hello")], + stream=False, + ) diff --git a/tests/nvidia/unit/conftest.py b/tests/nvidia/unit/conftest.py new file mode 100644 index 0000000000..cdc0c50d7c --- /dev/null +++ b/tests/nvidia/unit/conftest.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest + +from llama_stack.apis.inference import Inference +from llama_stack.providers.adapters.inference.nvidia import ( + get_adapter_impl, + NVIDIAConfig, +) +from pytest_httpx import HTTPXMock + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def base_url(): + return "http://endpoint.mocked" + + +@pytest.fixture +def client(base_url: str) -> Inference: + return get_adapter_impl( + NVIDIAConfig( + base_url=base_url, + api_key=os.environ.get("NVIDIA_API_KEY"), + ), + {}, + ) + + +@pytest.fixture +def mock_health( + httpx_mock: HTTPXMock, + base_url: str, +) -> HTTPXMock: + for path in [ + "/v1/health/live", + "/v1/health/ready", + ]: + httpx_mock.add_response( + url=f"{base_url}{path}", + status_code=200, + ) + return httpx_mock + + +@pytest.fixture +def mock_chat_completion(httpx_mock: HTTPXMock, base_url: str) -> HTTPXMock: + httpx_mock.add_response( + url=f"{base_url}/v1/chat/completions", + json={ + "id": "mock-id", + "created": 1234567890, + "object": "chat.completion", + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "WORKED"}, + "finish_reason": "length", + } + ], + }, + status_code=200, + ) + + return httpx_mock diff --git a/tests/nvidia/unit/test_chat_completion.py b/tests/nvidia/unit/test_chat_completion.py new file mode 100644 index 0000000000..1608ad39a6 --- /dev/null +++ b/tests/nvidia/unit/test_chat_completion.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from llama_models.llama3.api.datatypes import TokenLogProbs, ToolCall + +from llama_stack.apis.inference import Inference +from pytest_httpx import HTTPXMock + +pytestmark = pytest.mark.asyncio + + +async def test_content( + mock_health: HTTPXMock, + httpx_mock: HTTPXMock, + client: Inference, + base_url: str, +) -> None: + """ + Test that response content makes it through to the completion message. + """ + httpx_mock.add_response( + url=f"{base_url}/v1/chat/completions", + json={ + "id": "mock-id", + "created": 1234567890, + "object": "chat.completion", + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "RESPONSE"}, + "finish_reason": "length", + } + ], + }, + status_code=200, + ) + + client = await client + + response = await client.chat_completion( + model="Llama-3-8B-Instruct", + messages=[{"role": "user", "content": "BOGUS"}], + stream=False, + ) + assert response.completion_message.content == "RESPONSE" + + +async def test_logprobs( + mock_health: HTTPXMock, + httpx_mock: HTTPXMock, + client: Inference, + base_url: str, +) -> None: + """ + Test that logprobs are parsed correctly. + """ + httpx_mock.add_response( + url=f"{base_url}/v1/chat/completions", + json={ + "id": "mock-id", + "object": "chat.completion", + "created": 1234567890, + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Hello there"}, + "logprobs": { + "content": [ + { + "token": "Hello", + "logprob": -0.1, + "bytes": [72, 101, 108, 108, 111], + "top_logprobs": [ + {"token": "Hello", "logprob": -0.1}, + {"token": "Hi", "logprob": -1.2}, + {"token": "Greetings", "logprob": -2.1}, + ], + }, + { + "token": "there", + "logprob": -0.2, + "bytes": [116, 104, 101, 114, 101], + "top_logprobs": [ + {"token": "there", "logprob": -0.2}, + {"token": "here", "logprob": -1.3}, + {"token": "where", "logprob": -2.2}, + ], + }, + ] + }, + "finish_reason": "length", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + }, + status_code=200, + ) + + client = await client + + response = await client.chat_completion( + model="Llama-3-8B-Instruct", + messages=[{"role": "user", "content": "Hello"}], + logprobs={"top_k": 3}, + stream=False, + ) + + assert response.logprobs == [ + TokenLogProbs( + logprobs_by_token={ + "Hello": -0.1, + "Hi": -1.2, + "Greetings": -2.1, + } + ), + TokenLogProbs( + logprobs_by_token={ + "there": -0.2, + "here": -1.3, + "where": -2.2, + } + ), + ] + + +async def test_tools( + mock_health: HTTPXMock, + httpx_mock: HTTPXMock, + client: Inference, + base_url: str, +) -> None: + """ + Test that tools are passed correctly. + """ + httpx_mock.add_response( + url=f"{base_url}/v1/chat/completions", + json={ + "id": "mock-id", + "object": "chat.completion", + "created": 1234567890, + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tool-id", + "type": "function", + "function": { + "name": "magic", + "arguments": {"input": 3}, + }, + }, + { + "id": "tool-id!", + "type": "function", + "function": { + "name": "magic!", + "arguments": {"input": 42}, + }, + }, + ], + }, + "logprobs": None, + "finish_reason": "tool_calls", + } + ], + }, + status_code=200, + ) + + client = await client + + response = await client.chat_completion( + model="Llama-3-8B-Instruct", + messages=[{"role": "user", "content": "Hello"}], + stream=False, + ) + + assert response.completion_message.tool_calls == [ + ToolCall( + call_id="tool-id", + tool_name="magic", + arguments={"input": 3}, + ), + ToolCall( + call_id="tool-id!", + tool_name="magic!", + arguments={"input": 42}, + ), + ] + + +# TODO(mf): test stream=True for each case diff --git a/tests/nvidia/unit/test_health.py b/tests/nvidia/unit/test_health.py new file mode 100644 index 0000000000..0e3d146a34 --- /dev/null +++ b/tests/nvidia/unit/test_health.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.inference import Inference +from pytest_httpx import HTTPXMock + +pytestmark = pytest.mark.asyncio + + +async def test_chat_completion( + mock_health: HTTPXMock, + mock_chat_completion: HTTPXMock, + client: Inference, + base_url: str, +) -> None: + """ + Test that health endpoints are checked when chat_completion is called. + """ + client = await client + + await client.chat_completion( + model="Llama-3-8B-Instruct", + messages=[{"role": "user", "content": "BOGUS"}], + stream=False, + ) + + +# TODO(mf): test stream=True for each case +# TODO(mf): test completion +# TODO(mf): test embedding diff --git a/tests/nvidia/unit/test_import.py b/tests/nvidia/unit/test_import.py new file mode 100644 index 0000000000..87e6672396 --- /dev/null +++ b/tests/nvidia/unit/test_import.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.adapters.inference.nvidia import __all__ + + +def test_import(): + assert set(__all__) == {"get_adapter_impl", "NVIDIAConfig"} From dbe665ed1974dff5679e8b62457f96822b6d3160 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 4 Nov 2024 10:22:29 -0500 Subject: [PATCH 2/9] enable streaming support, use openai-python instead of httpx --- .../adapters/inference/nvidia/_nvidia.py | 99 ++-- .../inference/nvidia/_openai_utils.py | 430 +++++++++++++++ .../adapters/inference/nvidia/_utils.py | 280 +--------- llama_stack/providers/registry/inference.py | 4 +- tests/nvidia/integration/test_inference.py | 68 +++ tests/nvidia/unit/test_chat_completion.py | 4 +- tests/nvidia/unit/test_openai_utils.py | 493 ++++++++++++++++++ 7 files changed, 1037 insertions(+), 341 deletions(-) create mode 100644 llama_stack/providers/adapters/inference/nvidia/_openai_utils.py create mode 100644 tests/nvidia/unit/test_openai_utils.py diff --git a/llama_stack/providers/adapters/inference/nvidia/_nvidia.py b/llama_stack/providers/adapters/inference/nvidia/_nvidia.py index 621e3e0db0..05ac92cd2c 100644 --- a/llama_stack/providers/adapters/inference/nvidia/_nvidia.py +++ b/llama_stack/providers/adapters/inference/nvidia/_nvidia.py @@ -5,9 +5,8 @@ # the root directory of this source tree. import warnings -from typing import Dict, List, Optional, Union +from typing import AsyncIterator, Dict, List, Optional, Union -import httpx from llama_models.datatypes import SamplingParams from llama_models.llama3.api.datatypes import ( InterleavedTextMedia, @@ -17,6 +16,7 @@ ToolPromptFormat, ) from llama_models.sku_list import CoreModelId +from openai import APIConnectionError, AsyncOpenAI from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -32,7 +32,12 @@ ) from ._config import NVIDIAConfig -from ._utils import check_health, convert_chat_completion_request, parse_completion +from ._openai_utils import ( + convert_chat_completion_request, + convert_openai_chat_completion_choice, + convert_openai_chat_completion_stream, +) +from ._utils import check_health SUPPORTED_MODELS: Dict[CoreModelId, str] = { CoreModelId.llama3_8b_instruct: "meta/llama3-8b-instruct", @@ -71,17 +76,12 @@ def __init__(self, config: NVIDIAConfig) -> None: # ) self._config = config - - @property - def _headers(self) -> dict: - return { - b"User-Agent": b"llama-stack: nvidia-inference-adapter", - **( - {b"Authorization": f"Bearer {self._config.api_key}"} - if self._config.api_key - else {} - ), - } + # make sure the client lives longer than any async calls + self._client = AsyncOpenAI( + base_url=f"{self._config.base_url}/v1", + api_key=self._config.api_key or "NO KEY", + timeout=self._config.timeout, + ) async def list_models(self) -> List[ModelDef]: # TODO(mf): filter by available models @@ -98,7 +98,7 @@ def completion( response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: raise NotImplementedError() async def embeddings( @@ -121,56 +121,37 @@ async def chat_completion( ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: if tool_prompt_format: warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") - if stream: - raise ValueError("Streamed completions are not supported") - await check_health(self._config) # this raises errors - request = ChatCompletionRequest( - model=SUPPORTED_MODELS[CoreModelId(model)], - messages=messages, - sampling_params=sampling_params, - tools=tools, - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, - stream=stream, - logprobs=logprobs, + request = convert_chat_completion_request( + request=ChatCompletionRequest( + model=SUPPORTED_MODELS[CoreModelId(model)], + messages=messages, + sampling_params=sampling_params, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ), + n=1, ) - async with httpx.AsyncClient(timeout=self._config.timeout) as client: - try: - response = await client.post( - f"{self._config.base_url}/v1/chat/completions", - headers=self._headers, - json=convert_chat_completion_request(request, n=1), - ) - except httpx.ReadTimeout as e: - raise TimeoutError( - f"Request timed out. timeout set to {self._config.timeout}. Use `llama stack configure ...` to adjust it." - ) from e - - if response.status_code == 401: - raise PermissionError( - "Unauthorized. Please check your API key, reconfigure, and try again." - ) - - if response.status_code == 400: - raise ValueError( - f"Bad request. Please check the request and try again. Detail: {response.text}" - ) - - if response.status_code == 404: - raise ValueError( - "Model not found. Please check the model name and try again." - ) - - assert ( - response.status_code == 200 - ), f"Failed to get completion: {response.text}" + try: + response = await self._client.chat.completions.create(**request) + except APIConnectionError as e: + raise ConnectionError( + f"Failed to connect to NVIDIA NIM at {self._config.base_url}: {e}" + ) from e + if stream: + return convert_openai_chat_completion_stream(response) + else: # we pass n=1 to get only one completion - return parse_completion(response.json()["choices"][0]) + return convert_openai_chat_completion_choice(response.choices[0]) diff --git a/llama_stack/providers/adapters/inference/nvidia/_openai_utils.py b/llama_stack/providers/adapters/inference/nvidia/_openai_utils.py new file mode 100644 index 0000000000..998b4c275c --- /dev/null +++ b/llama_stack/providers/adapters/inference/nvidia/_openai_utils.py @@ -0,0 +1,430 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import warnings +from typing import Any, AsyncGenerator, Dict, Generator, List, Optional + +from llama_models.llama3.api.datatypes import ( + CompletionMessage, + StopReason, + TokenLogProbs, + ToolCall, +) +from openai import AsyncStream +from openai.types.chat import ChatCompletionChunk as OpenAIChatCompletionChunk +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, + ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs +) +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall, +) + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + Message, + ToolCallDelta, + ToolCallParseStatus, +) + + +def _convert_message(message: Message) -> Dict: + """ + Convert a Message to an OpenAI API-compatible dictionary. + """ + out_dict = message.dict() + # Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool" + if out_dict["role"] == "ipython": + out_dict.update(role="tool") + + if "stop_reason" in out_dict: + out_dict.update(stop_reason=out_dict["stop_reason"].value) + + # TODO(mf): tool_calls + + return out_dict + + +def convert_chat_completion_request( + request: ChatCompletionRequest, + n: int = 1, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. + """ + # model -> model + # messages -> messages + # sampling_params TODO(mattf): review strategy + # strategy=greedy -> nvext.top_k = -1, temperature = temperature + # strategy=top_p -> nvext.top_k = -1, top_p = top_p + # strategy=top_k -> nvext.top_k = top_k + # temperature -> temperature + # top_p -> top_p + # top_k -> nvext.top_k + # max_tokens -> max_tokens + # repetition_penalty -> nvext.repetition_penalty + # tools -> tools + # tool_choice ("auto", "required") -> tool_choice + # tool_prompt_format -> TBD + # stream -> stream + # logprobs -> logprobs + + nvext = {} + payload: Dict[str, Any] = dict( + model=request.model, + messages=[_convert_message(message) for message in request.messages], + stream=request.stream, + n=n, + extra_body=dict(nvext=nvext), + extra_headers={ + b"User-Agent": b"llama-stack: nvidia-inference-adapter", + }, + ) + + if request.tools: + payload.update(tools=request.tools) + if request.tool_choice: + payload.update( + tool_choice=request.tool_choice.value + ) # we cannot include tool_choice w/o tools, server will complain + + if request.logprobs: + payload.update(logprobs=True) + payload.update(top_logprobs=request.logprobs.top_k) + + if request.sampling_params: + nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) + + if request.sampling_params.max_tokens: + payload.update(max_tokens=request.sampling_params.max_tokens) + + if request.sampling_params.strategy == "top_p": + nvext.update(top_k=-1) + payload.update(top_p=request.sampling_params.top_p) + elif request.sampling_params.strategy == "top_k": + if ( + request.sampling_params.top_k != -1 + and request.sampling_params.top_k < 1 + ): + warnings.warn("top_k must be -1 or >= 1") + nvext.update(top_k=request.sampling_params.top_k) + elif request.sampling_params.strategy == "greedy": + nvext.update(top_k=-1) + payload.update(temperature=request.sampling_params.temperature) + + return payload + + +def _convert_openai_finish_reason(finish_reason: str) -> StopReason: + """ + Convert an OpenAI chat completion finish_reason to a StopReason. + + finish_reason: Literal["stop", "length", "tool_calls", ...] + - stop: model hit a natural stop point or a provided stop sequence + - length: maximum number of tokens specified in the request was reached + - tool_calls: model called a tool + + -> + + class StopReason(Enum): + end_of_turn = "end_of_turn" + end_of_message = "end_of_message" + out_of_tokens = "out_of_tokens" + """ + + # TODO(mf): are end_of_turn and end_of_message semantics correct? + return { + "stop": StopReason.end_of_turn, + "length": StopReason.out_of_tokens, + "tool_calls": StopReason.end_of_message, + }.get(finish_reason, StopReason.end_of_turn) + + +def _convert_openai_tool_calls( + tool_calls: List[OpenAIChatCompletionMessageToolCall], +) -> List[ToolCall]: + """ + Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. + + OpenAI ChatCompletionMessageToolCall: + id: str + function: Function + type: Literal["function"] + + OpenAI Function: + arguments: str + name: str + + -> + + ToolCall: + call_id: str + tool_name: str + arguments: Dict[str, ...] + """ + if not tool_calls: + return [] # CompletionMessage tool_calls is not optional + + return [ + ToolCall( + call_id=call.id, + tool_name=call.function.name, + arguments=json.loads(call.function.arguments), + ) + for call in tool_calls + ] + + +def _convert_openai_logprobs( + logprobs: OpenAIChoiceLogprobs, +) -> Optional[List[TokenLogProbs]]: + """ + Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs. + + OpenAI ChoiceLogprobs: + content: Optional[List[ChatCompletionTokenLogprob]] + + OpenAI ChatCompletionTokenLogprob: + token: str + logprob: float + top_logprobs: List[TopLogprob] + + OpenAI TopLogprob: + token: str + logprob: float + + -> + + TokenLogProbs: + logprobs_by_token: Dict[str, float] + - token, logprob + + """ + if not logprobs: + return None + + return [ + TokenLogProbs( + logprobs_by_token={ + logprobs.token: logprobs.logprob for logprobs in content.top_logprobs + } + ) + for content in logprobs.content + ] + + +def convert_openai_chat_completion_choice( + choice: OpenAIChoice, +) -> ChatCompletionResponse: + """ + Convert an OpenAI Choice into a ChatCompletionResponse. + + OpenAI Choice: + message: ChatCompletionMessage + finish_reason: str + logprobs: Optional[ChoiceLogprobs] + + OpenAI ChatCompletionMessage: + role: Literal["assistant"] + content: Optional[str] + tool_calls: Optional[List[ChatCompletionMessageToolCall]] + + -> + + ChatCompletionResponse: + completion_message: CompletionMessage + logprobs: Optional[List[TokenLogProbs]] + + CompletionMessage: + role: Literal["assistant"] + content: str | ImageMedia | List[str | ImageMedia] + stop_reason: StopReason + tool_calls: List[ToolCall] + + class StopReason(Enum): + end_of_turn = "end_of_turn" + end_of_message = "end_of_message" + out_of_tokens = "out_of_tokens" + """ + assert ( + hasattr(choice, "message") and choice.message + ), "error in server response: message not found" + assert ( + hasattr(choice, "finish_reason") and choice.finish_reason + ), "error in server response: finish_reason not found" + + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content + or "", # CompletionMessage content is not optional + stop_reason=_convert_openai_finish_reason(choice.finish_reason), + tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), + ), + logprobs=_convert_openai_logprobs(choice.logprobs), + ) + + +async def convert_openai_chat_completion_stream( + stream: AsyncStream[OpenAIChatCompletionChunk], +) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + """ + Convert a stream of OpenAI chat completion chunks into a stream + of ChatCompletionResponseStreamChunk. + + OpenAI ChatCompletionChunk: + choices: List[Choice] + + OpenAI Choice: # different from the non-streamed Choice + delta: ChoiceDelta + finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]] + logprobs: Optional[ChoiceLogprobs] + + OpenAI ChoiceDelta: + content: Optional[str] + role: Optional[Literal["system", "user", "assistant", "tool"]] + tool_calls: Optional[List[ChoiceDeltaToolCall]] + + OpenAI ChoiceDeltaToolCall: + index: int + id: Optional[str] + function: Optional[ChoiceDeltaToolCallFunction] + type: Optional[Literal["function"]] + + OpenAI ChoiceDeltaToolCallFunction: + name: Optional[str] + arguments: Optional[str] + + -> + + ChatCompletionResponseStreamChunk: + event: ChatCompletionResponseEvent + + ChatCompletionResponseEvent: + event_type: ChatCompletionResponseEventType + delta: Union[str, ToolCallDelta] + logprobs: Optional[List[TokenLogProbs]] + stop_reason: Optional[StopReason] + + ChatCompletionResponseEventType: + start = "start" + progress = "progress" + complete = "complete" + + ToolCallDelta: + content: Union[str, ToolCall] + parse_status: ToolCallParseStatus + + ToolCall: + call_id: str + tool_name: str + arguments: str + + ToolCallParseStatus: + started = "started" + in_progress = "in_progress" + failure = "failure" + success = "success" + + TokenLogProbs: + logprobs_by_token: Dict[str, float] + - token, logprob + + StopReason: + end_of_turn = "end_of_turn" + end_of_message = "end_of_message" + out_of_tokens = "out_of_tokens" + """ + + # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... + def _event_type_generator() -> ( + Generator[ChatCompletionResponseEventType, None, None] + ): + yield ChatCompletionResponseEventType.start + while True: + yield ChatCompletionResponseEventType.progress + + event_type = _event_type_generator() + + # we implement NIM specific semantics, the main difference from OpenAI + # is that tool_calls are always produced as a complete call. there is no + # intermediate / partial tool call streamed. because of this, we can + # simplify the logic and not concern outselves with parse_status of + # started/in_progress/failed. we can always assume success. + # + # a stream of ChatCompletionResponseStreamChunk consists of + # 0. a start event + # 1. zero or more progress events + # - each progress event has a delta + # - each progress event may have a stop_reason + # - each progress event may have logprobs + # - each progress event may have tool_calls + # if a progress event has tool_calls, + # it is fully formed and + # can be emitted with a parse_status of success + # 2. a complete event + + stop_reason = None + + async for chunk in stream: + choice = chunk.choices[0] # assuming only one choice per chunk + + # we assume there's only one finish_reason in the stream + stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason + + # if there's a tool call, emit an event for each tool in the list + # if tool call and content, emit both separately + + if choice.delta.tool_calls: + # the call may have content and a tool call. ChatCompletionResponseEvent + # does not support both, so we emit the content first + if choice.delta.content: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=choice.delta.content, + logprobs=_convert_openai_logprobs(choice.logprobs), + ) + ) + + # it is possible to have parallel tool calls in stream, but + # ChatCompletionResponseEvent only supports one per stream + if len(choice.delta.tool_calls) > 1: + warnings.warn( + "multiple tool calls found in a single delta, using the first, ignoring the rest" + ) + + # NIM only produces fully formed tool calls, so we can assume success + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=ToolCallDelta( + content=_convert_openai_tool_calls(choice.delta.tool_calls)[0], + parse_status=ToolCallParseStatus.success, + ), + logprobs=_convert_openai_logprobs(choice.logprobs), + ) + ) + else: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_type), + delta=choice.delta.content or "", # content is not optional + logprobs=_convert_openai_logprobs(choice.logprobs), + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + ) diff --git a/llama_stack/providers/adapters/inference/nvidia/_utils.py b/llama_stack/providers/adapters/inference/nvidia/_utils.py index 6b90750500..6f52bdc4b7 100644 --- a/llama_stack/providers/adapters/inference/nvidia/_utils.py +++ b/llama_stack/providers/adapters/inference/nvidia/_utils.py @@ -4,43 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import warnings -from typing import Any, Dict, List, Optional, Tuple +from typing import Tuple import httpx -from llama_models.llama3.api.datatypes import ( - CompletionMessage, - StopReason, - TokenLogProbs, - ToolCall, -) - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - Message, -) from ._config import NVIDIAConfig -def convert_message(message: Message) -> dict: - """ - Convert a Message to an OpenAI API-compatible dictionary. - """ - out_dict = message.dict() - # Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool" - if out_dict["role"] == "ipython": - out_dict.update(role="tool") - - if "stop_reason" in out_dict: - out_dict.update(stop_reason=out_dict["stop_reason"].value) - - # TODO(mf): tool_calls - - return out_dict - - async def _get_health(url: str) -> Tuple[bool, bool]: """ Query {url}/v1/health/{live,ready} to check if the server is running and ready @@ -78,251 +48,3 @@ async def check_health(config: NVIDIAConfig) -> None: # TODO(mf): should we wait for the server to be ready? except httpx.ConnectError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e - - -def convert_chat_completion_request( - request: ChatCompletionRequest, - n: int = 1, -) -> dict: - """ - Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. - """ - # model -> model - # messages -> messages - # sampling_params TODO(mattf): review strategy - # strategy=greedy -> nvext.top_k = -1, temperature = temperature - # strategy=top_p -> nvext.top_k = -1, top_p = top_p - # strategy=top_k -> nvext.top_k = top_k - # temperature -> temperature - # top_p -> top_p - # top_k -> nvext.top_k - # max_tokens -> max_tokens - # repetition_penalty -> nvext.repetition_penalty - # tools -> tools - # tool_choice ("auto", "required") -> tool_choice - # tool_prompt_format -> TBD - # stream -> stream - # logprobs -> logprobs - - print(f"sampling_params: {request.sampling_params}") - - payload: Dict[str, Any] = dict( - model=request.model, - messages=[convert_message(message) for message in request.messages], - stream=request.stream, - nvext={}, - n=n, - ) - nvext = payload["nvext"] - - if request.tools: - payload.update(tools=request.tools) - if request.tool_choice: - payload.update( - tool_choice=request.tool_choice.value - ) # we cannot include tool_choice w/o tools, server will complain - - if request.logprobs: - payload.update(logprobs=True) - payload.update(top_logprobs=request.logprobs.top_k) - - if request.sampling_params: - nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) - - if request.sampling_params.max_tokens: - payload.update(max_tokens=request.sampling_params.max_tokens) - - if request.sampling_params.strategy == "top_p": - nvext.update(top_k=-1) - payload.update(top_p=request.sampling_params.top_p) - elif request.sampling_params.strategy == "top_k": - if ( - request.sampling_params.top_k != -1 - and request.sampling_params.top_k < 1 - ): - warnings.warn("top_k must be -1 or >= 1") - nvext.update(top_k=request.sampling_params.top_k) - elif request.sampling_params.strategy == "greedy": - nvext.update(top_k=-1) - payload.update(temperature=request.sampling_params.temperature) - - return payload - - -def _parse_content(completion: dict) -> str: - """ - Get the content from an OpenAI completion response. - - OpenAI completion response format - - { - ... - "message": {"role": "assistant", "content": ..., ...}, - ... - } - """ - # content is nullable in the OpenAI response, common for tool calls - return completion["message"]["content"] or "" - - -def _parse_stop_reason(completion: dict) -> StopReason: - """ - Get the StopReason from an OpenAI completion response. - - OpenAI completion response format - - { - ... - "finish_reason": "length" or "stop" or "tool_calls", - ... - } - """ - - # StopReason options are end_of_turn, end_of_message, out_of_tokens - # TODO(mf): is end_of_turn and end_of_message usage correct? - stop_reason = StopReason.end_of_turn - if completion["finish_reason"] == "length": - stop_reason = StopReason.out_of_tokens - elif completion["finish_reason"] == "stop": - stop_reason = StopReason.end_of_message - elif completion["finish_reason"] == "tool_calls": - stop_reason = StopReason.end_of_turn - return stop_reason - - -def _parse_tool_calls(completion: dict) -> List[ToolCall]: - """ - Get the tool calls from an OpenAI completion response. - - OpenAI completion response format - - { - ..., - "message": { - ..., - "tool_calls": [ - { - "id": X, - "type": "function", - "function": { - "name": Y, - "arguments": Z, - }, - }* - ], - }, - } - -> - [ - ToolCall(call_id=X, tool_name=Y, arguments=Z), - ... - ] - """ - tool_calls = [] - if "tool_calls" in completion["message"]: - assert isinstance( - completion["message"]["tool_calls"], list - ), "error in server response: tool_calls not a list" - for call in completion["message"]["tool_calls"]: - assert "id" in call, "error in server response: tool call id not found" - assert ( - "function" in call - ), "error in server response: tool call function not found" - assert ( - "name" in call["function"] - ), "error in server response: tool call function name not found" - assert ( - "arguments" in call["function"] - ), "error in server response: tool call function arguments not found" - tool_calls.append( - ToolCall( - call_id=call["id"], - tool_name=call["function"]["name"], - arguments=call["function"]["arguments"], - ) - ) - - return tool_calls - - -def _parse_logprobs(completion: dict) -> Optional[List[TokenLogProbs]]: - """ - Extract logprobs from OpenAI as a list of TokenLogProbs. - - OpenAI completion response format - - { - ... - "logprobs": { - content: [ - { - ..., - top_logprobs: [{token: X, logprob: Y, bytes: [...]}+] - }+ - ] - }, - ... - } - -> - [ - TokenLogProbs( - logprobs_by_token={X: Y, ...} - ), - ... - ] - """ - if not (logprobs := completion.get("logprobs")): - return None - - return [ - TokenLogProbs( - logprobs_by_token={ - logprobs["token"]: logprobs["logprob"] - for logprobs in content["top_logprobs"] - } - ) - for content in logprobs["content"] - ] - - -def parse_completion( - completion: dict, -) -> ChatCompletionResponse: - """ - Parse an OpenAI completion response into a CompletionMessage and logprobs. - - OpenAI completion response format - - { - "message": { - "role": "assistant", - "content": ..., - "tool_calls": [ - { - ... - "id": ..., - "function": { - "name": ..., - "arguments": ..., - }, - }* - ]?, - "finish_reason": ..., - "logprobs": { - "content": [ - { - ..., - "top_logprobs": [{"token": ..., "logprob": ..., ...}+] - }+ - ] - }? - } - """ - assert "message" in completion, "error in server response: message not found" - assert ( - "finish_reason" in completion - ), "error in server response: finish_reason not found" - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=_parse_content(completion), - stop_reason=_parse_stop_reason(completion), - tool_calls=_parse_tool_calls(completion), - ), - logprobs=_parse_logprobs(completion), - ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 18397a08d7..38ca94ed51 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -144,7 +144,9 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="nvidia", - pip_packages=[], # TODO(mf): need to specify httpx if it's already a llama-stack dep? + pip_packages=[ + "openai", + ], module="llama_stack.providers.adapters.inference.nvidia", config_class="llama_stack.providers.adapters.inference.nvidia.NVIDIAConfig", ), diff --git a/tests/nvidia/integration/test_inference.py b/tests/nvidia/integration/test_inference.py index 2e7b33e4f4..df4c74d85a 100644 --- a/tests/nvidia/integration/test_inference.py +++ b/tests/nvidia/integration/test_inference.py @@ -8,11 +8,15 @@ from typing import Generator, List, Tuple import pytest +from llama_models.datatypes import SamplingParams from llama_stack.apis.inference import ( ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, CompletionMessage, Inference, + # LogProbConfig, Message, StopReason, SystemMessage, @@ -96,6 +100,70 @@ async def test_chat_completion_messages( assert response.completion_message.tool_calls == [] +async def test_chat_completion_basic( + client: Inference, + model: str, +): + """ + Test the chat completion endpoint with basic messages, with and without streaming. + """ + client = await client + messages = [ + UserMessage(content="How are you?"), + ] + + response = await client.chat_completion( + model=model, + messages=messages, + stream=False, + ) + + assert isinstance(response, ChatCompletionResponse) + assert isinstance(response.completion_message.content, str) + # we're not testing accuracy, so no assertions on the result.completion_message.content + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.stop_reason, StopReason) + assert response.completion_message.tool_calls == [] + + +async def test_chat_completion_stream_basic( + client: Inference, + model: str, +): + """ + Test the chat completion endpoint with basic messages, with and without streaming. + """ + client = await client + messages = [ + UserMessage(content="How are you?"), + ] + + response = await client.chat_completion( + model=model, + messages=messages, + stream=True, + sampling_params=SamplingParams(max_tokens=5), + # logprobs=LogProbConfig(top_k=3), + ) + + chunks = [chunk async for chunk in response] + assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in chunks) + assert all(isinstance(chunk.event.delta, str) for chunk in chunks) + assert chunks[0].event.event_type == ChatCompletionResponseEventType.start + assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete + if len(chunks) > 2: + assert all( + chunk.event.event_type == ChatCompletionResponseEventType.progress + for chunk in chunks[1:-1] + ) + # we're not testing accuracy, so no assertions on the result.completion_message.content + assert all( + chunk.event.stop_reason is None + or isinstance(chunk.event.stop_reason, StopReason) + for chunk in chunks + ) + + async def test_bad_base_url( model: str, ): diff --git a/tests/nvidia/unit/test_chat_completion.py b/tests/nvidia/unit/test_chat_completion.py index 1608ad39a6..b8c91f2449 100644 --- a/tests/nvidia/unit/test_chat_completion.py +++ b/tests/nvidia/unit/test_chat_completion.py @@ -157,7 +157,7 @@ async def test_tools( "type": "function", "function": { "name": "magic", - "arguments": {"input": 3}, + "arguments": '{"input": 3}', }, }, { @@ -165,7 +165,7 @@ async def test_tools( "type": "function", "function": { "name": "magic!", - "arguments": {"input": 42}, + "arguments": '{"input": 42}', }, }, ], diff --git a/tests/nvidia/unit/test_openai_utils.py b/tests/nvidia/unit/test_openai_utils.py new file mode 100644 index 0000000000..7acf3f6ccf --- /dev/null +++ b/tests/nvidia/unit/test_openai_utils.py @@ -0,0 +1,493 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator, List + +import pytest +from llama_models.llama3.api.datatypes import StopReason + +from llama_stack.apis.inference import ( + ChatCompletionResponse, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, +) +from llama_stack.providers.adapters.inference.nvidia._openai_utils import ( + convert_openai_chat_completion_choice, + convert_openai_chat_completion_stream, +) +from openai.types.chat import ( + ChatCompletionChunk as OpenAIChatCompletionChunk, + ChatCompletionMessage, + ChatCompletionMessageToolCall, + ChatCompletionTokenLogprob, +) +from openai.types.chat.chat_completion import Choice, ChoiceLogprobs +from openai.types.chat.chat_completion_chunk import ( + Choice as ChoiceChunk, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) +from openai.types.chat.chat_completion_token_logprob import TopLogprob + + +def test_convert_openai_chat_completion_choice_basic(): + response = Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Hello, world!", + ), + finish_reason="stop", + ) + result = convert_openai_chat_completion_choice(response) + assert isinstance(result, ChatCompletionResponse) + assert result.completion_message.content == "Hello, world!" + assert result.completion_message.stop_reason == StopReason.end_of_turn + assert result.completion_message.tool_calls == [] + assert result.logprobs is None + + +def test_convert_openai_chat_completion_choice_basic_with_tool_calls(): + response = Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Hello, world!", + tool_calls=[ + ChatCompletionMessageToolCall( + id="tool_call_id", + type="function", + function={ + "name": "test_function", + "arguments": '{"test_args": "test_value"}', + }, + ) + ], + ), + finish_reason="tool_calls", + ) + + result = convert_openai_chat_completion_choice(response) + assert isinstance(result, ChatCompletionResponse) + assert result.completion_message.content == "Hello, world!" + assert result.completion_message.stop_reason == StopReason.end_of_message + assert len(result.completion_message.tool_calls) == 1 + assert result.completion_message.tool_calls[0].tool_name == "test_function" + assert result.completion_message.tool_calls[0].arguments == { + "test_args": "test_value" + } + assert result.logprobs is None + + +def test_convert_openai_chat_completion_choice_basic_with_logprobs(): + response = Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Hello world", + ), + finish_reason="stop", + logprobs=ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob( + token="Hello", + logprob=-1.0, + bytes=[72, 101, 108, 108, 111], + top_logprobs=[ + TopLogprob( + token="Hello", logprob=-1.0, bytes=[72, 101, 108, 108, 111] + ), + TopLogprob( + token="Greetings", + logprob=-1.5, + bytes=[71, 114, 101, 101, 116, 105, 110, 103, 115], + ), + ], + ), + ChatCompletionTokenLogprob( + token="world", + logprob=-1.5, + bytes=[119, 111, 114, 108, 100], + top_logprobs=[ + TopLogprob( + token="world", logprob=-1.5, bytes=[119, 111, 114, 108, 100] + ), + TopLogprob( + token="planet", + logprob=-2.0, + bytes=[112, 108, 97, 110, 101, 116], + ), + ], + ), + ] + ), + ) + + result = convert_openai_chat_completion_choice(response) + assert isinstance(result, ChatCompletionResponse) + assert result.completion_message.content == "Hello world" + assert result.completion_message.stop_reason == StopReason.end_of_turn + assert result.completion_message.tool_calls == [] + assert result.logprobs is not None + assert len(result.logprobs) == 2 + assert len(result.logprobs[0].logprobs_by_token) == 2 + assert result.logprobs[0].logprobs_by_token["Hello"] == -1.0 + assert result.logprobs[0].logprobs_by_token["Greetings"] == -1.5 + assert len(result.logprobs[1].logprobs_by_token) == 2 + assert result.logprobs[1].logprobs_by_token["world"] == -1.5 + assert result.logprobs[1].logprobs_by_token["planet"] == -2.0 + + +def test_convert_openai_chat_completion_choice_missing_message(): + response = Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Hello, world!", + ), + finish_reason="stop", + ) + + response.message = None + with pytest.raises( + AssertionError, match="error in server response: message not found" + ): + convert_openai_chat_completion_choice(response) + + del response.message + with pytest.raises( + AssertionError, match="error in server response: message not found" + ): + convert_openai_chat_completion_choice(response) + + +def test_convert_openai_chat_completion_choice_missing_finish_reason(): + response = Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Hello, world!", + ), + finish_reason="stop", + ) + + response.finish_reason = None + with pytest.raises( + AssertionError, match="error in server response: finish_reason not found" + ): + convert_openai_chat_completion_choice(response) + + del response.finish_reason + with pytest.raises( + AssertionError, match="error in server response: finish_reason not found" + ): + convert_openai_chat_completion_choice(response) + + +# we want to test convert_openai_chat_completion_stream +# we need to produce a stream of OpenAIChatCompletionChunk +# streams to produce - +# 0. basic stream with one chunk, should produce 3 (start, progress, complete) +# 1. stream with 3 chunks, should produce 5 events (start, progress, progress, progress, complete) +# 2. stream with a tool call, should produce 4 events (start, progress w/ tool_call, complete) + + +@pytest.mark.asyncio +async def test_convert_openai_chat_completion_stream_basic(): + chunks = [ + OpenAIChatCompletionChunk( + id="1", + created=1234567890, + model="mock-model", + object="chat.completion.chunk", + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + role="assistant", + content="Hello, world!", + ), + finish_reason="stop", + ) + ], + ) + ] + + async def async_generator_from_list(items: List) -> AsyncGenerator: + for item in items: + yield item + + results = [ + result + async for result in convert_openai_chat_completion_stream( + async_generator_from_list(chunks) + ) + ] + + assert len(results) == 2 + assert all( + isinstance(result, ChatCompletionResponseStreamChunk) for result in results + ) + assert results[0].event.event_type == ChatCompletionResponseEventType.start + assert results[0].event.delta == "Hello, world!" + assert results[1].event.event_type == ChatCompletionResponseEventType.complete + assert results[1].event.stop_reason == StopReason.end_of_turn + + +@pytest.mark.asyncio +async def test_convert_openai_chat_completion_stream_basic_empty(): + chunks = [ + OpenAIChatCompletionChunk( + id="1", + created=1234567890, + model="mock-model", + object="chat.completion.chunk", + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + role="assistant", + ), + finish_reason="stop", + ) + ], + ), + OpenAIChatCompletionChunk( + id="1", + created=1234567890, + model="mock-model", + object="chat.completion.chunk", + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + role="assistant", + content="Hello, world!", + ), + finish_reason="stop", + ) + ], + ), + ] + + async def async_generator_from_list(items: List) -> AsyncGenerator: + for item in items: + yield item + + results = [ + result + async for result in convert_openai_chat_completion_stream( + async_generator_from_list(chunks) + ) + ] + + print(results) + + assert len(results) == 3 + assert all( + isinstance(result, ChatCompletionResponseStreamChunk) for result in results + ) + assert results[0].event.event_type == ChatCompletionResponseEventType.start + assert results[1].event.event_type == ChatCompletionResponseEventType.progress + assert results[1].event.delta == "Hello, world!" + assert results[2].event.event_type == ChatCompletionResponseEventType.complete + assert results[2].event.stop_reason == StopReason.end_of_turn + + +@pytest.mark.asyncio +async def test_convert_openai_chat_completion_stream_multiple_chunks(): + chunks = [ + OpenAIChatCompletionChunk( + id="1", + created=1234567890, + model="mock-model", + object="chat.completion.chunk", + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + role="assistant", + content="Hello, world!", + ), + # finish_reason="continue", + ) + ], + ), + OpenAIChatCompletionChunk( + id="2", + created=1234567891, + model="mock-model", + object="chat.completion.chunk", + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + role="assistant", + content="How are you?", + ), + # finish_reason="continue", + ) + ], + ), + OpenAIChatCompletionChunk( + id="3", + created=1234567892, + model="mock-model", + object="chat.completion.chunk", + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + role="assistant", + content="I'm good, thanks!", + ), + finish_reason="stop", + ) + ], + ), + ] + + async def async_generator_from_list(items: List) -> AsyncGenerator: + for item in items: + yield item + + results = [ + result + async for result in convert_openai_chat_completion_stream( + async_generator_from_list(chunks) + ) + ] + + assert len(results) == 4 + assert all( + isinstance(result, ChatCompletionResponseStreamChunk) for result in results + ) + assert results[0].event.event_type == ChatCompletionResponseEventType.start + assert results[0].event.delta == "Hello, world!" + assert not results[0].event.stop_reason + assert results[1].event.event_type == ChatCompletionResponseEventType.progress + assert results[1].event.delta == "How are you?" + assert not results[1].event.stop_reason + assert results[2].event.event_type == ChatCompletionResponseEventType.progress + assert results[2].event.delta == "I'm good, thanks!" + assert not results[2].event.stop_reason + assert results[3].event.event_type == ChatCompletionResponseEventType.complete + assert results[3].event.stop_reason == StopReason.end_of_turn + + +@pytest.mark.asyncio +async def test_convert_openai_chat_completion_stream_with_tool_call_and_content(): + chunks = [ + OpenAIChatCompletionChunk( + id="1", + created=1234567890, + model="mock-model", + object="chat.completion.chunk", + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + role="assistant", + content="Hello, world!", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="tool_call_id", + type="function", + function=ChoiceDeltaToolCallFunction( + name="test_function", + arguments='{"test_args": "test_value"}', + ), + ) + ], + ), + finish_reason="tool_calls", + ) + ], + ) + ] + + async def async_generator_from_list(items: List) -> AsyncGenerator: + for item in items: + yield item + + results = [ + result + async for result in convert_openai_chat_completion_stream( + async_generator_from_list(chunks) + ) + ] + + assert len(results) == 3 + assert all( + isinstance(result, ChatCompletionResponseStreamChunk) for result in results + ) + assert results[0].event.event_type == ChatCompletionResponseEventType.start + assert results[0].event.delta == "Hello, world!" + assert not results[0].event.stop_reason + assert results[1].event.event_type == ChatCompletionResponseEventType.progress + assert not isinstance(results[1].event.delta, str) + assert results[1].event.delta.content.tool_name == "test_function" + assert results[1].event.delta.content.arguments == {"test_args": "test_value"} + assert not results[1].event.stop_reason + assert results[2].event.event_type == ChatCompletionResponseEventType.complete + assert results[2].event.stop_reason == StopReason.end_of_message + + +@pytest.mark.asyncio +async def test_convert_openai_chat_completion_stream_with_tool_call_and_no_content(): + chunks = [ + OpenAIChatCompletionChunk( + id="1", + created=1234567890, + model="mock-model", + object="chat.completion.chunk", + choices=[ + ChoiceChunk( + index=0, + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="tool_call_id", + type="function", + function=ChoiceDeltaToolCallFunction( + name="test_function", + arguments='{"test_args": "test_value"}', + ), + ) + ], + ), + finish_reason="tool_calls", + ) + ], + ) + ] + + async def async_generator_from_list(items: List) -> AsyncGenerator: + for item in items: + yield item + + results = [ + result + async for result in convert_openai_chat_completion_stream( + async_generator_from_list(chunks) + ) + ] + + assert len(results) == 2 + assert all( + isinstance(result, ChatCompletionResponseStreamChunk) for result in results + ) + assert results[0].event.event_type == ChatCompletionResponseEventType.start + assert not isinstance(results[0].event.delta, str) + assert results[0].event.delta.content.tool_name == "test_function" + assert results[0].event.delta.content.arguments == {"test_args": "test_value"} + assert not results[0].event.stop_reason + assert results[1].event.event_type == ChatCompletionResponseEventType.complete + assert results[1].event.stop_reason == StopReason.end_of_message From 2980a1892060ea7cbeb7c2a91bff68e5197560c0 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 19 Nov 2024 12:49:14 -0500 Subject: [PATCH 3/9] map llama model -> provider model id in ModelRegistryHelper --- llama_stack/providers/utils/inference/model_registry.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 07225fac0e..8dbfab14aa 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -29,7 +29,6 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli return ModelAlias( provider_model_id=provider_model_id, aliases=[ - model_descriptor, get_huggingface_repo(model_descriptor), ], llama_model=model_descriptor, @@ -57,6 +56,10 @@ def __init__(self, model_aliases: List[ModelAlias]): self.alias_to_provider_id_map[alias_obj.provider_model_id] = ( alias_obj.provider_model_id ) + # ensure we can go from llama model to provider model id + self.alias_to_provider_id_map[alias_obj.llama_model] = ( + alias_obj.provider_model_id + ) self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = ( alias_obj.llama_model ) From a5d413045c5ab0141b91f08eb003b6c304cc8354 Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Tue, 19 Nov 2024 21:02:20 +0000 Subject: [PATCH 4/9] Add nvidia remote distro --- .../remote_hosted_distro/nvidia.md | 60 +++++++++++++++++ .../remote/inference/nvidia/_config.py | 9 ++- llama_stack/templates/nvidia/__init__.py | 7 ++ llama_stack/templates/nvidia/build.yaml | 19 ++++++ llama_stack/templates/nvidia/doc_template.md | 60 +++++++++++++++++ llama_stack/templates/nvidia/nvidia.py | 64 +++++++++++++++++++ llama_stack/templates/nvidia/run.yaml | 55 ++++++++++++++++ 7 files changed, 273 insertions(+), 1 deletion(-) create mode 100644 docs/source/getting_started/distributions/remote_hosted_distro/nvidia.md create mode 100644 llama_stack/templates/nvidia/__init__.py create mode 100644 llama_stack/templates/nvidia/build.yaml create mode 100644 llama_stack/templates/nvidia/doc_template.md create mode 100644 llama_stack/templates/nvidia/nvidia.py create mode 100644 llama_stack/templates/nvidia/run.yaml diff --git a/docs/source/getting_started/distributions/remote_hosted_distro/nvidia.md b/docs/source/getting_started/distributions/remote_hosted_distro/nvidia.md new file mode 100644 index 0000000000..b670c73453 --- /dev/null +++ b/docs/source/getting_started/distributions/remote_hosted_distro/nvidia.md @@ -0,0 +1,60 @@ +# NVIDIA Distribution + +The `llamastack/distribution-nvidia` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| inference | `remote::nvidia` | +| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | +| safety | `inline::llama-guard` | +| telemetry | `inline::meta-reference` | + + +### Environment Variables + +The following environment variables can be configured: + +- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) +- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) + +### Models + +The following models are available by default: + +- `${env.INFERENCE_MODEL} (None)` + + +### Prerequisite: API Keys + +Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). + + +## Running Llama Stack with NVIDIA + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-nvidia \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env FIREWORKS_API_KEY=$FIREWORKS_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template fireworks --image-type conda +llama stack run ./run.yaml \ + --port 5001 \ + --env FIREWORKS_API_KEY=$FIREWORKS_API_KEY +``` \ No newline at end of file diff --git a/llama_stack/providers/remote/inference/nvidia/_config.py b/llama_stack/providers/remote/inference/nvidia/_config.py index 46ac3fa5ba..fab636c8d4 100644 --- a/llama_stack/providers/remote/inference/nvidia/_config.py +++ b/llama_stack/providers/remote/inference/nvidia/_config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import os -from typing import Optional +from typing import Any, Dict, Optional from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -50,3 +50,10 @@ class NVIDIAConfig(BaseModel): @property def is_hosted(self) -> bool: return "integrate.api.nvidia.com" in self.base_url + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "url": "https://integrate.api.nvidia.com", + "api_key": "${env.NVIDIA_API_KEY}", + } diff --git a/llama_stack/templates/nvidia/__init__.py b/llama_stack/templates/nvidia/__init__.py new file mode 100644 index 0000000000..24e2fbd216 --- /dev/null +++ b/llama_stack/templates/nvidia/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .nvidia import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml new file mode 100644 index 0000000000..9a735c2203 --- /dev/null +++ b/llama_stack/templates/nvidia/build.yaml @@ -0,0 +1,19 @@ +version: '2' +name: nvidia +distribution_spec: + description: Use NVIDIA NIM for running LLM inference + docker_image: null + providers: + inference: + - remote::nvidia + memory: + - inline::faiss + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference +image_type: conda diff --git a/llama_stack/templates/nvidia/doc_template.md b/llama_stack/templates/nvidia/doc_template.md new file mode 100644 index 0000000000..a9db770558 --- /dev/null +++ b/llama_stack/templates/nvidia/doc_template.md @@ -0,0 +1,60 @@ +# NVIDIA Distribution + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }} ({{ model.provider_model_id }})` +{% endfor %} +{% endif %} + + +### Prerequisite: API Keys + +Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). + + +## Running Llama Stack with NVIDIA + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-{{ name }} \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env FIREWORKS_API_KEY=$FIREWORKS_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template fireworks --image-type conda +llama stack run ./run.yaml \ + --port 5001 \ + --env FIREWORKS_API_KEY=$FIREWORKS_API_KEY +``` diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py new file mode 100644 index 0000000000..0f15511804 --- /dev/null +++ b/llama_stack/templates/nvidia/nvidia.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_models.sku_list import all_registered_models + +from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig +from llama_stack.providers.remote.inference.nvidia._nvidia import _MODEL_ALIASES + +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::nvidia"], + "memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + } + + inference_provider = Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + config=NVIDIAConfig.sample_run_config(), + ) + + inference_model = ModelInput( + model_id="${env.INFERENCE_MODEL}", + provider_id="nvidia", + ) + + return DistributionTemplate( + name="nvidia", + distro_type="remote_hosted", + description="Use NVIDIA NIM for running LLM inference", + docker_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + default_models=[inference_model], + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + }, + default_models=[inference_model], + ), + }, + run_config_env_vars={ + "LLAMASTACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "NVIDIA_API_KEY": ( + "", + "NVIDIA API Key", + ), + }, + ) diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml new file mode 100644 index 0000000000..f4953852f5 --- /dev/null +++ b/llama_stack/templates/nvidia/run.yaml @@ -0,0 +1,55 @@ +version: '2' +image_name: nvidia +docker_image: null +conda_env: nvidia +apis: +- agents +- inference +- memory +- safety +- telemetry +providers: + inference: + - provider_id: nvidia + provider_type: remote::nvidia + config: + url: https://integrate.api.nvidia.com + api_key: ${env.NVIDIA_API_KEY} + memory: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db +models: +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: nvidia + provider_model_id: null +shields: [] +memory_banks: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] From 4ccf4ef6416e2f1f85ebc1c1736420c4cd865c33 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 19 Nov 2024 17:36:08 -0500 Subject: [PATCH 5/9] align with other remote adapters, rename config base_url -> url --- llama_stack/providers/remote/inference/nvidia/_config.py | 8 ++++---- llama_stack/providers/remote/inference/nvidia/_nvidia.py | 6 +++--- llama_stack/providers/remote/inference/nvidia/_utils.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/_config.py b/llama_stack/providers/remote/inference/nvidia/_config.py index 46ac3fa5ba..0b45ecf54d 100644 --- a/llama_stack/providers/remote/inference/nvidia/_config.py +++ b/llama_stack/providers/remote/inference/nvidia/_config.py @@ -17,7 +17,7 @@ class NVIDIAConfig(BaseModel): Configuration for the NVIDIA NIM inference endpoint. Attributes: - base_url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000 + url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000 api_key (str): The access key for the hosted NIM endpoints There are two ways to access NVIDIA NIMs - @@ -30,11 +30,11 @@ class NVIDIAConfig(BaseModel): By default the configuration will attempt to read the NVIDIA_API_KEY environment variable to set the api_key. Please do not put your API key in code. - If you are using a self-hosted NVIDIA NIM, you can set the base_url to the + If you are using a self-hosted NVIDIA NIM, you can set the url to the URL of your running NVIDIA NIM and do not need to set the api_key. """ - base_url: str = Field( + url: str = Field( default="https://integrate.api.nvidia.com", description="A base url for accessing the NVIDIA NIM", ) @@ -49,4 +49,4 @@ class NVIDIAConfig(BaseModel): @property def is_hosted(self) -> bool: - return "integrate.api.nvidia.com" in self.base_url + return "integrate.api.nvidia.com" in self.url diff --git a/llama_stack/providers/remote/inference/nvidia/_nvidia.py b/llama_stack/providers/remote/inference/nvidia/_nvidia.py index c5bfa0f259..92c4e1cfba 100644 --- a/llama_stack/providers/remote/inference/nvidia/_nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/_nvidia.py @@ -89,7 +89,7 @@ def __init__(self, config: NVIDIAConfig) -> None: # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) - print(f"Initializing NVIDIAInferenceAdapter({config.base_url})...") + print(f"Initializing NVIDIAInferenceAdapter({config.url})...") if config.is_hosted: if not config.api_key: @@ -110,7 +110,7 @@ def __init__(self, config: NVIDIAConfig) -> None: self._config = config # make sure the client lives longer than any async calls self._client = AsyncOpenAI( - base_url=f"{self._config.base_url}/v1", + base_url=f"{self._config.url}/v1", api_key=self._config.api_key or "NO KEY", timeout=self._config.timeout, ) @@ -172,7 +172,7 @@ async def chat_completion( response = await self._client.chat.completions.create(**request) except APIConnectionError as e: raise ConnectionError( - f"Failed to connect to NVIDIA NIM at {self._config.base_url}: {e}" + f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}" ) from e if stream: diff --git a/llama_stack/providers/remote/inference/nvidia/_utils.py b/llama_stack/providers/remote/inference/nvidia/_utils.py index 6f52bdc4b7..c66cf75f43 100644 --- a/llama_stack/providers/remote/inference/nvidia/_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/_utils.py @@ -40,7 +40,7 @@ async def check_health(config: NVIDIAConfig) -> None: if not config.is_hosted: print("Checking NVIDIA NIM health...") try: - is_live, is_ready = await _get_health(config.base_url) + is_live, is_ready = await _get_health(config.url) if not is_live: raise ConnectionError("NVIDIA NIM is not running") if not is_ready: From 67597442356be88d6c53da47d9561ef68a048247 Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Wed, 20 Nov 2024 23:04:48 +0000 Subject: [PATCH 6/9] Added distributions for inline and remote --- distributions/inline-nvidia/build.yaml | 1 + distributions/inline-nvidia/compose.yaml | 58 ++++++++++++++++++++++++ distributions/inline-nvidia/run.yaml | 56 +++++++++++++++++++++++ distributions/remote-nvidia/build.yaml | 1 + distributions/remote-nvidia/compose.yaml | 19 ++++++++ distributions/remote-nvidia/run.yaml | 1 + 6 files changed, 136 insertions(+) create mode 120000 distributions/inline-nvidia/build.yaml create mode 100644 distributions/inline-nvidia/compose.yaml create mode 100644 distributions/inline-nvidia/run.yaml create mode 120000 distributions/remote-nvidia/build.yaml create mode 100644 distributions/remote-nvidia/compose.yaml create mode 120000 distributions/remote-nvidia/run.yaml diff --git a/distributions/inline-nvidia/build.yaml b/distributions/inline-nvidia/build.yaml new file mode 120000 index 0000000000..8903d2e572 --- /dev/null +++ b/distributions/inline-nvidia/build.yaml @@ -0,0 +1 @@ +../../llama_stack/templates/nvidia/build.yaml \ No newline at end of file diff --git a/distributions/inline-nvidia/compose.yaml b/distributions/inline-nvidia/compose.yaml new file mode 100644 index 0000000000..f7320b9686 --- /dev/null +++ b/distributions/inline-nvidia/compose.yaml @@ -0,0 +1,58 @@ +services: + nim: + image: nvcr.io/nim/meta/llama-3.1-8b-instruct:latest + network_mode: "host" + volumes: + - nim-llm-cache:/opt/nim/.cache + ports: + - "8000:8000" + shm_size: 16G + environment: + - CUDA_VISIBLE_DEVICES=0 + - NIM_HTTP_API_PORT=8000 + - NIM_TRITON_LOG_VERBOSE=1 + - NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}} + command: [] + deploy: + resources: + reservations: + devices: + - driver: nvidia + # that's the closest analogue to --gpus; provide + # an integer amount of devices or 'all' + count: 1 + # Devices are reserved using a list of capabilities, making + # capabilities the only required field. A device MUST + # satisfy all the requested capabilities for a successful + # reservation. + capabilities: [gpu] + runtime: nvidia + healthcheck: + test: ["CMD", "curl", "http://localhost:8000/v1/health/ready"] + interval: 5s + timeout: 5s + retries: 30 + start_period: 120s + llamastack: + depends_on: + - nim + image: distribution-nvidia:dev + network_mode: "host" + volumes: + - ~/.llama:/root/.llama + - ./run.yaml:/root/llamastack-run-nvidia.yaml + ports: + - "5000:5000" + environment: + - INFERENCE_MODEL=${INFERENCE_MODEL:-Llama3.1-8B-Instruct} + - NVIDIA_API_KEY=${NVIDIA_API_KEY:-} + entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml-config /root/llamastack-run-nvidia.yaml" + deploy: + restart_policy: + condition: on-failure + delay: 3s + max_attempts: 5 + window: 60s +volumes: + nim-llm-cache: + driver: local \ No newline at end of file diff --git a/distributions/inline-nvidia/run.yaml b/distributions/inline-nvidia/run.yaml new file mode 100644 index 0000000000..81e9e7f1c0 --- /dev/null +++ b/distributions/inline-nvidia/run.yaml @@ -0,0 +1,56 @@ +version: '2' +image_name: nvidia +docker_image: null +conda_env: nvidia +apis: +- agents +- inference +- memory +- safety +- telemetry +providers: + inference: + - provider_id: nvidia + provider_type: remote::nvidia + config: + url: http://localhost:8000 + api_key: ${env.NVIDIA_API_KEY} + memory: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db +models: +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: nvidia + provider_model_id: null +shields: [] +memory_banks: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] + diff --git a/distributions/remote-nvidia/build.yaml b/distributions/remote-nvidia/build.yaml new file mode 120000 index 0000000000..8903d2e572 --- /dev/null +++ b/distributions/remote-nvidia/build.yaml @@ -0,0 +1 @@ +../../llama_stack/templates/nvidia/build.yaml \ No newline at end of file diff --git a/distributions/remote-nvidia/compose.yaml b/distributions/remote-nvidia/compose.yaml new file mode 100644 index 0000000000..04b12d0da2 --- /dev/null +++ b/distributions/remote-nvidia/compose.yaml @@ -0,0 +1,19 @@ +services: + llamastack: + image: distribution-nvidia:dev + network_mode: "host" + volumes: + - ~/.llama:/root/.llama + - ./run.yaml:/root/llamastack-run-nvidia.yaml + ports: + - "5000:5000" + environment: + - INFERENCE_MODEL=${INFERENCE_MODEL:-Llama3.1-8B-Instruct} + - NVIDIA_API_KEY=${NVIDIA_API_KEY:-} + entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml-config /root/llamastack-run-nvidia.yaml" + deploy: + restart_policy: + condition: on-failure + delay: 3s + max_attempts: 5 + window: 60s diff --git a/distributions/remote-nvidia/run.yaml b/distributions/remote-nvidia/run.yaml new file mode 120000 index 0000000000..85da3e26bd --- /dev/null +++ b/distributions/remote-nvidia/run.yaml @@ -0,0 +1 @@ +../../llama_stack/templates/nvidia/run.yaml \ No newline at end of file From deecc27b661c614935a99dadd8f200ec5970e721 Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Tue, 3 Dec 2024 19:22:05 -0800 Subject: [PATCH 7/9] Reverted outdated changes --- .../remote/inference/nvidia/_config.py | 59 --- .../remote/inference/nvidia/_nvidia.py | 182 -------- .../remote/inference/nvidia/_openai_utils.py | 430 ------------------ .../remote/inference/nvidia/_utils.py | 50 -- llama_stack/templates/nvidia/doc_template.md | 6 +- llama_stack/templates/nvidia/nvidia.py | 6 +- 6 files changed, 5 insertions(+), 728 deletions(-) delete mode 100644 llama_stack/providers/remote/inference/nvidia/_config.py delete mode 100644 llama_stack/providers/remote/inference/nvidia/_nvidia.py delete mode 100644 llama_stack/providers/remote/inference/nvidia/_openai_utils.py delete mode 100644 llama_stack/providers/remote/inference/nvidia/_utils.py diff --git a/llama_stack/providers/remote/inference/nvidia/_config.py b/llama_stack/providers/remote/inference/nvidia/_config.py deleted file mode 100644 index 7934a0f059..0000000000 --- a/llama_stack/providers/remote/inference/nvidia/_config.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import os -from typing import Any, Dict, Optional - -from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field - - -@json_schema_type -class NVIDIAConfig(BaseModel): - """ - Configuration for the NVIDIA NIM inference endpoint. - - Attributes: - url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000 - api_key (str): The access key for the hosted NIM endpoints - - There are two ways to access NVIDIA NIMs - - 0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com - 1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure - - By default the configuration is set to use the hosted APIs. This requires - an API key which can be obtained from https://ngc.nvidia.com/. - - By default the configuration will attempt to read the NVIDIA_API_KEY environment - variable to set the api_key. Please do not put your API key in code. - - If you are using a self-hosted NVIDIA NIM, you can set the url to the - URL of your running NVIDIA NIM and do not need to set the api_key. - """ - - url: str = Field( - default="https://integrate.api.nvidia.com", - description="A base url for accessing the NVIDIA NIM", - ) - api_key: Optional[str] = Field( - default_factory=lambda: os.getenv("NVIDIA_API_KEY"), - description="The NVIDIA API key, only needed of using the hosted service", - ) - timeout: int = Field( - default=60, - description="Timeout for the HTTP requests", - ) - - @property - def is_hosted(self) -> bool: - return "integrate.api.nvidia.com" in self.url - - @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: - return { - "url": "https://integrate.api.nvidia.com", - "api_key": "${env.NVIDIA_API_KEY}", - } diff --git a/llama_stack/providers/remote/inference/nvidia/_nvidia.py b/llama_stack/providers/remote/inference/nvidia/_nvidia.py deleted file mode 100644 index 92c4e1cfba..0000000000 --- a/llama_stack/providers/remote/inference/nvidia/_nvidia.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import warnings -from typing import AsyncIterator, List, Optional, Union - -from llama_models.datatypes import SamplingParams -from llama_models.llama3.api.datatypes import ( - InterleavedTextMedia, - Message, - ToolChoice, - ToolDefinition, - ToolPromptFormat, -) -from llama_models.sku_list import CoreModelId -from openai import APIConnectionError, AsyncOpenAI - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseStreamChunk, - CompletionResponse, - CompletionResponseStreamChunk, - EmbeddingsResponse, - Inference, - LogProbConfig, - ResponseFormat, -) -from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias_with_just_provider_model_id, - ModelRegistryHelper, -) - -from ._config import NVIDIAConfig -from ._openai_utils import ( - convert_chat_completion_request, - convert_openai_chat_completion_choice, - convert_openai_chat_completion_stream, -) -from ._utils import check_health - -_MODEL_ALIASES = [ - build_model_alias_with_just_provider_model_id( - "meta/llama3-8b-instruct", - CoreModelId.llama3_8b_instruct.value, - ), - build_model_alias_with_just_provider_model_id( - "meta/llama3-70b-instruct", - CoreModelId.llama3_70b_instruct.value, - ), - build_model_alias_with_just_provider_model_id( - "meta/llama-3.1-8b-instruct", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_model_alias_with_just_provider_model_id( - "meta/llama-3.1-70b-instruct", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_model_alias_with_just_provider_model_id( - "meta/llama-3.1-405b-instruct", - CoreModelId.llama3_1_405b_instruct.value, - ), - build_model_alias_with_just_provider_model_id( - "meta/llama-3.2-1b-instruct", - CoreModelId.llama3_2_1b_instruct.value, - ), - build_model_alias_with_just_provider_model_id( - "meta/llama-3.2-3b-instruct", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_model_alias_with_just_provider_model_id( - "meta/llama-3.2-11b-vision-instruct", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_model_alias_with_just_provider_model_id( - "meta/llama-3.2-90b-vision-instruct", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - # TODO(mf): how do we handle Nemotron models? - # "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", -] - - -class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): - def __init__(self, config: NVIDIAConfig) -> None: - # TODO(mf): filter by available models - ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) - - print(f"Initializing NVIDIAInferenceAdapter({config.url})...") - - if config.is_hosted: - if not config.api_key: - raise RuntimeError( - "API key is required for hosted NVIDIA NIM. " - "Either provide an API key or use a self-hosted NIM." - ) - # elif self._config.api_key: - # - # we don't raise this warning because a user may have deployed their - # self-hosted NIM with an API key requirement. - # - # warnings.warn( - # "API key is not required for self-hosted NVIDIA NIM. " - # "Consider removing the api_key from the configuration." - # ) - - self._config = config - # make sure the client lives longer than any async calls - self._client = AsyncOpenAI( - base_url=f"{self._config.url}/v1", - api_key=self._config.api_key or "NO KEY", - timeout=self._config.timeout, - ) - - def completion( - self, - model_id: str, - content: InterleavedTextMedia, - sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: - raise NotImplementedError() - - async def embeddings( - self, - model_id: str, - contents: List[InterleavedTextMedia], - ) -> EmbeddingsResponse: - raise NotImplementedError() - - async def chat_completion( - self, - model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ - ToolPromptFormat - ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ) -> Union[ - ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] - ]: - if tool_prompt_format: - warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") - - await check_health(self._config) # this raises errors - - request = convert_chat_completion_request( - request=ChatCompletionRequest( - model=self.get_provider_model_id(model_id), - messages=messages, - sampling_params=sampling_params, - tools=tools, - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, - stream=stream, - logprobs=logprobs, - ), - n=1, - ) - - try: - response = await self._client.chat.completions.create(**request) - except APIConnectionError as e: - raise ConnectionError( - f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}" - ) from e - - if stream: - return convert_openai_chat_completion_stream(response) - else: - # we pass n=1 to get only one completion - return convert_openai_chat_completion_choice(response.choices[0]) diff --git a/llama_stack/providers/remote/inference/nvidia/_openai_utils.py b/llama_stack/providers/remote/inference/nvidia/_openai_utils.py deleted file mode 100644 index 998b4c275c..0000000000 --- a/llama_stack/providers/remote/inference/nvidia/_openai_utils.py +++ /dev/null @@ -1,430 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import warnings -from typing import Any, AsyncGenerator, Dict, Generator, List, Optional - -from llama_models.llama3.api.datatypes import ( - CompletionMessage, - StopReason, - TokenLogProbs, - ToolCall, -) -from openai import AsyncStream -from openai.types.chat import ChatCompletionChunk as OpenAIChatCompletionChunk -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall, -) - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - Message, - ToolCallDelta, - ToolCallParseStatus, -) - - -def _convert_message(message: Message) -> Dict: - """ - Convert a Message to an OpenAI API-compatible dictionary. - """ - out_dict = message.dict() - # Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool" - if out_dict["role"] == "ipython": - out_dict.update(role="tool") - - if "stop_reason" in out_dict: - out_dict.update(stop_reason=out_dict["stop_reason"].value) - - # TODO(mf): tool_calls - - return out_dict - - -def convert_chat_completion_request( - request: ChatCompletionRequest, - n: int = 1, -) -> dict: - """ - Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. - """ - # model -> model - # messages -> messages - # sampling_params TODO(mattf): review strategy - # strategy=greedy -> nvext.top_k = -1, temperature = temperature - # strategy=top_p -> nvext.top_k = -1, top_p = top_p - # strategy=top_k -> nvext.top_k = top_k - # temperature -> temperature - # top_p -> top_p - # top_k -> nvext.top_k - # max_tokens -> max_tokens - # repetition_penalty -> nvext.repetition_penalty - # tools -> tools - # tool_choice ("auto", "required") -> tool_choice - # tool_prompt_format -> TBD - # stream -> stream - # logprobs -> logprobs - - nvext = {} - payload: Dict[str, Any] = dict( - model=request.model, - messages=[_convert_message(message) for message in request.messages], - stream=request.stream, - n=n, - extra_body=dict(nvext=nvext), - extra_headers={ - b"User-Agent": b"llama-stack: nvidia-inference-adapter", - }, - ) - - if request.tools: - payload.update(tools=request.tools) - if request.tool_choice: - payload.update( - tool_choice=request.tool_choice.value - ) # we cannot include tool_choice w/o tools, server will complain - - if request.logprobs: - payload.update(logprobs=True) - payload.update(top_logprobs=request.logprobs.top_k) - - if request.sampling_params: - nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) - - if request.sampling_params.max_tokens: - payload.update(max_tokens=request.sampling_params.max_tokens) - - if request.sampling_params.strategy == "top_p": - nvext.update(top_k=-1) - payload.update(top_p=request.sampling_params.top_p) - elif request.sampling_params.strategy == "top_k": - if ( - request.sampling_params.top_k != -1 - and request.sampling_params.top_k < 1 - ): - warnings.warn("top_k must be -1 or >= 1") - nvext.update(top_k=request.sampling_params.top_k) - elif request.sampling_params.strategy == "greedy": - nvext.update(top_k=-1) - payload.update(temperature=request.sampling_params.temperature) - - return payload - - -def _convert_openai_finish_reason(finish_reason: str) -> StopReason: - """ - Convert an OpenAI chat completion finish_reason to a StopReason. - - finish_reason: Literal["stop", "length", "tool_calls", ...] - - stop: model hit a natural stop point or a provided stop sequence - - length: maximum number of tokens specified in the request was reached - - tool_calls: model called a tool - - -> - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" - """ - - # TODO(mf): are end_of_turn and end_of_message semantics correct? - return { - "stop": StopReason.end_of_turn, - "length": StopReason.out_of_tokens, - "tool_calls": StopReason.end_of_message, - }.get(finish_reason, StopReason.end_of_turn) - - -def _convert_openai_tool_calls( - tool_calls: List[OpenAIChatCompletionMessageToolCall], -) -> List[ToolCall]: - """ - Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. - - OpenAI ChatCompletionMessageToolCall: - id: str - function: Function - type: Literal["function"] - - OpenAI Function: - arguments: str - name: str - - -> - - ToolCall: - call_id: str - tool_name: str - arguments: Dict[str, ...] - """ - if not tool_calls: - return [] # CompletionMessage tool_calls is not optional - - return [ - ToolCall( - call_id=call.id, - tool_name=call.function.name, - arguments=json.loads(call.function.arguments), - ) - for call in tool_calls - ] - - -def _convert_openai_logprobs( - logprobs: OpenAIChoiceLogprobs, -) -> Optional[List[TokenLogProbs]]: - """ - Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs. - - OpenAI ChoiceLogprobs: - content: Optional[List[ChatCompletionTokenLogprob]] - - OpenAI ChatCompletionTokenLogprob: - token: str - logprob: float - top_logprobs: List[TopLogprob] - - OpenAI TopLogprob: - token: str - logprob: float - - -> - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob - - """ - if not logprobs: - return None - - return [ - TokenLogProbs( - logprobs_by_token={ - logprobs.token: logprobs.logprob for logprobs in content.top_logprobs - } - ) - for content in logprobs.content - ] - - -def convert_openai_chat_completion_choice( - choice: OpenAIChoice, -) -> ChatCompletionResponse: - """ - Convert an OpenAI Choice into a ChatCompletionResponse. - - OpenAI Choice: - message: ChatCompletionMessage - finish_reason: str - logprobs: Optional[ChoiceLogprobs] - - OpenAI ChatCompletionMessage: - role: Literal["assistant"] - content: Optional[str] - tool_calls: Optional[List[ChatCompletionMessageToolCall]] - - -> - - ChatCompletionResponse: - completion_message: CompletionMessage - logprobs: Optional[List[TokenLogProbs]] - - CompletionMessage: - role: Literal["assistant"] - content: str | ImageMedia | List[str | ImageMedia] - stop_reason: StopReason - tool_calls: List[ToolCall] - - class StopReason(Enum): - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" - """ - assert ( - hasattr(choice, "message") and choice.message - ), "error in server response: message not found" - assert ( - hasattr(choice, "finish_reason") and choice.finish_reason - ), "error in server response: finish_reason not found" - - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content - or "", # CompletionMessage content is not optional - stop_reason=_convert_openai_finish_reason(choice.finish_reason), - tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), - ), - logprobs=_convert_openai_logprobs(choice.logprobs), - ) - - -async def convert_openai_chat_completion_stream( - stream: AsyncStream[OpenAIChatCompletionChunk], -) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: - """ - Convert a stream of OpenAI chat completion chunks into a stream - of ChatCompletionResponseStreamChunk. - - OpenAI ChatCompletionChunk: - choices: List[Choice] - - OpenAI Choice: # different from the non-streamed Choice - delta: ChoiceDelta - finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]] - logprobs: Optional[ChoiceLogprobs] - - OpenAI ChoiceDelta: - content: Optional[str] - role: Optional[Literal["system", "user", "assistant", "tool"]] - tool_calls: Optional[List[ChoiceDeltaToolCall]] - - OpenAI ChoiceDeltaToolCall: - index: int - id: Optional[str] - function: Optional[ChoiceDeltaToolCallFunction] - type: Optional[Literal["function"]] - - OpenAI ChoiceDeltaToolCallFunction: - name: Optional[str] - arguments: Optional[str] - - -> - - ChatCompletionResponseStreamChunk: - event: ChatCompletionResponseEvent - - ChatCompletionResponseEvent: - event_type: ChatCompletionResponseEventType - delta: Union[str, ToolCallDelta] - logprobs: Optional[List[TokenLogProbs]] - stop_reason: Optional[StopReason] - - ChatCompletionResponseEventType: - start = "start" - progress = "progress" - complete = "complete" - - ToolCallDelta: - content: Union[str, ToolCall] - parse_status: ToolCallParseStatus - - ToolCall: - call_id: str - tool_name: str - arguments: str - - ToolCallParseStatus: - started = "started" - in_progress = "in_progress" - failure = "failure" - success = "success" - - TokenLogProbs: - logprobs_by_token: Dict[str, float] - - token, logprob - - StopReason: - end_of_turn = "end_of_turn" - end_of_message = "end_of_message" - out_of_tokens = "out_of_tokens" - """ - - # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... - def _event_type_generator() -> ( - Generator[ChatCompletionResponseEventType, None, None] - ): - yield ChatCompletionResponseEventType.start - while True: - yield ChatCompletionResponseEventType.progress - - event_type = _event_type_generator() - - # we implement NIM specific semantics, the main difference from OpenAI - # is that tool_calls are always produced as a complete call. there is no - # intermediate / partial tool call streamed. because of this, we can - # simplify the logic and not concern outselves with parse_status of - # started/in_progress/failed. we can always assume success. - # - # a stream of ChatCompletionResponseStreamChunk consists of - # 0. a start event - # 1. zero or more progress events - # - each progress event has a delta - # - each progress event may have a stop_reason - # - each progress event may have logprobs - # - each progress event may have tool_calls - # if a progress event has tool_calls, - # it is fully formed and - # can be emitted with a parse_status of success - # 2. a complete event - - stop_reason = None - - async for chunk in stream: - choice = chunk.choices[0] # assuming only one choice per chunk - - # we assume there's only one finish_reason in the stream - stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason - - # if there's a tool call, emit an event for each tool in the list - # if tool call and content, emit both separately - - if choice.delta.tool_calls: - # the call may have content and a tool call. ChatCompletionResponseEvent - # does not support both, so we emit the content first - if choice.delta.content: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_type), - delta=choice.delta.content, - logprobs=_convert_openai_logprobs(choice.logprobs), - ) - ) - - # it is possible to have parallel tool calls in stream, but - # ChatCompletionResponseEvent only supports one per stream - if len(choice.delta.tool_calls) > 1: - warnings.warn( - "multiple tool calls found in a single delta, using the first, ignoring the rest" - ) - - # NIM only produces fully formed tool calls, so we can assume success - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_type), - delta=ToolCallDelta( - content=_convert_openai_tool_calls(choice.delta.tool_calls)[0], - parse_status=ToolCallParseStatus.success, - ), - logprobs=_convert_openai_logprobs(choice.logprobs), - ) - ) - else: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=next(event_type), - delta=choice.delta.content or "", # content is not optional - logprobs=_convert_openai_logprobs(choice.logprobs), - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) diff --git a/llama_stack/providers/remote/inference/nvidia/_utils.py b/llama_stack/providers/remote/inference/nvidia/_utils.py deleted file mode 100644 index c66cf75f43..0000000000 --- a/llama_stack/providers/remote/inference/nvidia/_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Tuple - -import httpx - -from ._config import NVIDIAConfig - - -async def _get_health(url: str) -> Tuple[bool, bool]: - """ - Query {url}/v1/health/{live,ready} to check if the server is running and ready - - Args: - url (str): URL of the server - - Returns: - Tuple[bool, bool]: (is_live, is_ready) - """ - async with httpx.AsyncClient() as client: - live = await client.get(f"{url}/v1/health/live") - ready = await client.get(f"{url}/v1/health/ready") - return live.status_code == 200, ready.status_code == 200 - - -async def check_health(config: NVIDIAConfig) -> None: - """ - Check if the server is running and ready - - Args: - url (str): URL of the server - - Raises: - RuntimeError: If the server is not running or ready - """ - if not config.is_hosted: - print("Checking NVIDIA NIM health...") - try: - is_live, is_ready = await _get_health(config.url) - if not is_live: - raise ConnectionError("NVIDIA NIM is not running") - if not is_ready: - raise ConnectionError("NVIDIA NIM is not ready") - # TODO(mf): should we wait for the server to be ready? - except httpx.ConnectError as e: - raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e diff --git a/llama_stack/templates/nvidia/doc_template.md b/llama_stack/templates/nvidia/doc_template.md index a9db770558..949018f8d6 100644 --- a/llama_stack/templates/nvidia/doc_template.md +++ b/llama_stack/templates/nvidia/doc_template.md @@ -47,14 +47,14 @@ docker run \ llamastack/distribution-{{ name }} \ --yaml-config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ - --env FIREWORKS_API_KEY=$FIREWORKS_API_KEY + --env NVIDIA_API_KEY=$NVIDIA_API_KEY ``` ### Via Conda ```bash -llama stack build --template fireworks --image-type conda +llama stack build --template nvidia --image-type conda llama stack run ./run.yaml \ --port 5001 \ - --env FIREWORKS_API_KEY=$FIREWORKS_API_KEY + --env NVIDIA_API_KEY=$NVIDIA_API_KEY ``` diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index 0f15511804..22aa1f4b08 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -6,11 +6,9 @@ from pathlib import Path -from llama_models.sku_list import all_registered_models - -from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.distribution.datatypes import ModelInput, Provider from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig -from llama_stack.providers.remote.inference.nvidia._nvidia import _MODEL_ALIASES +from llama_stack.providers.remote.inference.nvidia.nvidia import _MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings From 091d1969794ed7432800f419ad7225b98834749a Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Tue, 3 Dec 2024 19:23:41 -0800 Subject: [PATCH 8/9] Changing fireworks to nvidia in docs --- docs/source/distributions/self_hosted_distro/nvidia.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/distributions/self_hosted_distro/nvidia.md b/docs/source/distributions/self_hosted_distro/nvidia.md index 3ea2200141..2f34997166 100644 --- a/docs/source/distributions/self_hosted_distro/nvidia.md +++ b/docs/source/distributions/self_hosted_distro/nvidia.md @@ -53,7 +53,7 @@ docker run \ ### Via Conda ```bash -llama stack build --template fireworks --image-type conda +llama stack build --template nvidia --image-type conda llama stack run ./run.yaml \ --port 5001 \ --env NVIDIA_API_KEY=$NVIDIA_API_KEY From 1658a5fe7534c2d90a8d29452bd5b43f182cac31 Mon Sep 17 00:00:00 2001 From: Chantal D Gama Rose Date: Wed, 4 Dec 2024 09:32:45 -0800 Subject: [PATCH 9/9] Added nvidia as remote hosted distro --- docs/source/distributions/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/distributions/index.md b/docs/source/distributions/index.md index b61e9b28f5..e88960a848 100644 --- a/docs/source/distributions/index.md +++ b/docs/source/distributions/index.md @@ -24,6 +24,7 @@ If so, we suggest: - {dockerhub}`distribution-remote-vllm` ([Guide](self_hosted_distro/remote-vllm)) - {dockerhub}`distribution-meta-reference-gpu` ([Guide](self_hosted_distro/meta-reference-gpu)) - {dockerhub}`distribution-tgi` ([Guide](self_hosted_distro/tgi)) + - {dockerhub} `distribution-nvidia` ([Guide](self_hosted_distro/nvidia)) - **Are you running on a "regular" desktop machine?** If so, we suggest: