Skip to content

Commit

Permalink
Enable mypy checking on V1 code (vllm-project#11105)
Browse files Browse the repository at this point in the history
Signed-off-by: Mark McLoughlin <[email protected]>
  • Loading branch information
markmc authored Dec 14, 2024
1 parent 93abf23 commit 6d917d0
Show file tree
Hide file tree
Showing 21 changed files with 160 additions and 121 deletions.
1 change: 1 addition & 0 deletions tools/mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/worker
run_mypy vllm/v1
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

assert output is not None, "Output tensor must be provided."

if attn_metadata is None:
# Profiling run.
return output
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, List, Optional
from typing import Dict, Iterable, List, Optional

from vllm.logger import init_logger
from vllm.utils import cdiv
Expand Down Expand Up @@ -263,12 +263,13 @@ def free(self, request: Request) -> None:
"""
# Default to [] in case a request is freed (aborted) before alloc.
blocks = self.req_to_blocks.pop(request.request_id, [])
ordered_blocks: Iterable[KVCacheBlock] = blocks
if self.enable_caching:
# Free blocks in reverse order so that the tail blocks are
# freed first.
blocks = reversed(blocks)
ordered_blocks = reversed(blocks)

for block in blocks:
for block in ordered_blocks:
block.decr_ref()
if block.ref_cnt == 0:
self.free_block_queue.append(block)
Expand Down Expand Up @@ -396,8 +397,7 @@ def _cache_full_blocks(
f"{request.request_id}({request})")

# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
tuple(block_tokens))
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)

# Update and added the full block to the cache.
blk.block_hash = block_hash
Expand Down
17 changes: 9 additions & 8 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""KV-Cache Utilities."""
from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Tuple

Expand All @@ -13,7 +14,7 @@ class BlockHashType(NamedTuple):
collision happens when the hash value is the same.
"""
hash_value: int
token_ids: Tuple[int]
token_ids: Tuple[int, ...]


@dataclass
Expand Down Expand Up @@ -79,8 +80,8 @@ def __init__(self, blocks: List[KVCacheBlock]) -> None:
self.num_free_blocks = len(blocks)

# Initialize the doubly linked list of free blocks.
self.free_list_head = blocks[0]
self.free_list_tail = blocks[-1]
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
for i in range(self.num_free_blocks):
if i > 0:
blocks[i].prev_free_block = blocks[i - 1]
Expand Down Expand Up @@ -159,7 +160,7 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]:


def hash_block_tokens(parent_block_hash: Optional[int],
curr_block_token_ids: Tuple[int]) -> BlockHashType:
curr_block_token_ids: Sequence[int]) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
Expand All @@ -171,19 +172,19 @@ def hash_block_tokens(parent_block_hash: Optional[int],
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
curr_block_token_ids: A tuple of token ids in the current
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
curr_block_token_ids)
tuple(curr_block_token_ids))


def hash_request_tokens(block_size: int,
token_ids: List[int]) -> List[BlockHashType]:
token_ids: Sequence[int]) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Expand All @@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int,
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = tuple(token_ids[start:end])
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def schedule(self) -> "SchedulerOutput":
break
if not can_schedule:
break
assert new_blocks is not None

# Schedule the request.
scheduled_running_reqs.append(request)
Expand Down
23 changes: 14 additions & 9 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,19 @@ class EngineCoreRequest:
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[Optional[str]]]
mm_hashes: Optional[List[str]]
mm_placeholders: Optional[MultiModalPlaceholderDict]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]


class EngineCoreOutput(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
class EngineCoreOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]

request_id: str
new_token_ids: List[int]
Expand All @@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct,
stop_reason: Union[int, str, None] = None


class EngineCoreOutputs(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False):
class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]

#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout and using an int enum for finish/stop reason
Expand All @@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum):
ADD = b'\x00'
ABORT = b'\x01'
PROFILE = b'\x02'


EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]
11 changes: 6 additions & 5 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
asyncio_mode=True,
)

self.output_handler = None
self.output_handler: Optional[asyncio.Task] = None

def __del__(self):
self.shutdown()
Expand Down Expand Up @@ -126,7 +126,8 @@ def shutdown(self):
handler.cancel()

@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
Expand Down Expand Up @@ -361,10 +362,10 @@ async def check_health(self) -> None:
logger.debug("Called check_health.")

async def start_profile(self) -> None:
await self.engine_core.profile(True)
await self.engine_core.profile_async(True)

async def stop_profile(self) -> None:
await self.engine_core.profile(False)
await self.engine_core.profile_async(False)

@property
def is_running(self) -> bool:
Expand All @@ -380,7 +381,7 @@ def errored(self) -> bool:

@property
def dead_error(self) -> BaseException:
return Exception
return Exception() # TODO: implement


# Retain V0 name for backwards compatibility.
Expand Down
20 changes: 10 additions & 10 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from dataclasses import dataclass
from multiprocessing.process import BaseProcess
from typing import List, Tuple, Type, Union
from typing import List, Tuple, Type

import zmq
import zmq.asyncio
Expand All @@ -20,7 +20,7 @@
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
Expand Down Expand Up @@ -97,8 +97,10 @@ def add_request(self, request: EngineCoreRequest):
# Note that the cache here is mirrored with the client side of the
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes)
assert request.mm_inputs is not None
request.mm_inputs, request.mm_hashes = (
self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes))

req = Request.from_engine_core_request(request)

Expand Down Expand Up @@ -128,7 +130,7 @@ def step(self) -> List[EngineCoreOutput]:
def shutdown(self):
self.model_executor.shutdown()

def profile(self, is_start=True):
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)


Expand Down Expand Up @@ -161,8 +163,8 @@ def __init__(
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = queue.Queue()
self.output_queue = queue.Queue()
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
daemon=True).start()
Expand Down Expand Up @@ -318,9 +320,7 @@ def _log_stats(self):

self._last_logging_time = now

def _handle_client_request(
self, request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""

if isinstance(request, EngineCoreRequest):
Expand Down
43 changes: 24 additions & 19 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import atexit
import os
from typing import List, Union
from typing import List, Optional

import msgspec
import zmq
Expand All @@ -10,8 +10,9 @@
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.core import (EngineCore, EngineCoreProc,
EngineCoreProcHandle)
from vllm.v1.serial_utils import PickleEncoder

logger = init_logger(__name__)
Expand Down Expand Up @@ -59,7 +60,7 @@ def get_output(self) -> List[EngineCoreOutput]:
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

async def profile(self, is_start=True) -> None:
def profile(self, is_start: bool = True) -> None:
raise NotImplementedError

def abort_requests(self, request_ids: List[str]) -> None:
Expand All @@ -71,6 +72,9 @@ async def get_output_async(self) -> List[EngineCoreOutput]:
async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

async def profile_async(self, is_start: bool = True) -> None:
raise NotImplementedError

async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -105,7 +109,7 @@ def shutdown(self):
def __del__(self):
self.shutdown()

def profile(self, is_start=True) -> None:
def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start)


Expand Down Expand Up @@ -133,7 +137,10 @@ def __init__(
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)

# ZMQ setup.
self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context())
if asyncio_mode:
self.ctx = zmq.asyncio.Context()
else:
self.ctx = zmq.Context() # type: ignore[attr-defined]

# Path for IPC.
ready_path = get_open_zmq_ipc_path()
Expand All @@ -149,11 +156,13 @@ def __init__(
self.input_socket.bind(input_path)

# Start EngineCore in background process.
self.proc_handle: Optional[EngineCoreProcHandle]
self.proc_handle = EngineCoreProc.make_engine_core_process(
*args,
input_path=input_path,
output_path=output_path,
ready_path=ready_path,
input_path=
input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords
output_path=output_path, # type: ignore[misc]
ready_path=ready_path, # type: ignore[misc]
**kwargs,
)
atexit.register(self.shutdown)
Expand Down Expand Up @@ -204,10 +213,8 @@ def get_output(self) -> List[EngineCoreOutput]:
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
return engine_core_outputs

def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:

# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
Expand All @@ -219,7 +226,7 @@ def add_request(self, request: EngineCoreRequest) -> None:
def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.ABORT, request_ids)

def profile(self, is_start=True) -> None:
def profile(self, is_start: bool = True) -> None:
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))

Expand All @@ -237,10 +244,8 @@ async def get_output_async(self) -> List[EngineCoreOutput]:

return engine_core_outputs

async def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:

msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
Expand All @@ -252,6 +257,6 @@ async def abort_requests_async(self, request_ids: List[str]) -> None:
if len(request_ids) > 0:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)

async def profile(self, is_start=True) -> None:
async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
Loading

0 comments on commit 6d917d0

Please sign in to comment.