From d847ed79378859948e4be5190d92b10eb359dce4 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Thu, 22 Feb 2024 19:14:50 -0500 Subject: [PATCH] Q4_1 quantization compiling to vmfb megacommit (#2) --- python/turbine_llamacpp/compile.py | 201 +++++++++++++++++---- python/turbine_llamacpp/ggml_structs.py | 116 +++++++++++- python/turbine_llamacpp/llamacpp_runner.py | 126 +++++++++++++ python/turbine_llamacpp/model.py | 48 ++--- python/turbine_llamacpp/params.py | 83 ++++++++- 5 files changed, 501 insertions(+), 73 deletions(-) create mode 100644 python/turbine_llamacpp/llamacpp_runner.py diff --git a/python/turbine_llamacpp/compile.py b/python/turbine_llamacpp/compile.py index 6a2ca90..a98fdb9 100644 --- a/python/turbine_llamacpp/compile.py +++ b/python/turbine_llamacpp/compile.py @@ -9,6 +9,7 @@ import argparse + parser = argparse.ArgumentParser() parser.add_argument( "--gguf_path", @@ -16,13 +17,85 @@ default="ggml-model-q8_0.gguf", help="path to gguf", ) +parser.add_argument( + "--irpa_path", + type=str, + default=None, + help="path to a .irpa file to generate new repacked parameters.", +) +parser.add_argument( + "--compile_to", default="torch", type=str, help="torch, linalg, vmfb" +) +parser.add_argument( + "--vmfb_path", type=str, default=None, help="Path/name to store compiled vmfb." +) +parser.add_argument("--device", type=str, default="llvm-cpu", help="llvm-cpu") +parser.add_argument( + "--quantization", + type=str, + default="", + help="Comma separated list of quantization types. Supported types are [Q4_1].", +) + -def create_direct_predict_internal_kv_module(model: LlamaCPP) -> CompiledModule: +def create_direct_predict_internal_kv_module( + hp: HParams, + compile_to=None, + device=None, + vmfb_path=None, + quantization=None, + irpa_path=None, +): """This compilation performs direct, non-sampled prediction. - It manages its kv cache and step states internally. + It manages its kv kv_cache and step states internally. """ + quant_types = quantization.split(",") + if irpa_path: + import iree.runtime as rt + + dequantize_types = [ + type + for type in [ + "F32", + "F16", + "Q4_0", + "Q4_1", + "Q5_0", + "Q5_1", + "Q8_0", + "Q8_1", + "Q2_K", + "Q3_K", + "Q4_K", + "Q5_K", + "Q6_K", + "Q8_K", + ] + if type not in quant_types + ] + # We can't match on this param yet for the quantization rewrite. + dequantize_params = [ + "token_embd.weight", + ] + repacked_params = hp.repack_tensor_params( + dequantize_types=dequantize_types, + dequantize_params=dequantize_params, + dtype=torch.float32, + ) + rt.save_archive_file(repacked_params, irpa_path) + print(f"saved repacked parameters to {irpa_path}") + + # Replace tensor params for tracing with dequantized types for any type not + # listed in args.quantization + replaceable_types = [type for type in hp.supported_types if type not in quant_types] + # Replace Q4_1 tensors because of a rewrite trick for Q4_1 parameters + if "Q4_1" in quant_types: + replaceable_types.append("Q4_1") + hp.replace_quantized_tensors(replaceable_types=replaceable_types) + model = LlamaCPP(hp) + class LlamaDpisModule(CompiledModule): params = export_parameters( model.theta.params, @@ -30,69 +103,56 @@ class LlamaDpisModule(CompiledModule): name_mapper=lambda n: n.removeprefix("params."), ) current_seq_index = export_global(AbstractIndex, mutable=True) - cache_k = export_global( - model.cache_k, name="cache_k", uninitialized=True, mutable=True - ) - cache_v = export_global( - model.cache_v, name="cache_v", uninitialized=True, mutable=True - ) + kv_cache = export_global_tree(model.kv_cache, uninitialized=True, mutable=True) def run_initialize( - self, input_ids=AbstractTensor(model.hp.bs, None, dtype=torch.int32) + self, input_ids=AbstractTensor(model.hp.bs, None, dtype=torch.int64) ): - output_token, cache_k, cache_v = self._initialize( + output_token, *kv_cache = self._initialize( input_ids, - cache_k=self.cache_k, - cache_v=self.cache_v, + *self.kv_cache, constraints=[ input_ids.dynamic_dim(1) <= model.max_seqlen, ], ) self.current_seq_index = IREE.tensor_dim(input_ids, 1) - self.cache_k = cache_k - self.cache_v = cache_v + self.kv_cache = kv_cache return output_token - def run_forward(self, token0=AbstractTensor(1, 1, dtype=torch.int32)): + def run_forward(self, token0=AbstractTensor(1, 1, dtype=torch.int64)): seq_index_0 = self.current_seq_index # TODO: Torch currently has poor support for passing symints across # the tracing boundary, so we box it in a tensor and unbox it on the # inside. Once this restriction is relaxes, just pass it straight through. - seq_index_0_tensor = IREE.tensor_splat(value=seq_index_0, dtype=torch.int32) - output_token, cache_k, cache_v = self._decode_step( - token0, seq_index_0_tensor, self.cache_k, self.cache_v + seq_index_0_tensor = IREE.tensor_splat(value=seq_index_0, dtype=torch.int64) + output_token, *kv_cache = self._decode_step( + token0, seq_index_0_tensor, *self.kv_cache ) # TODO: Emit an assertion of some kind of overflowing max_seqlen. self.current_seq_index = seq_index_0 + 1 - self.cache_k = cache_k - self.cache_v = cache_v + self.kv_cache = kv_cache return output_token @jittable - def _initialize( - input_ids: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor - ): + def _initialize(input_ids: torch.Tensor, *kv_cache): return ( model.forward( input_ids, 0, - local_cache_k=cache_k, - local_cache_v=cache_v, + local_kv_cache=kv_cache, ), - cache_k, - cache_v, + *kv_cache, ) @jittable def _decode_step( token0: torch.Tensor, index0: torch.Tensor, - cache_k: torch.Tensor, - cache_v: torch.Tensor, + *kv_cache, ): bs, sl_input = token0.shape - _, _, sl_k, *_ = cache_k.shape - _, _, sl_v, *_ = cache_v.shape + _, sl_k, *_ = kv_cache[0].shape + _, sl_v, *_ = kv_cache[0].shape index0_scalar = index0.item() # Torch is very picky that on the auto-regressive steps it knows # that the index0_scalar value (which is used to slice the caches) @@ -107,23 +167,86 @@ def _decode_step( model.forward( token0, index0_scalar, - local_cache_k=cache_k, - local_cache_v=cache_v, + local_kv_cache=kv_cache, ), - cache_k, - cache_v, + *kv_cache, ) - return LlamaDpisModule(import_to="import") + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = LlamaDpisModule(import_to=import_to) + + quantized_param_names = get_quantized_param_name_dict(hp, quant_types) + # Only supporting rewrite for Q4_1 params right now. + if "Q4_1" in quantized_param_names and not compile_to == "linalg": + from shark_turbine.transforms.quantization import mm_group_quant + + mm_group_quant.MMGroupQuantRewriterPass( + CompiledModule.get_mlir_module(inst).operation, + group_size=32, + param_names=quantized_param_names["Q4_1"], + ).run() + module_str = str(CompiledModule.get_mlir_module(inst)) + if compile_to != "vmfb": + return module_str + else: + flags = [ + "--iree-input-type=torch", + "--mlir-print-debuginfo", + "--mlir-print-op-on-diagnostic=false", + "--iree-stream-resource-index-bits=64", + "--iree-vm-target-index-bits=64", + ] + if device == "cpu" or device == "llvm-cpu": + flags.extend( + [ + "--iree-llvmcpu-target-cpu-features=host", + "--iree-llvmcpu-target-triple=x86_64-linux-gnu", + "--iree-llvmcpu-enable-ukernels=all", + ] + ) + device = "llvm-cpu" + else: + print("Unknown device kind: ", device) + import iree.compiler as ireec + + flatbuffer_blob = ireec.compile_str( + module_str, + target_backends=[device], + extra_args=flags, + ) + if vmfb_path is None: + vmfb_path = f"output.vmfb" + with open(vmfb_path, "wb+") as f: + f.write(flatbuffer_blob) + print("saved to output.vmfb") + return module_str + + +def get_quantized_param_name_dict(hp: HParams, allowed_quant_types: list[str]): + quantized_param_names = {} + for tensor_name, quant_type in hp.replaced_quantized_tensors: + if quant_type in allowed_quant_types: + if quant_type in quantized_param_names: + quantized_param_names[quant_type].add(tensor_name) + else: + quantized_param_names[quant_type] = set([tensor_name]) + return quantized_param_names def main(): args = parser.parse_args() hp = HParams(args.gguf_path) - model = LlamaCPP(hp) - cm = create_direct_predict_internal_kv_module(model) + module_str = create_direct_predict_internal_kv_module( + hp, + args.compile_to, + args.device, + args.vmfb_path, + args.quantization, + args.irpa_path, + ) with open(f"output.mlir", "w+") as f: - f.write(str(CompiledModule.get_mlir_module(cm))) + f.write(module_str) + print("saved to output.mlir") if __name__ == "__main__": diff --git a/python/turbine_llamacpp/ggml_structs.py b/python/turbine_llamacpp/ggml_structs.py index 6e2c515..3f5236b 100644 --- a/python/turbine_llamacpp/ggml_structs.py +++ b/python/turbine_llamacpp/ggml_structs.py @@ -1,11 +1,13 @@ from typing import Generic, Optional, TypeVar from abc import ABC, abstractmethod from dataclasses import dataclass +import warnings import torch __all__ = [ "Q4_0", + "Q4_1", "Q8_0", "QuantizedTensor", ] @@ -53,6 +55,12 @@ def dequant_blocked(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: scaled = d * qs.to(dtype) return scaled + def repack_for_turbine(self, dtype: Optional[torch.dtype] = None): + warnings.warn( + f"Repacking quantized type Q8_0 not supported. Returning in GGUF format." + ) + return self.dequant(dtype), None, None + def __repr__(self): return f"Q8_0(d[{self.d.shape}]={self.d}, qs[{self.qs.shape}]={self.qs})" @@ -86,6 +94,7 @@ def unpack(self) -> Q8_0Struct: qs = blocks[..., 1:].view(torch.int8) return Q8_0Struct(self.shape, blocks, d, qs) + @dataclass class Q4_0Struct(UnpackedStruct): shape: list[int] @@ -103,16 +112,22 @@ def dequant_blocked(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: d = d.to(dtype) else: dtype = d.dtype - v1 = (qs & 0xF) - v2 = (qs >> 4) + v1 = qs & 0xF + v2 = qs >> 4 # Set up shape for combined unpacked dequants. target_shape = list(v1.shape) target_shape[-1] = v1.shape[-1] + v2.shape[-1] # Combining unpacked quants. - v3 = torch.cat([v1,v2],dim=-1) + v3 = torch.cat([v1, v2], dim=-1) scaled = d * (v3.to(dtype) - 8.0) return scaled + def repack_for_turbine(self, dtype: Optional[torch.dtype] = None): + warnings.warn( + f"Repacking quantized type Q4_0 not supported. Returning in GGUF format." + ) + return self.dequant(dtype), None, None + def __repr__(self): return f"Q4_0(d[{self.d.shape}]={self.d}, qs[{self.qs.shape}]={self.qs})" @@ -148,3 +163,98 @@ def unpack(self) -> Q4_0Struct: d = blocks[..., 0:1].view(torch.float16) qs = blocks[..., 1:].view(torch.uint8) return Q4_0Struct(self.shape, blocks, d, qs) + + +@dataclass +class Q4_1Struct(UnpackedStruct): + shape: list[int] + blocks: torch.Tensor + d: torch.Tensor + m: torch.Tensor + qs: torch.Tensor + + def dequant(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + return self.dequant_blocked(dtype).reshape(self.shape) + + def dequant_blocked(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + d = self.d + m = self.m + qs = self.qs + if dtype: + d = d.to(dtype) + m = m.to(dtype) + else: + dtype = d.dtype + v1 = qs & 0xF + v2 = qs >> 4 + # Set up shape for combined unpacked dequants. + target_shape = list(v1.shape) + target_shape[-1] = v1.shape[-1] + v2.shape[-1] + # Combining unpacked quants. + v3 = torch.cat([v1, v2], dim=-1) + scaled = (d * v3.to(dtype)) + m + return scaled + + # GGML packing of Q4 data is in the order: + # [0, 16, 1, 17, 2, 18, ...] + # We need to repack to the [0, 1, 2, ...] order. + def reorder_q4_data(self, q4_tensor: torch.Tensor): + v1 = q4_tensor & 0xF + v2 = q4_tensor >> 4 + block_size = q4_tensor.size(-1) + even_idx = torch.tensor(range(0, block_size, 2)) + odd_idx = torch.tensor(range(1, block_size, 2)) + v1_even = v1.index_select(-1, even_idx) + v1_odd = v1.index_select(-1, odd_idx) + v2_even = v2.index_select(-1, even_idx) + v2_odd = v2.index_select(-1, odd_idx) + v1_packed = torch.bitwise_or(v1_even, v1_odd << 4) + v2_packed = torch.bitwise_or(v2_even, v2_odd << 4) + return torch.cat([v1_packed, v2_packed], dim=-1) + + def repack_for_turbine(self, dtype: Optional[torch.dtype] = None): + if not dtype: + dtype = self.d.dtype + weights = self.reorder_q4_data(self.qs) + scales = self.d + # GGML uses a positive scaled zero point, and turbine uses a negative + # unscaled zero point so we adjust the zero points accordingly. + zps = self.m / -self.d + return weights, scales.to(dtype), zps.to(dtype) + + def __repr__(self): + return f"Q4_1(d[{self.d.shape}]={self.d}, m[{self.m.shape}]={self.m}, qs[{self.qs.shape}]={self.qs})" + + +class Q4_1(QuantizedTensor[Q4_1Struct]): + """ + ``` + #define QK4_1 32 + typedef struct { + ggml_fp16_t d; // delta + ggml_fp16_t m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants + } block_q4_1; + ``` + Dequant: + https://github.com/ggerganov/llama.cpp/blob/f026f8120f97090d34a52b3dc023c82e0ede3f7d/ggml-opencl.cpp#L131-L142 + """ + + def __init__(self, linear: torch.Tensor, shape: list[int]): + assert linear.dtype == torch.uint8 + self.linear = linear + self.shape = shape + + def unpack(self) -> Q4_1Struct: + # Blocks are 9 i16s, so start there. + # delta: 1 i16 + # quants: 8 i16s. (32 i4s -> 16 i8s -> 8 i16s) + linear_blocks = self.linear.view(torch.int16).reshape(-1, 10) + # Reblock to the result shape excluding the final dimension, which + # is expanded. + block_shape = self.shape[0:-1] + [-1, 10] + blocks = linear_blocks.reshape(block_shape) + d = blocks[..., 0:1].view(torch.float16) + m = blocks[..., 1:2].view(torch.float16) + qs = blocks[..., 2:].view(torch.uint8) + return Q4_1Struct(self.shape, blocks, d, m, qs) diff --git a/python/turbine_llamacpp/llamacpp_runner.py b/python/turbine_llamacpp/llamacpp_runner.py new file mode 100644 index 0000000..0730b5d --- /dev/null +++ b/python/turbine_llamacpp/llamacpp_runner.py @@ -0,0 +1,126 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from iree import runtime as ireert +import torch +import time +from turbine_llamacpp.params import * +from transformers import LlamaTokenizer + +parser = argparse.ArgumentParser() + +# TODO move common runner flags to generic flag file +parser.add_argument( + "--vmfb_path", + type=str, + default="output.vmfb", + help="path to vmfb containing compiled module", +) +parser.add_argument( + "--external_weight_path", + type=str, + default="reformatted_parameters.irpa", + help="path to external weight parameters", +) +parser.add_argument( + "--gguf_path", + type=str, + default="", + help="path to gguf file used to generate parameters", +) +parser.add_argument( + "--hf_model_path", + type=str, + default="openlm-research/open_llama_3b", + help="path to the hf model. Needed for tokenizer right now", +) +parser.add_argument( + "--device", + type=str, + default="local-task", + help="local-sync, local-task", +) +parser.add_argument( + "--prompt", + type=str, + default=" Q: What is the largest animal?\nA:", + help="prompt for llm model", +) + + +class SharkLLM(object): + def __init__(self, device, vmfb_path, external_weight_path): + self.runner = vmfbRunner( + device=device, + vmfb_path=vmfb_path, + external_weight_path=external_weight_path, + ) + self.model = self.runner.ctx.modules.llama_dpis + self.first_input = True + self.num_tokens = 0 + self.last_prompt = None + self.prev_token_len = 0 + + def format_out(self, results): + return results.to_host()[0][0] + + def generate(self, input_ids, tokenizer): + try: + turbine_results = [] + # Only need not seen token for init cache + # Because we have stored the res in KV-cache. + token_len = input_ids.shape[-1] + inputs = [ireert.asdevicearray(self.runner.config.device, input_ids)] + s = time.time() + results = self.model["run_initialize"](*inputs) # example_input_id + e = time.time() + print( + f"num_tokens: {token_len}, time_taken={e-s}, tok/second:{token_len/(e-s)}" + ) + token_len += 1 + self.first_input = False + s = time.time() + turbine_results.append(self.format_out(results)) + while self.format_out(results) != 2: + results = self.model["run_forward"](results) + # uncomment to see tokens as they are emitted + # print(f"turbine: {tokenizer.decode(self.format_out(results))}") + turbine_results.append(self.format_out(results)) + e = time.time() + decoded_tokens = len(turbine_results) + print( + f"Decode num_tokens: {decoded_tokens}, time_taken={e-s}, tok/second:{decoded_tokens/(e-s)}" + ) + self.prev_token_len = token_len + decoded_tokens + return turbine_results + except KeyboardInterrupt: + return turbine_results + + +def run_llm( + device, + prompt, + vmfb_path, + external_weight_path, + hf_model_path, +): + tokenizer = LlamaTokenizer.from_pretrained(hf_model_path) + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + llm = SharkLLM( + device=device, + vmfb_path=vmfb_path, + external_weight_path=external_weight_path, + ) + print("generating turbine output: ") + return tokenizer.decode(llm.generate(input_ids, tokenizer=tokenizer)) + + +if __name__ == "__main__": + args = parser.parse_args() + turbine_output = run_llm( + args.device, + args.prompt, + args.vmfb_path, + args.external_weight_path, + args.hf_model_path, + ) + print(turbine_output) diff --git a/python/turbine_llamacpp/model.py b/python/turbine_llamacpp/model.py index e58898e..6d94168 100644 --- a/python/turbine_llamacpp/model.py +++ b/python/turbine_llamacpp/model.py @@ -16,6 +16,7 @@ ENABLE_DEBUG = False import argparse + parser = argparse.ArgumentParser() parser.add_argument( "--gguf_path", @@ -24,6 +25,7 @@ help="path to gguf", ) + def debug(*args): if ENABLE_DEBUG: print(*args) @@ -66,26 +68,18 @@ def __init__(self, hp: HParams): raise ValueError("Unsupported rotary embedding") # Initialize the KV cache. - self.cache_k = torch.empty( - ( - self.transformer_block_count, - self.hp.bs, - self.max_seqlen, - self.attention_head_count, - self.attention_head_dim, - ), - dtype=self.hp.dtype, - ) - self.cache_v = torch.empty( - ( - self.transformer_block_count, - self.hp.bs, - self.max_seqlen, - self.attention_head_count, - self.attention_head_dim, - ), - dtype=self.hp.dtype, - ) + self.kv_cache = [ + torch.empty( + ( + self.hp.bs, + self.max_seqlen, + self.attention_head_count, + self.attention_head_dim, + ), + dtype=self.hp.dtype, + ) + for i in range(self.transformer_block_count * 2) + ] def forward( self, @@ -93,8 +87,7 @@ def forward( start_index: int, *, return_logits: bool = False, - local_cache_k: Optional[torch.Tensor] = None, - local_cache_v: Optional[torch.Tensor] = None, + local_kv_cache: list[torch.Tensor] = None, ): bs, sl = tokens.shape assert bs == self.hp.bs, "Batch size mismatch vs params" @@ -113,17 +106,15 @@ def forward( ).type_as(h) # Allow either the global cache or a local set passed in parameters. - if local_cache_k is None: - local_cache_k = self.cache_k - if local_cache_v is None: - local_cache_v = self.cache_v + if local_kv_cache is None: + local_kv_cache = self.kv_cache # Transformer blocks. for block_idx in range(self.transformer_block_count): transformer_theta = self.theta("blk", block_idx) # Attention. - block_cache_k = local_cache_k[block_idx, ...] - block_cache_v = local_cache_v[block_idx, ...] + block_cache_k = local_kv_cache[block_idx] + block_cache_v = local_kv_cache[self.transformer_block_count + block_idx] attention_output = self.attention( transformer_theta, h, @@ -310,7 +301,6 @@ def create_rotary_embed_table(max_seqlen: int, dim: int, theta: float = 10000.0) args = parser.parse_args() torch.no_grad().__enter__() hp = HParams(args.gguf_path) - # print(hp) detokenizer = Detokenizer(hp) model = LlamaCPP(hp) start_index = 0 diff --git a/python/turbine_llamacpp/params.py b/python/turbine_llamacpp/params.py index 01c8ff9..4704892 100644 --- a/python/turbine_llamacpp/params.py +++ b/python/turbine_llamacpp/params.py @@ -33,6 +33,8 @@ def as_qtensor(self) -> QuantizedTensor: tn = self.type_name if tn == "Q4_0": return self.as_q4_0() + if tn == "Q4_1": + return self.as_q4_1() if tn == "Q8_0": return self.as_q8_0() raise ValueError(f"Quantized type {tn} not supported") @@ -44,9 +46,11 @@ def as_tensor(self) -> torch.Tensor: raise ValueError(f"Tensor type {tn} not supported") def as_q4_0(self) -> Q4_0: - # import pdb; pdb.set_trace() return Q4_0(torch.tensor(self.data), self.shape) + def as_q4_1(self) -> Q4_1: + return Q4_1(torch.tensor(self.data), self.shape) + def as_q8_0(self) -> Q8_0: return Q8_0(torch.tensor(self.data), self.shape) @@ -72,6 +76,10 @@ def __init__( self.dtype = dtype self.rotary_emb_dtype = dtype + # Quantized tensor replacement + self.replaced_quantized_tensors = [] + self.supported_types = ["Q4_0", "Q4_1", "Q8_0"] + def _load_gguf(self, reader: GGUFReader): # Extract hyper-parameters. Adapted from gguf-dump.py for field in reader.fields.values(): @@ -86,7 +94,7 @@ def _load_gguf(self, reader: GGUFReader): else: self.tables[field.name] = field.parts # from IPython import embed - # embed() + # embed() # Extract tensors. for tensor in reader.tensors: @@ -111,6 +119,35 @@ def __contains__(self, k: str): def __iter__(self): return self.raw_params.__iter__() + def replace_quantized_tensors(self, replaceable_types: Optional[list[str]] = None): + if not replaceable_types: + replaceable_types = self.supported_types + else: + for type in replaceable_types: + if type not in self.supported_types: + raise ValueError(f"Replacement of type {type} not supported") + if self.dtype == torch.float32: + replacement_type_name = "F32" + elif self.dtype == torch.float16: + replacement_type_name = "F16" + else: + raise ValueError(f"Replacement into tensors of {self.dtype} not supported") + for tensor_name, model_tensor in self.tensors.items(): + if model_tensor.type_name in replaceable_types: + self.replaced_quantized_tensors.append( + (tensor_name, model_tensor.type_name) + ) + replacement_data = torch.zeros( + size=model_tensor.shape, dtype=self.dtype + ) + new_model_tensor = ModelTensor( + name=model_tensor.name, + shape=model_tensor.shape, + type_name=replacement_type_name, + data=replacement_data, + ) + self.tensors[tensor_name] = new_model_tensor + @property def tensor_params( self, @@ -164,6 +201,48 @@ def add_to_dict( add_to_dict(False, hp_tensor.name, hp_tensor.as_tensor()) return params_dict, qparams_dict + def repack_tensor_params( + self, + dequantize_types: list[str] = [], + dequantize_params: list[str] = [], + dtype: Optional[torch.dtype] = None, + dequantize_all: bool = False, + ) -> dict[str, torch.Tensor]: + if dtype is None: + dtype = self.dtype + reformatted_tensors = {} + for tensor_name, tensor in self.tensors.items(): + if not tensor.is_quantized or tensor.type_name not in self.supported_types: + reformatted_tensors[tensor_name] = np.ascontiguousarray( + tensor.as_tensor().detach().numpy() + ) + continue + if ( + dequantize_all + or tensor.type_name in dequantize_types + or tensor_name in dequantize_params + ): + reformatted_tensor = tensor.as_qtensor().unpack().dequant(dtype) + reformatted_tensors[tensor_name] = np.ascontiguousarray( + reformatted_tensor.detach().numpy() + ) + else: + reformatted_tensor, scales, zps = ( + tensor.as_qtensor().unpack().repack_for_turbine(dtype) + ) + reformatted_tensors[tensor_name] = np.ascontiguousarray( + reformatted_tensor.detach().numpy() + ) + if scales is not None: + reformatted_tensors[f"{tensor_name}_scale"] = np.ascontiguousarray( + scales.detach().numpy() + ) + if zps is not None: + reformatted_tensors[f"{tensor_name}_zp"] = np.ascontiguousarray( + zps.detach().numpy() + ) + return reformatted_tensors + def __repr__(self): parts = ["HParams(", " raw_params=["]