Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Aug 31, 2024
1 parent 79a7194 commit 60a83f4
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 140 deletions.
20 changes: 20 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"python": ".venv/bin/python",
"justMyCode": false,
"env": {
"CUDA_VISIBLE_DEVICES": "7"
}
}
]
}
53 changes: 0 additions & 53 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,59 +239,6 @@ def compress(
compressed_state_dict = self.sparsity_compressor.compress(
compressed_state_dict
)
# breakpoint()
# # HACK (mgoin): Post-process step for kv cache scales to take the
# # k/v_proj module `output_scale` parameters, and store them in the
# # parent attention module as `k_scale` and `v_scale`
# #
# # Example:
# # Replace `model.layers.0.self_attn.k_proj.output_scale`
# # with `model.layers.0.self_attn.k_scale`
# if (
# self.quantization_config is not None
# and self.quantization_config.kv_cache_scheme is not None
# ):
# # HACK (mgoin): We assume the quantized modules in question
# # will be k_proj and v_proj since those are the default targets.
# # We check that both of these modules have output activation
# # quantization, and additionally check that q_proj doesn't.
# q_proj_has_no_quant_output = 0
# k_proj_has_quant_output = 0
# v_proj_has_quant_output = 0
# for name, module in model.named_modules():
# if not hasattr(module, "quantization_scheme"):
# continue
# out_act = module.quantization_scheme.output_activations
# if name.endswith(".q_proj") and out_act is None:
# q_proj_has_no_quant_output += 1
# elif name.endswith(".k_proj") and out_act is not None:
# k_proj_has_quant_output += 1
# elif name.endswith(".v_proj") and out_act is not None:
# v_proj_has_quant_output += 1

# assert (
# q_proj_has_no_quant_output > 0
# and k_proj_has_quant_output > 0
# and v_proj_has_quant_output > 0
# )
# assert (
# q_proj_has_no_quant_output
# == k_proj_has_quant_output
# == v_proj_has_quant_output
# )

# # Move all .k/v_proj.output_scale parameters to .k/v_scale
# working_state_dict = {}
# for key in compressed_state_dict.keys():
# if key.endswith(".k_proj.output_scale"):
# new_key = key.replace(".k_proj.output_scale", ".k_scale")
# working_state_dict[new_key] = compressed_state_dict[key]
# elif key.endswith(".v_proj.output_scale"):
# new_key = key.replace(".v_proj.output_scale", ".v_scale")
# working_state_dict[new_key] = compressed_state_dict[key]
# else:
# working_state_dict[key] = compressed_state_dict[key]
# compressed_state_dict = working_state_dict

# HACK: Override the dtype_byte_size function in transformers to
# support float8 types. Fix is posted upstream
Expand Down
58 changes: 28 additions & 30 deletions src/compressed_tensors/quantization/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@
# limitations under the License.


from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
from compressed_tensors.quantization.observers import Observer
from compressed_tensors.quantization.quant_args import QuantizationArgs
from torch import Tensor
from transformers import DynamicCache as HFDyanmicCache


# from compressed_tensors.quantization.observers import Observer
# from compressed_tensors.quantization.quant_args import QuantizationArgs


class QuantizedCache(HFDyanmicCache):
"""
Quantized KV cache used in the forward call based on HF's dynamic cache.
Expand All @@ -49,30 +47,25 @@ class QuantizedCache(HFDyanmicCache):
"""

_instance = None
_initialized = False

def __new__(cls, *args, **kwargs):
"""Singleton"""
if cls._instance is None:
cls._instance = super(QuantizedCache, cls).__new__(cls)
return cls._instance

# def __init__(self, quantization_args: QuantizationArgs):
def __init__(self, quantization_args):
if not hasattr(
self, "_initialized"
): # Ensure attributes are initialized only once
def __init__(self, quantization_args: QuantizationArgs):
if not self._initialized:
super().__init__()

self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = []
self._quantized_key_cache: List[Tensor] = []
self._quantized_value_cache: List[Tensor] = []

self.quantization_args = quantization_args

# self.k_observers: List[Observer] = []
# self.v_observers: List[Observer] = []

self.k_observers: List = []
self.v_observers: List = []
self.k_observers: List[Observer] = []
self.v_observers: List[Observer] = []

self.k_scales: List[
Tensor
Expand All @@ -82,15 +75,15 @@ def __init__(self, quantization_args):
self.k_zps: List[Tensor] = []
self.v_zps: List[Tensor] = []

self._initialized = True # Mark the instance as initialized
self._initialized = True

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
key_states: Tensor,
value_states: Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[Tensor, Tensor]:
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
Expand Down Expand Up @@ -159,23 +152,28 @@ def update(
return keys_to_return, values_to_return

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
"""
Returns the sequence length of the cached states.
A layer index can be optionally passed.
"""
if len(self.key_cache) <= layer_idx:
return 0
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to verify attn_weight shape in some models
# since we cannot get the seq_length of each layer directly and
# rely on `_seen_tokens` which is updated every "layer_idx" == 0,
# this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to
# verify attn_weight shape in some models
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1

def reset(self):
def reset_states(self):
"""reset the kv states (used in calibration)"""
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.key_cache: List[Tensor] = []
self.value_cache: List[Tensor] = []
self._seen_tokens = (
0 # Used in `generate` to keep tally of how many tokens the cache has seen
)
self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = []
self._quantized_key_cache: List[Tensor] = []
self._quantized_value_cache: List[Tensor] = []

def _quantize(self, tensor, kv_type, layer_idx):
"""Quantizes a key/value using a defined quantization method."""
Expand Down
16 changes: 5 additions & 11 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,11 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
# list of submodules to ignore
ignored_submodules = defaultdict(list)
# mark appropriate layers for quantization by setting their quantization schemes
# for name, submodule in iter_named_leaf_modules(model):
for name, submodule in iter_named_modules(
model,
include_children=True,
include_attn=True,
): # all the modules, self_attn + leaf
): # child modules and attention modules
# potentially fix module name to remove FSDP wrapper prefix
name = fix_fsdp_module_name(name)
if matches := find_name_or_class_matches(name, submodule, config.ignore):
Expand All @@ -143,8 +142,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict

targets = find_name_or_class_matches(name, submodule, target_to_scheme)

if targets: # if targets and leaf
# target matched - add layer and scheme to target list
if targets:
# mark modules to be quantized by adding
# quant scheme to the matching layers
submodule.quantization_scheme = _scheme_from_targets(
target_to_scheme, targets, name
)
Expand All @@ -157,8 +157,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
"not found in the model: "
f"{set(config.ignore) - set(ignored_submodules)}"
)
# apply current quantization status across all targeted layers

# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)
return names_to_scheme

Expand Down Expand Up @@ -207,8 +207,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):

