From b74a05114fb5bdb069d2b44285a4fdd630bc5140 Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Thu, 19 Dec 2024 11:27:59 -0800 Subject: [PATCH 1/5] Add inline vLLM provider to regression tests --- .../providers/tests/inference/fixtures.py | 28 ++++++++++++++++++- .../tests/inference/test_text_inference.py | 20 +++++++------ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index d9c0cb1889..e935cd8c24 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -15,6 +15,7 @@ from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) +from llama_stack.providers.inline.inference.vllm import VLLMConfig from llama_stack.providers.remote.inference.bedrock import BedrockConfig from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig @@ -104,6 +105,26 @@ def inference_ollama(inference_model) -> ProviderFixture: ) +@pytest_asyncio.fixture(scope="session") +def inference_vllm(inference_model) -> ProviderFixture: + inference_model = ( + [inference_model] if isinstance(inference_model, str) else inference_model + ) + return ProviderFixture( + providers=[ + Provider( + provider_id=f"vllm-{i}", + provider_type="inline::vllm", + config=VLLMConfig( + model=m, + enforce_eager=True, # Make test run faster + ).model_dump(), + ) + for i, m in enumerate(inference_model) + ] + ) + + @pytest.fixture(scope="session") def inference_vllm_remote() -> ProviderFixture: return ProviderFixture( @@ -222,6 +243,7 @@ def model_id(inference_model) -> str: "ollama", "fireworks", "together", + "vllm", "vllm_remote", "remote", "bedrock", @@ -254,4 +276,8 @@ async def inference_stack(request, inference_model): ], ) - return test_stack.impls[Api.inference], test_stack.impls[Api.models] + # Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended + yield test_stack.impls[Api.inference], test_stack.impls[Api.models] + + # Cleanup code that runs after test case completion + await test_stack.impls[Api.inference].shutdown() diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99a62ac080..55a3bd0a8f 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -67,7 +67,9 @@ def sample_tool_definition(): class TestInference: - @pytest.mark.asyncio + # Session scope for asyncio because the tests in this class all + # share the same provider instance. + @pytest.mark.asyncio(loop_scope="session") async def test_model_list(self, inference_model, inference_stack): _, models_impl = inference_stack response = await models_impl.list_models() @@ -83,7 +85,7 @@ async def test_model_list(self, inference_model, inference_stack): assert model_def is not None - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_completion(self, inference_model, inference_stack): inference_impl, _ = inference_stack @@ -128,7 +130,7 @@ async def test_completion(self, inference_model, inference_stack): last = chunks[-1] assert last.stop_reason == StopReason.out_of_tokens - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_completion_logprobs(self, inference_model, inference_stack): inference_impl, _ = inference_stack @@ -183,7 +185,7 @@ async def test_completion_logprobs(self, inference_model, inference_stack): else: # no token, no logprobs assert not chunk.logprobs, "Logprobs should be empty" - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") @pytest.mark.skip("This test is not quite robust") async def test_completion_structured_output(self, inference_model, inference_stack): inference_impl, _ = inference_stack @@ -227,7 +229,7 @@ class Output(BaseModel): assert answer.year_born == "1963" assert answer.year_retired == "2003" - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_non_streaming( self, inference_model, inference_stack, common_params, sample_messages ): @@ -244,7 +246,7 @@ async def test_chat_completion_non_streaming( assert isinstance(response.completion_message.content, str) assert len(response.completion_message.content) > 0 - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_structured_output( self, inference_model, inference_stack, common_params ): @@ -314,7 +316,7 @@ class AnswerFormat(BaseModel): with pytest.raises(ValidationError): AnswerFormat.model_validate_json(response.completion_message.content) - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_streaming( self, inference_model, inference_stack, common_params, sample_messages ): @@ -341,7 +343,7 @@ async def test_chat_completion_streaming( end = grouped[ChatCompletionResponseEventType.complete][0] assert end.event.stop_reason == StopReason.end_of_turn - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_with_tool_calling( self, inference_model, @@ -380,7 +382,7 @@ async def test_chat_completion_with_tool_calling( assert "location" in call.arguments assert "San Francisco" in call.arguments["location"] - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_with_tool_calling_streaming( self, inference_model, From 9d23c063d52db2c9dcb99c95f99fc950670c981d Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Thu, 19 Dec 2024 11:28:15 -0800 Subject: [PATCH 2/5] Fix regressions in inline vLLM provider --- .../providers/inline/inference/vllm/vllm.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 0e7ba872c9..72aa2200a9 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -50,7 +50,7 @@ def __init__(self, config: VLLMConfig): self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self): - log.info("Initializing vLLM inference adapter") + log.info("Initializing vLLM inference provider.") # Disable usage stats reporting. This would be a surprising thing for most # people to find out was on by default. @@ -79,14 +79,33 @@ async def initialize(self): async def shutdown(self): """Shutdown the vLLM inference adapter.""" - log.info("Shutting down vLLM inference adapter") + log.info("Shutting down vLLM inference provider.") if self.engine: self.engine.shutdown_background_loop() - async def register_model(self, model: Model) -> None: - raise ValueError( - "You cannot dynamically add a model to a running vllm instance" - ) + # Note that the return type of the superclass method is WRONG + async def register_model(self, model: Model) -> Model: + """ + Callback that is called when the server associates an inference endpoint + with an inference provider. + + :param model: Object that encapsulates parameters necessary for identifying + a specific LLM. + + :returns: The input ``Model`` object. It may or may not be permissible + to change fields before returning this object. + """ + log.info(f"Registering model {model.identifier} with vLLM inference provider.") + # The current version of this provided is hard-coded to serve only + # the model specified in the YAML config file. + configured_model = resolve_model(self.config.model) + registered_model = resolve_model(model.model_id) + + if configured_model.core_model_id != registered_model.core_model_id: + raise ValueError(f"Requested model '{model.identifier}' is different from " + f"model '{self.config.model}' that this provider " + f"is configured to serve") + return model def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams: if sampling_params is None: @@ -160,7 +179,7 @@ async def chat_completion( log.info("Sampling params: %s", sampling_params) request_id = _random_uuid() - prompt = chat_completion_request_to_prompt(request, self.formatter) + prompt = chat_completion_request_to_prompt(request, self.config.model, self.formatter) vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate( prompt, vllm_sampling_params, request_id @@ -216,7 +235,7 @@ async def _generate_and_convert_to_openai_compat(): stream, self.formatter ): yield chunk - + async def embeddings( self, model_id: str, contents: list[InterleavedTextMedia] ) -> EmbeddingsResponse: From 6ec9eabbeb76fd0938c47aa72753b87950e611b8 Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Thu, 19 Dec 2024 11:35:25 -0800 Subject: [PATCH 3/5] Redo code change after merge --- llama_stack/providers/inline/inference/vllm/vllm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index ac38e13b52..12c6c03700 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -169,7 +169,8 @@ async def chat_completion( log.info("Sampling params: %s", sampling_params) request_id = _random_uuid() - prompt = await chat_completion_request_to_prompt(request, self.formatter) + prompt = await chat_completion_request_to_prompt(request, self.config.model, + self.formatter) vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate( prompt, vllm_sampling_params, request_id From c8580d3b0c44f3b178555a1ff9d1a5d72012b47e Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Thu, 19 Dec 2024 11:46:31 -0800 Subject: [PATCH 4/5] Apply formatting to source files --- .../providers/inline/inference/vllm/vllm.py | 25 +++++++++++-------- .../providers/tests/inference/fixtures.py | 2 +- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 12c6c03700..1caae96872 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -88,23 +88,25 @@ async def register_model(self, model: Model) -> Model: """ Callback that is called when the server associates an inference endpoint with an inference provider. - + :param model: Object that encapsulates parameters necessary for identifying a specific LLM. - + :returns: The input ``Model`` object. It may or may not be permissible to change fields before returning this object. """ - log.info(f"Registering model {model.identifier} with vLLM inference provider.") - # The current version of this provided is hard-coded to serve only + log.info(f"Registering model {model.identifier} with vLLM inference provider.") + # The current version of this provided is hard-coded to serve only # the model specified in the YAML config file. configured_model = resolve_model(self.config.model) registered_model = resolve_model(model.model_id) - + if configured_model.core_model_id != registered_model.core_model_id: - raise ValueError(f"Requested model '{model.identifier}' is different from " - f"model '{self.config.model}' that this provider " - f"is configured to serve") + raise ValueError( + f"Requested model '{model.identifier}' is different from " + f"model '{self.config.model}' that this provider " + f"is configured to serve" + ) return model def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams: @@ -169,8 +171,9 @@ async def chat_completion( log.info("Sampling params: %s", sampling_params) request_id = _random_uuid() - prompt = await chat_completion_request_to_prompt(request, self.config.model, - self.formatter) + prompt = await chat_completion_request_to_prompt( + request, self.config.model, self.formatter + ) vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate( prompt, vllm_sampling_params, request_id @@ -226,7 +229,7 @@ async def _generate_and_convert_to_openai_compat(): stream, self.formatter ): yield chunk - + async def embeddings( self, model_id: str, contents: List[InterleavedContent] ) -> EmbeddingsResponse: diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index f3c7df404d..524bc69dbb 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -292,6 +292,6 @@ async def inference_stack(request, inference_model): # Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended yield test_stack.impls[Api.inference], test_stack.impls[Api.models] - + # Cleanup code that runs after test case completion await test_stack.impls[Api.inference].shutdown() From 82c10c917faadb761cf7167c30d4935133dc66e7 Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Thu, 19 Dec 2024 15:06:47 -0800 Subject: [PATCH 5/5] Minor change to force rerun of automatic jobs --- llama_stack/providers/inline/inference/vllm/vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 1caae96872..2672c3dbb6 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -78,7 +78,7 @@ async def initialize(self): self.engine = AsyncLLMEngine.from_engine_args(engine_args) async def shutdown(self): - """Shutdown the vLLM inference adapter.""" + """Shut down the vLLM inference adapter.""" log.info("Shutting down vLLM inference provider.") if self.engine: self.engine.shutdown_background_loop()