Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KV-Cache] Make k_scale, v_scale as attributes of self_attn using HFCache #148

Merged
merged 21 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 0 additions & 56 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,62 +252,6 @@ def compress(
compressed_state_dict
)

# HACK (mgoin): Post-process step for kv cache scales to take the
# k/v_proj module `output_scale` parameters, and store them in the
# parent attention module as `k_scale` and `v_scale`
#
# Example:
# Replace `model.layers.0.self_attn.k_proj.output_scale`
# with `model.layers.0.self_attn.k_scale`
if (
self.quantization_config is not None
and self.quantization_config.kv_cache_scheme is not None
):
# HACK (mgoin): We assume the quantized modules in question
# will be k_proj and v_proj since those are the default targets.
# We check that both of these modules have output activation
# quantization, and additionally check that q_proj doesn't.
q_proj_has_no_quant_output = 0
k_proj_has_quant_output = 0
v_proj_has_quant_output = 0
for name, module in model.named_modules():
if not hasattr(module, "quantization_scheme"):
# We still want to count non-quantized q_proj
if name.endswith(".q_proj"):
q_proj_has_no_quant_output += 1
continue
out_act = module.quantization_scheme.output_activations
if name.endswith(".q_proj") and out_act is None:
q_proj_has_no_quant_output += 1
elif name.endswith(".k_proj") and out_act is not None:
k_proj_has_quant_output += 1
elif name.endswith(".v_proj") and out_act is not None:
v_proj_has_quant_output += 1

assert (
q_proj_has_no_quant_output > 0
and k_proj_has_quant_output > 0
and v_proj_has_quant_output > 0
)
assert (
q_proj_has_no_quant_output
== k_proj_has_quant_output
== v_proj_has_quant_output
)

# Move all .k/v_proj.output_scale parameters to .k/v_scale
working_state_dict = {}
for key in compressed_state_dict.keys():
if key.endswith(".k_proj.output_scale"):
new_key = key.replace(".k_proj.output_scale", ".k_scale")
working_state_dict[new_key] = compressed_state_dict[key]
elif key.endswith(".v_proj.output_scale"):
new_key = key.replace(".v_proj.output_scale", ".v_scale")
working_state_dict[new_key] = compressed_state_dict[key]
else:
working_state_dict[key] = compressed_state_dict[key]
compressed_state_dict = working_state_dict

# HACK: Override the dtype_byte_size function in transformers to
# support float8 types. Fix is posted upstream
# https://github.com/huggingface/transformers/pull/30488
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .quant_config import *
from .quant_scheme import *
from .lifecycle import *
from .cache import QuantizedKVParameterCache
201 changes: 201 additions & 0 deletions src/compressed_tensors/quantization/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

from compressed_tensors.quantization.observers import Observer
from compressed_tensors.quantization.quant_args import QuantizationArgs
from torch import Tensor
from transformers import DynamicCache as HFDyanmicCache


class KVCacheScaleType(Enum):
KEY = "k_scale"
VALUE = "v_scale"


class QuantizedKVParameterCache(HFDyanmicCache):
horheynm marked this conversation as resolved.
Show resolved Hide resolved

"""
Quantized KV cache used in the forward call based on HF's dynamic cache.
Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
Singleton, so that the same cache gets reused in all forward call of self_attn.
Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
gets called appropriately.
The size of tensor is
`[batch_size, num_heads, seq_len - residual_length, head_dim]`.


Triggered by adding kv_cache_scheme in the recipe.

Example:

```python3
recipe = '''
quant_stage:
quant_modifiers:
QuantizationModifier:
kv_cache_scheme:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
'''

"""

_instance = None
_initialized = False

def __new__(cls, *args, **kwargs):
"""Singleton"""
if cls._instance is None:
cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
return cls._instance

def __init__(self, quantization_args: QuantizationArgs):
if not self._initialized:
super().__init__()

self.quantization_args = quantization_args

self.k_observers: List[Observer] = []
self.v_observers: List[Observer] = []

# each index corresponds to layer_idx of the attention layer
self.k_scales: List[Tensor] = []
self.v_scales: List[Tensor] = []

self.k_zps: List[Tensor] = []
self.v_zps: List[Tensor] = []

self._initialized = True

def update(
self,
key_states: Tensor,
value_states: Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Tensor, Tensor]:
"""
Get the k_scale and v_scale and output the
fakequant-ed key_states and value_states
"""