if status >= QuantizationStatus.INITIALIZED > current_status:
model.apply(initialize_module_for_quantization)
# add self_attn mapper here
# model.apply(initialize_module_for_attn_quantization)

if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
# only quantize weights up front when our end goal state is calibration,
Expand Down Expand Up @@ -316,10 +314,6 @@ def _scheme_from_targets(
# if `targets` iterable contains a single element
# use it as the key

# if name.endswith("self_attn"):
# scheme = target_to_scheme[targets[0]]
# args = scheme.output_activations
# args.set_kv_cache()
return target_to_scheme[targets[0]]

# otherwise, we need to merge QuantizationSchemes corresponding
Expand Down
25 changes: 12 additions & 13 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from functools import wraps
from math import ceil
from typing import Optional
from typing import Callable, Optional

import torch
from compressed_tensors.quantization.observers.helpers import calculate_range
Expand Down Expand Up @@ -295,14 +295,15 @@ def wrapped_forward(self, *args, **kwargs):
input_, *args[1:], **kwargs
)

if scheme.output_activations is not None:
if scheme.output_activations is not None and not is_kv_cache_quant_scheme(
scheme
):
# calibrate and (fake) quantize output activations when applicable

# kv_cache observers updated on model's forward call
if not is_kv_cache_quant_scheme(scheme):
output = maybe_calibrate_or_quantize(
module, output, "output", scheme.output_activations
)
# kv_cache scales updated on model self_attn forward call in
# wrap_module_forward_quantized_attn
output = maybe_calibrate_or_quantize(
module, output, "output", scheme.output_activations
)

