Skip to content

Commit

Permalink
Feat (examples/llm): add packed 3/5/6b export
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Sep 4, 2023
1 parent 4ff62c7 commit 91bd476
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from abc import ABC
from abc import abstractmethod
from contextlib import contextmanager
import math
import warnings

import numpy as np
import torch
from torch.nn import Module

Expand Down Expand Up @@ -128,6 +131,41 @@ def pack_int_weights(self, bit_width, int_weights):
packed_int_weights[:, column] |= int_weights[:, j] << shift_factor
i += 8 // bit_width
return packed_int_weights

# pack 3b values into 3 bytes, 5b values into 5 bytes, 6b values into 4 bytes
elif bit_width == 3 or bit_width == 5 or bit_width == 6:
padding = (int_weights.shape[1] * bit_width) % 8
if padding > 0:
warnings.warn(
f"Weight tensor does not divide by {bit_width}, zero-padding columns by {padding}."
)
packed_int_weights = torch.zeros(
(int_weights.shape[0], int_weights.shape[1] * bit_width // 8 + padding),
device=int_weights.device,
dtype=int_weights.dtype)

def lcm(x, y):
from fractions import gcd
return x * y // gcd(x, y)

num_packed_bits = lcm(bit_width, 8)
num_packed_bytes = num_packed_bits // 8
num_packed_elems = num_packed_bits // bit_width

i = 0
for column in range(0, packed_int_weights.shape[1], num_packed_bytes):
# cast to uint8 since it's the only dtype supported by unpackbits
# the bit-wise representation of int8 values isn't affected
bits_to_unpack = int_weights[:, i:i + num_packed_elems].numpy().astype(np.uint8)
unpacked_bits = np.unpackbits(bits_to_unpack, axis=1)
unpacked_bits = unpacked_bits.reshape(unpacked_bits.shape[0], -1, 8)
unpacked_bits = unpacked_bits[:, :, -bit_width:]
unpacked_bits = unpacked_bits.reshape(unpacked_bits.shape[0], -1)
packed_bits = np.packbits(unpacked_bits, axis=1)
packed_int_weights[:, column:column +
num_packed_bytes] |= torch.from_numpy(packed_bits)
i += num_packed_elems
return packed_int_weights
else:
raise ValueError(f"Bit width {bit_width} not supported.")

Expand Down

0 comments on commit 91bd476

Please sign in to comment.