if len(self.k_observers) <= layer_idx:
k_observer = self.quantization_args.get_observer()
v_observer = self.quantization_args.get_observer()

self.k_observers.append(k_observer)
self.v_observers.append(v_observer)

q_key_states = self._quantize(
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
)
q_value_states = self._quantize(
value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
)

qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
qdq_value_states = self._dequantize(
q_value_states, KVCacheScaleType.VALUE, layer_idx
)

keys_to_return, values_to_return = qdq_key_states, qdq_value_states

return keys_to_return, values_to_return

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""
Returns the sequence length of the cached states.
A layer index can be optionally passed.
"""
if len(self.key_cache) <= layer_idx:
return 0
# since we cannot get the seq_length of each layer directly and
# rely on `_seen_tokens` which is updated every "layer_idx" == 0,
# this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to
# verify attn_weight shape in some models
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1

def reset_states(self):
"""reset the kv states (used in calibration)"""
self.key_cache: List[Tensor] = []
self.value_cache: List[Tensor] = []
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0
self._quantized_key_cache: List[Tensor] = []
self._quantized_value_cache: List[Tensor] = []

def reset(self):
"""
Reset the instantiation, create new instance on init
"""
QuantizedKVParameterCache._instance = None
QuantizedKVParameterCache._initialized = False

def _quantize(self, tensor, kv_type, layer_idx):
"""Quantizes a key/value using a defined quantization method."""
from compressed_tensors.quantization.lifecycle.forward import quantize

if kv_type == KVCacheScaleType.KEY: # key type
observer = self.k_observers[layer_idx]
scales = self.k_scales
zps = self.k_zps
else:
assert kv_type == KVCacheScaleType.VALUE
observer = self.v_observers[layer_idx]
horheynm marked this conversation as resolved.
Show resolved Hide resolved
scales = self.v_scales
zps = self.v_zps

scale, zp = observer(tensor)
if len(scales) <= layer_idx:
scales.append(scale)
zps.append(zp)
else:
scales[layer_idx] = scale
zps[layer_idx] = scale

q_tensor = quantize(
x=tensor,
scale=scale,
zero_point=zp,
args=self.quantization_args,
)
return q_tensor

def _dequantize(self, qtensor, kv_type, layer_idx):
"""Dequantizes back the tensor that was quantized by `self._quantize()`"""
from compressed_tensors.quantization.lifecycle.forward import dequantize

if kv_type == KVCacheScaleType.KEY:
scale = self.k_scales[layer_idx]
zp = self.k_zps[layer_idx]
else:
assert kv_type == KVCacheScaleType.VALUE
scale = self.v_scales[layer_idx]
horheynm marked this conversation as resolved.
Show resolved Hide resolved
zp = self.v_zps[layer_idx]

qdq_tensor = dequantize(
x_q=qtensor,
scale=scale,
zero_point=zp,
args=self.quantization_args,
)
return qdq_tensor
14 changes: 13 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
infer_quantization_status,
is_kv_cache_quant_scheme,
iter_named_leaf_modules,
iter_named_quantizable_modules,
)
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
from compressed_tensors.utils.offload import update_parameter_data
Expand Down Expand Up @@ -135,15 +136,23 @@ def apply_quantization_config(
# list of submodules to ignore
ignored_submodules = defaultdict(list)
# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in iter_named_leaf_modules(model):
for name, submodule in iter_named_quantizable_modules(
model,
include_children=True,
include_attn=True,
): # child modules and attention modules
# potentially fix module name to remove FSDP wrapper prefix
name = fix_fsdp_module_name(name)
if matches := find_name_or_class_matches(name, submodule, config.ignore):
for match in matches:
ignored_submodules[match].append(name)
continue # layer matches ignore list, continue

targets = find_name_or_class_matches(name, submodule, target_to_scheme)

if targets:
# mark modules to be quantized by adding
# quant scheme to the matching layers
scheme = _scheme_from_targets(target_to_scheme, targets, name)
if run_compressed:
format = config.format
Expand Down Expand Up @@ -200,6 +209,9 @@ def process_kv_cache_config(
:param config: the QuantizationConfig
:return: the QuantizationConfig with additional "kv_cache" group
"""
if targets == KV_CACHE_TARGETS:
_LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")

kv_cache_dict = config.kv_cache_scheme.model_dump()
kv_cache_scheme = QuantizationScheme(
output_activations=QuantizationArgs(**kv_cache_dict),
Expand Down
Loading
Loading