Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inline vLLM inference provider to regression tests and fix regressions #662

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions llama_stack/providers/inline/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, config: VLLMConfig):
self.formatter = ChatFormat(Tokenizer.get_instance())

async def initialize(self):
log.info("Initializing vLLM inference adapter")
log.info("Initializing vLLM inference provider.")

# Disable usage stats reporting. This would be a surprising thing for most
# people to find out was on by default.
Expand Down Expand Up @@ -78,15 +78,36 @@ async def initialize(self):
self.engine = AsyncLLMEngine.from_engine_args(engine_args)

async def shutdown(self):
"""Shutdown the vLLM inference adapter."""
log.info("Shutting down vLLM inference adapter")
"""Shut down the vLLM inference adapter."""
log.info("Shutting down vLLM inference provider.")
if self.engine:
self.engine.shutdown_background_loop()

async def register_model(self, model: Model) -> None:
raise ValueError(
"You cannot dynamically add a model to a running vllm instance"
)
# Note that the return type of the superclass method is WRONG
async def register_model(self, model: Model) -> Model:
"""
Callback that is called when the server associates an inference endpoint
with an inference provider.

:param model: Object that encapsulates parameters necessary for identifying
a specific LLM.

:returns: The input ``Model`` object. It may or may not be permissible
to change fields before returning this object.
"""
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
# The current version of this provided is hard-coded to serve only
# the model specified in the YAML config file.
configured_model = resolve_model(self.config.model)
registered_model = resolve_model(model.model_id)

if configured_model.core_model_id != registered_model.core_model_id:
raise ValueError(
f"Requested model '{model.identifier}' is different from "
f"model '{self.config.model}' that this provider "
f"is configured to serve"
)
return model

def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
if sampling_params is None:
Expand Down Expand Up @@ -150,7 +171,9 @@ async def chat_completion(
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()

prompt = await chat_completion_request_to_prompt(request, self.formatter)
prompt = await chat_completion_request_to_prompt(
request, self.config.model, self.formatter
)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
Expand Down
28 changes: 27 additions & 1 deletion llama_stack/providers/tests/inference/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.inline.inference.vllm import VLLMConfig
from llama_stack.providers.remote.inference.bedrock import BedrockConfig

from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
Expand Down Expand Up @@ -104,6 +105,26 @@ def inference_ollama(inference_model) -> ProviderFixture:
)


@pytest_asyncio.fixture(scope="session")
def inference_vllm(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
return ProviderFixture(
providers=[
Provider(
provider_id=f"vllm-{i}",
provider_type="inline::vllm",
config=VLLMConfig(
model=m,
enforce_eager=True, # Make test run faster
).model_dump(),
)
for i, m in enumerate(inference_model)
]
)


@pytest.fixture(scope="session")
def inference_vllm_remote() -> ProviderFixture:
return ProviderFixture(
Expand Down Expand Up @@ -236,6 +257,7 @@ def model_id(inference_model) -> str:
"ollama",
"fireworks",
"together",
"vllm",
"vllm_remote",
"remote",
"bedrock",
Expand Down Expand Up @@ -268,4 +290,8 @@ async def inference_stack(request, inference_model):
],
)

return test_stack.impls[Api.inference], test_stack.impls[Api.models]
# Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended
yield test_stack.impls[Api.inference], test_stack.impls[Api.models]

# Cleanup code that runs after test case completion
await test_stack.impls[Api.inference].shutdown()
20 changes: 11 additions & 9 deletions llama_stack/providers/tests/inference/test_text_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def sample_tool_definition():


class TestInference:
@pytest.mark.asyncio
# Session scope for asyncio because the tests in this class all
# share the same provider instance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frreiss can you explain a bit what this change does (or rather not having this causes)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, happy to explain.

The fixtures in llama_stack/providers/tests/inference/fixtures.py are all tagged with @pytest.fixture(scope="session"). This tag means that these fixtures are initialized once, then reused for the rest of the test session. Reusing the fixtures this way makes the tests run faster. The fixtures for inline providers need to load models from disk.

The default behavior of the pytest-asyncio plugin is to spawn a new ayncio event loop for every test case.

The result of these two different scoping policies is that the fixtures are being initialized under one event loop, then the test cases interact with the fixtures from a different event loop (one event loop per test case). This change of event loops happens not to break the existing tests, because the inference providers they exercise are stateless at the asyncio layer.

vLLM is not stateless at the asyncio layer. An idle vLLM instance has an event handler waiting for new inference requests to be added to the current batch. Switching to a different event loop drops this event handler, preventing vLLM's batching mechanism from functioning. When I added the inline vLLM provider to the test cases, the change in event loops caused the inference tests to hang.

Adding @pytest.mark.asyncio(loop_scope="session") to a test case prevents pytest-asyncio from switching event loops when running the test case. This change ensures that the asyncio event loop used during the test case is the same as the event loop that was present when any session-scoped fixtures were initialized.

The primary potential downside of scoping the event loop in this way is that, if a misbehaving test case were to leave orphan event handlers, those event handlers could cause errors in later test cases instead of causing an error in the misbehaving test case. This risk seemed acceptable to me.

@pytest.mark.asyncio(loop_scope="session")
async def test_model_list(self, inference_model, inference_stack):
_, models_impl = inference_stack
response = await models_impl.list_models()
Expand All @@ -83,7 +85,7 @@ async def test_model_list(self, inference_model, inference_stack):

assert model_def is not None

@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_completion(self, inference_model, inference_stack):
inference_impl, _ = inference_stack

Expand Down Expand Up @@ -128,7 +130,7 @@ async def test_completion(self, inference_model, inference_stack):
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens

@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_logprobs(self, inference_model, inference_stack):
inference_impl, _ = inference_stack

Expand Down Expand Up @@ -183,7 +185,7 @@ async def test_completion_logprobs(self, inference_model, inference_stack):
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"

@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.skip("This test is not quite robust")
async def test_completion_structured_output(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
Expand Down Expand Up @@ -227,7 +229,7 @@ class Output(BaseModel):
assert answer.year_born == "1963"
assert answer.year_retired == "2003"

@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
Expand All @@ -244,7 +246,7 @@ async def test_chat_completion_non_streaming(
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0

@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(
self, inference_model, inference_stack, common_params
):
Expand Down Expand Up @@ -314,7 +316,7 @@ class AnswerFormat(BaseModel):
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)

@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
Expand All @@ -341,7 +343,7 @@ async def test_chat_completion_streaming(
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn

@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling(
self,
inference_model,
Expand Down Expand Up @@ -380,7 +382,7 @@ async def test_chat_completion_with_tool_calling(
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]

@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling_streaming(
self,
inference_model,
Expand Down