diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 2c721d9ba7609..faada2ce64bcd 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -50,12 +50,3 @@ async def test_check_health(client: openai.AsyncOpenAI): response = requests.get(base_url + "/health") assert response.status_code == HTTPStatus.OK - - -@pytest.mark.asyncio -async def test_log_metrics(client: openai.AsyncOpenAI): - base_url = str(client.base_url)[:-3].strip("/") - - response = requests.get(base_url + "/metrics") - - assert response.status_code == HTTPStatus.OK diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py new file mode 100644 index 0000000000000..cbe601e623056 --- /dev/null +++ b/tests/entrypoints/openai/test_metrics.py @@ -0,0 +1,179 @@ +from http import HTTPStatus + +import openai +import pytest +import requests +from prometheus_client.parser import text_string_to_metric_families +from transformers import AutoTokenizer + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "1024", + "--enforce-eager", + "--max-num-seqs", + "128", + ] + + +@pytest.fixture(scope="module", + params=[ + "", + "--enable-chunked-prefill", + "--disable-frontend-multiprocessing", + ]) +def client(default_server_args, request): + if request.param: + default_server_args.append(request.param) + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server.get_async_client() + + +_PROMPT = "Hello my name is Robert and I love magic" +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"] + +_NUM_REQUESTS = 10 +_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT) +_NUM_GENERATION_TOKENS_PER_REQUEST = 10 + +# {metric_family: [(suffix, expected_value)]} +EXPECTED_VALUES = { + "vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)], + "vllm:time_per_output_token_seconds": + [("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))], + "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], + "vllm:request_prompt_tokens": + [("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), + ("_count", _NUM_REQUESTS)], + "vllm:request_generation_tokens": + [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), + ("_count", _NUM_REQUESTS)], + "vllm:request_params_n": [("_count", _NUM_REQUESTS)], + "vllm:request_params_best_of": [("_count", _NUM_REQUESTS)], + "vllm:prompt_tokens": [("_total", + _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], + "vllm:generation_tokens": + [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], + "vllm:request_success": [("_total", _NUM_REQUESTS)], +} + + +@pytest.mark.asyncio +async def test_metrics_counts(client: openai.AsyncOpenAI): + base_url = str(client.base_url)[:-3].strip("/") + + for _ in range(_NUM_REQUESTS): + # sending a request triggers the metrics to be logged. + await client.completions.create( + model=MODEL_NAME, + prompt=_TOKENIZED_PROMPT, + max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST) + + response = requests.get(base_url + "/metrics") + print(response.text) + assert response.status_code == HTTPStatus.OK + + # Loop over all expected metric_families + for metric_family, suffix_values_list in EXPECTED_VALUES.items(): + found_metric = False + + # Check to see if the metric_family is found in the prom endpoint. + for family in text_string_to_metric_families(response.text): + if family.name == metric_family: + found_metric = True + + # Check that each suffix is found in the prom endpoint. + for suffix, expected_value in suffix_values_list: + metric_name_w_suffix = f"{metric_family}{suffix}" + found_suffix = False + + for sample in family.samples: + if sample.name == metric_name_w_suffix: + found_suffix = True + + # For each suffix, value sure the value matches + # what we expect. + assert sample.value == expected_value, ( + f"{metric_name_w_suffix} expected value of " + f"{expected_value} did not match found value " + f"{sample.value}") + break + assert found_suffix, ( + f"Did not find {metric_name_w_suffix} in prom endpoint" + ) + break + + assert found_metric, (f"Did not find {metric_family} in prom endpoint") + + +EXPECTED_METRICS = [ + "vllm:num_requests_running", + "vllm:num_requests_swapped", + "vllm:num_requests_waiting", + "vllm:gpu_cache_usage_perc", + "vllm:cpu_cache_usage_perc", + "vllm:time_to_first_token_seconds_sum", + "vllm:time_to_first_token_seconds_bucket", + "vllm:time_to_first_token_seconds_count", + "vllm:time_per_output_token_seconds_sum", + "vllm:time_per_output_token_seconds_bucket", + "vllm:time_per_output_token_seconds_count", + "vllm:e2e_request_latency_seconds_sum", + "vllm:e2e_request_latency_seconds_bucket", + "vllm:e2e_request_latency_seconds_count", + "vllm:request_prompt_tokens_sum", + "vllm:request_prompt_tokens_bucket", + "vllm:request_prompt_tokens_count", + "vllm:request_generation_tokens_sum", + "vllm:request_generation_tokens_bucket", + "vllm:request_generation_tokens_count", + "vllm:request_params_n_sum", + "vllm:request_params_n_bucket", + "vllm:request_params_n_count", + "vllm:request_params_best_of_sum", + "vllm:request_params_best_of_bucket", + "vllm:request_params_best_of_count", + "vllm:num_preemptions_total", + "vllm:prompt_tokens_total", + "vllm:generation_tokens_total", + "vllm:request_success_total", + "vllm:cache_config_info", + # labels in cache_config_info + "block_size", + "cache_dtype", + "cpu_offload_gb", + "enable_prefix_caching", + "gpu_memory_utilization", + "num_cpu_blocks", + "num_gpu_blocks", + "num_gpu_blocks_override", + "sliding_window", + "swap_space_bytes", +] + + +@pytest.mark.asyncio +async def test_metrics_exist(client: openai.AsyncOpenAI): + base_url = str(client.base_url)[:-3].strip("/") + + # sending a request triggers the metrics to be logged. + await client.completions.create(model=MODEL_NAME, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + response = requests.get(base_url + "/metrics") + assert response.status_code == HTTPStatus.OK + + for metric in EXPECTED_METRICS: + assert metric in response.text diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a28b20fcbbcd8..dced804fccca9 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -15,7 +15,7 @@ from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, PromptComponents) -from vllm.engine.metrics import StatLoggerBase +from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4ddb80ff7de1a..021f4f2484307 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -16,8 +16,7 @@ from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger, - StatLoggerBase, Stats) +from vllm.engine.metrics_types import StatLoggerBase, Stats from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker @@ -339,6 +338,13 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: if stat_loggers is not None: self.stat_loggers = stat_loggers else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import (LoggingStatLogger, + PrometheusStatLogger) + self.stat_loggers = { "logging": LoggingStatLogger( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 2f105b9cd2fb6..1071786c27cd6 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,13 +1,12 @@ -import time -from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import TYPE_CHECKING from typing import Counter as CollectionsCounter -from typing import Dict, List, Optional, Protocol, Union +from typing import Dict, List, Optional, Union import numpy as np import prometheus_client +from vllm.engine.metrics_types import (StatLoggerBase, Stats, + SupportsMetricsInfo) from vllm.executor.ray_utils import ray from vllm.logger import init_logger @@ -29,41 +28,49 @@ # begin-metrics-definitions class Metrics: + """ + vLLM uses a multiprocessing-based frontend for the OpenAI server. + This means that we need to run prometheus_client in multiprocessing mode + See https://prometheus.github.io/client_python/multiprocess/ for more + details on limitations. + """ labelname_finish_reason = "finished_reason" _gauge_cls = prometheus_client.Gauge _counter_cls = prometheus_client.Counter _histogram_cls = prometheus_client.Histogram def __init__(self, labelnames: List[str], max_model_len: int): - # Unregister any existing vLLM collectors + # Unregister any existing vLLM collectors (for CI/CD) self._unregister_vllm_metrics() - # Config Information - self._create_info_cache_config() - # System stats # Scheduler State self.gauge_scheduler_running = self._gauge_cls( name="vllm:num_requests_running", documentation="Number of requests currently running on GPU.", - labelnames=labelnames) + labelnames=labelnames, + multiprocess_mode="sum") self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", - labelnames=labelnames) + labelnames=labelnames, + multiprocess_mode="sum") self.gauge_scheduler_swapped = self._gauge_cls( name="vllm:num_requests_swapped", documentation="Number of requests swapped to CPU.", - labelnames=labelnames) + labelnames=labelnames, + multiprocess_mode="sum") # KV Cache Usage in % self.gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", - labelnames=labelnames) + labelnames=labelnames, + multiprocess_mode="sum") self.gauge_cpu_cache_usage = self._gauge_cls( name="vllm:cpu_cache_usage_perc", documentation="CPU KV-cache usage. 1 means 100 percent usage.", - labelnames=labelnames) + labelnames=labelnames, + multiprocess_mode="sum") # Iteration stats self.counter_num_preemption = self._counter_cls( @@ -137,11 +144,13 @@ def __init__(self, labelnames: List[str], max_model_len: int): self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( name="vllm:spec_decode_draft_acceptance_rate", documentation="Speulative token acceptance rate.", - labelnames=labelnames) + labelnames=labelnames, + multiprocess_mode="sum") self.gauge_spec_decode_efficiency = self._gauge_cls( name="vllm:spec_decode_efficiency", documentation="Speculative decoding system efficiency.", - labelnames=labelnames) + labelnames=labelnames, + multiprocess_mode="sum") self.counter_spec_decode_num_accepted_tokens = (self._counter_cls( name="vllm:spec_decode_num_accepted_tokens_total", documentation="Number of accepted tokens.", @@ -160,19 +169,18 @@ def __init__(self, labelnames: List[str], max_model_len: int): name="vllm:avg_prompt_throughput_toks_per_s", documentation="Average prefill throughput in tokens/s.", labelnames=labelnames, + multiprocess_mode="sum", ) # Deprecated in favor of vllm:generation_tokens_total self.gauge_avg_generation_throughput = self._gauge_cls( name="vllm:avg_generation_throughput_toks_per_s", documentation="Average generation throughput in tokens/s.", labelnames=labelnames, + multiprocess_mode="sum", ) - def _create_info_cache_config(self) -> None: - # Config Information - self.info_cache_config = prometheus_client.Info( - name='vllm:cache_config', - documentation='information of cache_config') + +# end-metrics-definitions def _unregister_vllm_metrics(self) -> None: for collector in list(prometheus_client.REGISTRY._collector_to_names): @@ -180,9 +188,6 @@ def _unregister_vllm_metrics(self) -> None: prometheus_client.REGISTRY.unregister(collector) -# end-metrics-definitions - - class _RayGaugeWrapper: """Wraps around ray.util.metrics.Gauge to provide same API as prometheus_client.Gauge""" @@ -190,7 +195,9 @@ class _RayGaugeWrapper: def __init__(self, name: str, documentation: str = "", - labelnames: Optional[List[str]] = None): + labelnames: Optional[List[str]] = None, + multiprocess_mode: str = ""): + del multiprocess_mode labelnames_tuple = tuple(labelnames) if labelnames else None self._gauge = ray_metrics.Gauge(name=name, description=documentation, @@ -268,10 +275,6 @@ def _unregister_vllm_metrics(self) -> None: # No-op on purpose pass - def _create_info_cache_config(self) -> None: - # No-op on purpose - pass - def build_1_2_5_buckets(max_value: int) -> List[int]: """ @@ -295,46 +298,6 @@ def build_1_2_5_buckets(max_value: int) -> List[int]: exponent += 1 -@dataclass -class Stats: - """Created by LLMEngine for use by StatLogger.""" - now: float - - # System stats (should have _sys suffix) - # Scheduler State - num_running_sys: int - num_waiting_sys: int - num_swapped_sys: int - # KV Cache Usage in % - gpu_cache_usage_sys: float - cpu_cache_usage_sys: float - - # Iteration stats (should have _iter suffix) - num_prompt_tokens_iter: int - num_generation_tokens_iter: int - time_to_first_tokens_iter: List[float] - time_per_output_tokens_iter: List[float] - num_preemption_iter: int - - # Request stats (should have _requests suffix) - # Latency - time_e2e_requests: List[float] - # Metadata - num_prompt_tokens_requests: List[int] - num_generation_tokens_requests: List[int] - best_of_requests: List[int] - n_requests: List[int] - finished_reason_requests: List[str] - - spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None - - -class SupportsMetricsInfo(Protocol): - - def metrics_info(self) -> Dict[str, str]: - ... - - def local_interval_elapsed(now: float, last_log: float, local_interval: float) -> bool: elapsed_time = now - last_log @@ -346,38 +309,9 @@ def get_throughput(tracked_stats: List[int], now: float, return float(np.sum(tracked_stats) / (now - last_log)) -class StatLoggerBase(ABC): - """Base class for StatLogger.""" - - def __init__(self, local_interval: float) -> None: - # Tracked stats over current local logging interval. - self.num_prompt_tokens: List[int] = [] - self.num_generation_tokens: List[int] = [] - self.last_local_log = time.time() - self.local_interval = local_interval - self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None - - @abstractmethod - def info(self, type: str, obj: SupportsMetricsInfo) -> None: - raise NotImplementedError - - @abstractmethod - def log(self, stats: Stats) -> None: - raise NotImplementedError - - def maybe_update_spec_decode_metrics(self, stats: Stats): - """Save spec decode metrics (since they are unlikely - to be emitted at same time as log interval).""" - if stats.spec_decode_metrics is not None: - self.spec_decode_metrics = stats.spec_decode_metrics - - class LoggingStatLogger(StatLoggerBase): """LoggingStatLogger is used in LLMEngine to log to Stdout.""" - def info(self, type: str, obj: SupportsMetricsInfo) -> None: - raise NotImplementedError - def log(self, stats: Stats) -> None: """Called by LLMEngine. Logs to Stdout every self.local_interval seconds.""" @@ -440,10 +374,14 @@ def _format_spec_decode_metrics_str( f"Number of draft tokens: {metrics.draft_tokens}, " f"Number of emitted tokens: {metrics.emitted_tokens}.") + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + class PrometheusStatLogger(StatLoggerBase): """PrometheusStatLogger is used LLMEngine to log to Promethus.""" _metrics_cls = Metrics + _gauge_cls = prometheus_client.Gauge def __init__(self, local_interval: float, labels: Dict[str, str], max_model_len: int) -> None: @@ -453,10 +391,6 @@ def __init__(self, local_interval: float, labels: Dict[str, str], self.metrics = self._metrics_cls(labelnames=list(labels.keys()), max_model_len=max_model_len) - def info(self, type: str, obj: SupportsMetricsInfo) -> None: - if type == "cache_config": - self.metrics.info_cache_config.info(obj.metrics_info()) - def _log_gauge(self, gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. gauge.labels(**self.labels).set(data) @@ -586,6 +520,19 @@ def log(self, stats: Stats): self.last_local_log = stats.now self.spec_decode_metrics = None + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + # Info type metrics are syntactic sugar for a gauge permanently set to 1 + # Since prometheus multiprocessing mode does not support Info, emulate + # info here with a gauge. + if type == "cache_config": + metrics_info = obj.metrics_info() + info_gauge = self._gauge_cls( + name="vllm:cache_config_info", + documentation="Information of the LLMEngine CacheConfig", + labelnames=metrics_info.keys(), + multiprocess_mode="mostrecent") + info_gauge.labels(**metrics_info).set(1) + class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py new file mode 100644 index 0000000000000..7449aafc5aecb --- /dev/null +++ b/vllm/engine/metrics_types.py @@ -0,0 +1,85 @@ +""" +These types are defined in this file to avoid importing vllm.engine.metrics +and therefore importing prometheus_client. + +This is required due to usage of Prometheus multiprocess mode to enable +metrics after splitting out the uvicorn process from the engine process. + +Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR +before prometheus_client is imported. Typically, this is done by setting +the env variable before launch, but since we are a library, we need to +do this in Python code and lazily import prometheus_client. +""" + +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, List, Optional, Protocol + +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + + +@dataclass +class Stats: + """Created by LLMEngine for use by StatLogger.""" + now: float + + # System stats (should have _sys suffix) + # Scheduler State + num_running_sys: int + num_waiting_sys: int + num_swapped_sys: int + # KV Cache Usage in % + gpu_cache_usage_sys: float + cpu_cache_usage_sys: float + + # Iteration stats (should have _iter suffix) + num_prompt_tokens_iter: int + num_generation_tokens_iter: int + time_to_first_tokens_iter: List[float] + time_per_output_tokens_iter: List[float] + num_preemption_iter: int + + # Request stats (should have _requests suffix) + # Latency + time_e2e_requests: List[float] + # Metadata + num_prompt_tokens_requests: List[int] + num_generation_tokens_requests: List[int] + best_of_requests: List[int] + n_requests: List[int] + finished_reason_requests: List[str] + + spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> Dict[str, str]: + ... + + +class StatLoggerBase(ABC): + """Base class for StatLogger.""" + + def __init__(self, local_interval: float) -> None: + # Tracked stats over current local logging interval. + self.num_prompt_tokens: List[int] = [] + self.num_generation_tokens: List[int] = [] + self.last_local_log = time.time() + self.local_interval = local_interval + self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + @abstractmethod + def log(self, stats: Stats) -> None: + raise NotImplementedError + + @abstractmethod + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + def maybe_update_spec_decode_metrics(self, stats: Stats): + """Save spec decode metrics (since they are unlikely + to be emitted at same time as log interval).""" + if stats.spec_decode_metrics is not None: + self.spec_decode_metrics = stats.spec_decode_metrics diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a641dcc24aaae..d79238e08d540 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,7 +2,9 @@ import importlib import inspect import multiprocessing +import os import re +import tempfile from argparse import Namespace from contextlib import asynccontextmanager from http import HTTPStatus @@ -12,7 +14,6 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse -from prometheus_client import make_asgi_app from starlette.routing import Mount import vllm.envs as envs @@ -54,6 +55,7 @@ openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding openai_serving_tokenization: OpenAIServingTokenization +prometheus_multiproc_dir: tempfile.TemporaryDirectory logger = init_logger('vllm.entrypoints.openai.api_server') @@ -109,6 +111,21 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: # Otherwise, use the multiprocessing AsyncLLMEngine. else: + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + global prometheus_multiproc_dir + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ[ + "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + else: + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + # Select random path for IPC. rpc_path = get_open_zmq_ipc_path() logger.info("Multiprocessing frontend to use %s for RPC Path.", @@ -149,13 +166,38 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: # Wait for server process to join rpc_server_process.join() + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import multiprocess + multiprocess.mark_process_dead(rpc_server_process.pid) + router = APIRouter() def mount_metrics(app: FastAPI): - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app()) + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import (CollectorRegistry, make_asgi_app, + multiprocess) + + prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) + if prometheus_multiproc_dir_path is not None: + logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", + prometheus_multiproc_dir_path) + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + else: + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app()) + # Workaround for 307 Redirect for /metrics metrics_route.path_regex = re.compile('^/metrics(?P.*)$') app.routes.append(metrics_route)