diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 2cc080608c7a9..1b0453c2bd6f8 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -126,8 +126,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, - has_zp=has_zp, - min_capability=cls.get_min_capability()) + has_zp=has_zp) class AWQMarlinLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 066102f3a01c0..b92697531c299 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -136,8 +136,7 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): return False return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], - group_size=group_size, - min_capability=cls.get_min_capability()) + group_size=group_size) class GPTQMarlinLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 6e84d36219361..0ec68ac5b0f21 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -26,12 +26,13 @@ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types(has_zp: bool, - min_capability: Optional[int] = None): - if min_capability is None: + device_capability: Optional[int] = None + ): + if device_capability is None: major, minor = current_platform.get_device_capability() - min_capability = major * 10 + minor + device_capability = major * 10 + minor - if min_capability < 80: + if device_capability < 80: return [] if has_zp: @@ -48,20 +49,20 @@ def _check_marlin_supported( quant_type: ScalarType, group_size: Optional[int], has_zp: bool, - min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: - if min_capability is None: + if device_capability is None: major, minor = current_platform.get_device_capability() - min_capability = major * 10 + minor + device_capability = major * 10 + minor supported_types = query_marlin_supported_quant_types( - has_zp, min_capability) + has_zp, device_capability) if quant_type not in supported_types: return (False, f"Marlin does not support weight_bits = {quant_type}. " f"Only types = {supported_types} " f"are supported (for group_size = {group_size}, " - f"min_capability = {min_capability}, zp = {has_zp}).") + f"device_capability = {device_capability}, zp = {has_zp}).") if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): return (False, f"Marlin does not support group_size = {group_size}. " f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " @@ -73,9 +74,9 @@ def _check_marlin_supported( def check_marlin_supported(quant_type: ScalarType, group_size: int, has_zp: bool = False, - min_capability: Optional[int] = None) -> bool: + device_capability: Optional[int] = None) -> bool: cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, - min_capability) + device_capability) return cond