# restore back to unquantized_value
if scheme.weights is not None:
Expand All @@ -328,16 +329,14 @@ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationSchem
def wrapped_forward(self, *args, **kwargs):

past_key_value = scheme.output_activations.get_kv_cache()
# print(hex(id(past_key_value))) # 0x7f74669e6500
# breakpoint()
kwargs["past_key_value"] = past_key_value
kwargs["use_cache"] = past_key_value is not None

attn_module = forward_func_orig.__get__(module, module.__class__)
attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)

past_key_value.reset()
past_key_value.reset_states()

rtn = attn_module(*args, **kwargs)
rtn = attn_forward(*args, **kwargs)

self.k_scale = past_key_value.k_scales[module.layer_idx]
self.v_scale = past_key_value.v_scales[module.layer_idx]
Expand Down
19 changes: 4 additions & 15 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def initialize_module_for_quantization(
return

if is_attention_module(module):
# wrap forward call of module to perform quantized actions based on calltime status
# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized_attn(module, scheme)
else:

Expand Down Expand Up @@ -120,7 +121,8 @@ def initialize_module_for_quantization(
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
remove_hook_from_module(module)

# wrap forward call of module to perform quantized actions based on calltime status
# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)

if offloaded:
Expand Down Expand Up @@ -186,19 +188,6 @@ def _initialize_scale_zero_point_observer(
module.register_parameter(f"{base_name}_g_idx", init_g_idx)


def _register_kv_cache_to_registry(args: QuantizationArgs):
"""
Register KV Cache in registry. Should only be called if
kv_cache_scheme is defined in the recipe
"""
from compressed_tensors.quantization.cache import QuantizedCache

cache = QuantizedCache(args)
name = "kv-cache"
if name not in cache.registered_names():
QuantizedCache.register_value(value=cache, name=name)


def is_attention_module(module: Module):
return "attention" in module.__class__.__name__.lower() and (
hasattr(module, "k_proj")
Expand Down
19 changes: 3 additions & 16 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import Any, Dict, Optional

import torch
from compressed_tensors.quantization.cache import QuantizedCache
from pydantic import BaseModel, Field, field_validator, model_validator


Expand Down Expand Up @@ -95,13 +94,6 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
"Observers constructor excluding quantization range or symmetry"
),
)
# kv_cache: Optional[
# QuantizedCache
# ] = None # Singleton, only relevant for output_activations

# model_config = {
# "arbitrary_types_allowed": True
# }

def get_observer(self):
"""
Expand All @@ -116,15 +108,10 @@ def get_observer(self):

return Observer.load_from_registry(self.observer, quantization_args=self)

def get_kv_cache(self) -> QuantizedCache:
# """Lazy initialization of kv_cache. Singleton instantiation"""
return QuantizedCache(self)
# if self.kv_cache is None:
# self.kv_cache = QuantizedCache(self)
def get_kv_cache(self):
from compressed_tensors.quantization.cache import QuantizedCache

# @kv_cache.setter
# def kv_cache(self, value: QuantizedCache):
# self.kv_cache = value
return QuantizedCache(self)

@field_validator("group_size", mode="before")
def validate_group(cls, value) -> int:
Expand Down
Loading

0 comments on commit 60a83f4

Please sign in to comment.