Skip to content

Commit

Permalink
[Bugfix] Allow ScalarType to be compiled with pytorch 2.3 and add che…
Browse files Browse the repository at this point in the history
…cks for registering FakeScalarType and dynamo support. (vllm-project#7886)
  • Loading branch information
bnellnm authored Aug 28, 2024
1 parent bc6e42a commit c166e7e
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 67 deletions.
3 changes: 2 additions & 1 deletion csrc/core/scalar_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t len() const {
throw c10::TypeError("__len__ not implemented");
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
"__len__ not implemented");
return 0;
}

Expand Down
134 changes: 70 additions & 64 deletions vllm/_core_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,92 +181,98 @@ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,

ScalarType = torch.classes._core_C.ScalarType

# Needed for dynamo support of ScalarType.
@torch._library.register_fake_class("_core_C::ScalarType")
class FakeScalarType:
if (hasattr(torch, "_library")
and hasattr(torch._library, "register_fake_class")):
# Needed for dynamo support of ScalarType.
@torch._library.register_fake_class("_core_C::ScalarType")
class FakeScalarType:

def __init__(self, scalar_type):
self.ScalarType = scalar_type
def __init__(self, scalar_type):
self.ScalarType = scalar_type

def bias_getter(self) -> int:
return self.ScalarType.bias
def bias_getter(self) -> int:
return self.ScalarType.bias

def exponent_getter(self) -> int:
return self.ScalarType.exponent
def exponent_getter(self) -> int:
return self.ScalarType.exponent

def mantissa_getter(self) -> int:
return self.ScalarType.mantissa
def mantissa_getter(self) -> int:
return self.ScalarType.mantissa

def signed_getter(self) -> bool:
return self.ScalarType.signed
def signed_getter(self) -> bool:
return self.ScalarType.signed

def size_bits_getter(self) -> int:
return self.ScalarType.size_bits
def size_bits_getter(self) -> int:
return self.ScalarType.size_bits

@property
def size_bits(self) -> int:
return self.ScalarType.size_bits
@property
def size_bits(self) -> int:
return self.ScalarType.size_bits

def min(self) -> Union[int, float]:
return self.ScalarType.min()
def min(self) -> Union[int, float]:
return self.ScalarType.min()

def max(self) -> Union[int, float]:
return self.ScalarType.max()
def max(self) -> Union[int, float]:
return self.ScalarType.max()

def is_signed(self) -> bool:
return self.ScalarType.is_signed()
def is_signed(self) -> bool:
return self.ScalarType.is_signed()

def is_floating_point(self) -> bool:
return self.ScalarType.is_floating_point()
def is_floating_point(self) -> bool:
return self.ScalarType.is_floating_point()

def is_integer(self) -> bool:
return self.ScalarType.is_integer()
def is_integer(self) -> bool:
return self.ScalarType.is_integer()

def has_bias(self) -> bool:
return self.ScalarType.has_bias()
def has_bias(self) -> bool:
return self.ScalarType.has_bias()

def has_infs(self) -> bool:
return self.ScalarType.has_infs()
def has_infs(self) -> bool:
return self.ScalarType.has_infs()

def has_nans(self) -> bool:
return self.ScalarType.has_nans()
def has_nans(self) -> bool:
return self.ScalarType.has_nans()

def is_ieee_754(self) -> bool:
return self.ScalarType.is_ieee_754()
def is_ieee_754(self) -> bool:
return self.ScalarType.is_ieee_754()

def __str__(self) -> str:
return self.ScalarType.__str__()
def __str__(self) -> str:
return self.ScalarType.__str__()

def __repr__(self) -> str:
return self.ScalarType.__repr__()
def __repr__(self) -> str:
return self.ScalarType.__repr__()

def __len__(self) -> int:
return self.ScalarType.__len__()
def __len__(self) -> int:
return self.ScalarType.__len__()

def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
return torch.classes._core_C.ScalarType.__obj_flatten__(
self.ScalarType)
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
return torch.classes._core_C.ScalarType.__obj_flatten__(
self.ScalarType)

@classmethod
def __obj_unflatten__(
cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType':
return cls(
torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type))
@classmethod
def __obj_unflatten__(
cls, flat_type: Tuple[Tuple[str, Any],
...]) -> 'ScalarType':
return cls(
torch.classes._core_C.ScalarType.__obj_unflatten__(
flat_type))

@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.int_(size_bits, bias)
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.int_(size_bits, bias)

@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.uint(size_bits, bias)
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.uint(size_bits, bias)

@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
return ScalarType.float_IEEE754(exponent, mantissa)
@classmethod
def float_IEEE754(cls, exponent: int,
mantissa: int) -> 'ScalarType':
return ScalarType.float_IEEE754(exponent, mantissa)

@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int) -> 'ScalarType':
return ScalarType.float_(exponent, mantissa, finite_values_only,
nan_repr)
@classmethod
def float_(cls, exponent: int, mantissa: int,
finite_values_only: bool,
nan_repr: int) -> 'ScalarType':
return ScalarType.float_(exponent, mantissa,
finite_values_only, nan_repr)
9 changes: 9 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import psutil
import torch
import torch.types
from packaging.version import Version
from typing_extensions import ParamSpec, TypeIs, assert_never

import vllm.envs as envs
Expand Down Expand Up @@ -1114,3 +1115,11 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)


# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
# In particular, the FakeScalarType is not supported for earlier versions of
# PyTorch which breaks dynamo for any ops registered using ScalarType.
def supports_dynamo() -> bool:
base_torch_version = Version(Version(torch.__version__).base_version)
return base_torch_version >= Version("2.4.0")
5 changes: 3 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available)
flatten_2d_lists, is_hip, is_pin_memory_available,
supports_dynamo)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -946,7 +947,7 @@ def load_model(self) -> None:
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")

if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
self.model = torch.compile(self.model,
fullgraph=True,
backend="eager")
Expand Down

0 comments on commit c166e7e

Please sign in to comment.