Skip to content

Commit

Permalink
[TPU][Quantization] TPU W8A8 (#11785)
Browse files Browse the repository at this point in the history
Co-authored-by: Woosuk Kwon <[email protected]>
  • Loading branch information
robertgshaw2-redhat and WoosukKwon authored Jan 8, 2025
1 parent 47de882 commit 56fe4c2
Show file tree
Hide file tree
Showing 18 changed files with 565 additions and 190 deletions.
11 changes: 10 additions & 1 deletion .buildkite/run-tpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,13 @@ remove_docker_container
# For HF_TOKEN.
source /etc/environment
# Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference/offline_inference_tpu.py"
docker run --privileged --net host --shm-size=16G -it \
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
&& python3 -m pip install pytest \
&& python3 -m pip install lm_eval[api]==0.4.4 \
&& pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \
&& pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
&& python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
&& python3 /workspace/vllm/examples/offline_inference/offline_inference_tpu.py"
49 changes: 49 additions & 0 deletions tests/tpu/test_quantization_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dataclasses import dataclass

import lm_eval
import pytest

TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03


@dataclass
class GSM8KAccuracyTestConfig:
model_name: str
excepted_value: float

def get_model_args(self) -> str:
return (f"pretrained={self.model_name},"
"max_model_len=4096,max_num_seqs=32")


# NOTE: Accuracy scores measured on GPUs.
ACCURACY_CONFIGS = [
GSM8KAccuracyTestConfig(
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
excepted_value=0.76), # no bias
# NOTE(rob): We cannot re-initialize VLLM in the same process for TPU,
# so only one of these tests can run in a single call to pytest. As
# a follow up, move this into the LM-EVAL section of the CI.
# GSM8KAccuracyTestConfig(
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
# excepted_value=0.66), # bias in QKV layers
]


@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):

results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(),
tasks="gsm8k",
batch_size="auto",
)

EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Set

import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_int8_linear, convert_to_channelwise)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
Expand All @@ -18,6 +17,7 @@


class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set()

def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool):
Expand All @@ -30,74 +30,25 @@ def get_min_capability(cls) -> int:
# turing and up
return 75

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(self.logical_widths) > 1
if is_fused_module and self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(layer.weight_scale,
self.logical_widths)
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
else:
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme:
if self.input_symmetric:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
layer.input_zero_point = None
else:
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = layer.input_zero_point.to(dtype=torch.int32)
range_max = (layer.input_scale *
(int8_traits.max - azps)).max()
range_min = (layer.input_scale *
(int8_traits.min - azps)).min()

scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
layer.input_scale = Parameter(scale, requires_grad=False)

# AZP loaded as int8 but used as int32
azp = (int8_traits.min -
range_min / scale).to(dtype=torch.int32)
layer.input_zero_point = Parameter(azp, requires_grad=False)

else:
layer.input_scale = None
layer.input_zero_point = None

# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.input_symmetric:
azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = layer.input_zero_point * azp_adj

layer.azp_adj = azp_adj
else:
layer.azp_adj = None

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes
layer.logical_widths = output_partition_sizes

scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric)

kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config)

if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8Int8",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)

# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
Expand Down Expand Up @@ -140,12 +91,18 @@ def create_weights(self, layer: torch.nn.Module,
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)

self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj")

# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj,
bias=bias)
return self.kernel.apply_weights(layer, x, bias)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels import (
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels import (
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
Expand Down
74 changes: 0 additions & 74 deletions vllm/model_executor/layers/quantization/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,74 +0,0 @@
from typing import List, Optional, Type

import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.exllama import (
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.machete import (
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.marlin import (
MarlinLinearKernel)
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
MPLinearKernel, MPLinearLayerConfig)
from vllm.platforms import current_platform

# in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
MarlinLinearKernel,
ExllamaLinearKernel,
]


def choose_mp_linear_kernel(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the compute
capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
Type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
compute_capability = _cc[0] * 10 + _cc[1]

failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue

if kernel.get_min_capability() > compute_capability:
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute capability "
f"is {compute_capability}")
continue

can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
)

raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import List, Optional, Type

import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
MarlinLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
MPLinearKernel, MPLinearLayerConfig)
from vllm.platforms import current_platform

# in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
MarlinLinearKernel,
ExllamaLinearKernel,
]


def choose_mp_linear_kernel(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the compute
capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
Type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
compute_capability = _cc[0] * 10 + _cc[1]

failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue

if kernel.get_min_capability() > compute_capability:
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute capability "
f"is {compute_capability}")
continue

can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
)

raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))
Loading

0 comments on commit 56fe4c2

Please sign in to comment.