Skip to content

Commit

Permalink
add v2
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Liu <[email protected]>
  • Loading branch information
Yi4Liu committed Jan 1, 2025
1 parent 2e27ca3 commit 14a59a2
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
1 change: 1 addition & 0 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
total_loss = 0

for i in range(self.iters):
logger.info(f"iter {i} / {self.iters}")
total_loss = 0
if self.sampler == "rand":
whole_indices = torch.randperm(nsamples)[:pick_samples]
Expand Down
61 changes: 58 additions & 3 deletions auto_round/data_type/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import lru_cache

from loguru import logger as rich_logger
import torch
from auto_round.utils import logger
from auto_round.config import global_config
Expand Down Expand Up @@ -67,7 +67,6 @@ def float8_e4m3fn_hpu_ste(x: torch.Tensor):
return fp8



@register_dtype("fp8_dynamic_per_token_sym")
def fp8_dynamic_per_token_sym(tensor, max_scale=1.0, **kwargs):
"""Dynamic per-token symmetric quantization using float8.
Expand Down Expand Up @@ -200,7 +199,6 @@ def progressive_quant_fp8_int4_bas(tensor, bits=4, group_size=-1, v=0, min_scale
return qdq_tensor, scale_fp8_to_int4 * scale_bf16_to_fp8, None



##ugly code, need to refine later

@register_dtype("fp8_gaudi2_sym")
Expand Down Expand Up @@ -293,3 +291,60 @@ def progressive_quant_fp8_int4(tensor, bits=4, group_size=-1, v=0, min_scale=1.0
qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8

return qdq_tensor, (scale_fp8_to_int4 * scale_bf16_to_fp8, scale_bf16_to_fp8), zp_fp8_to_int4

@register_dtype("fp8_gaudi2_to_int_sym_v2")
def progressive_quant_fp8_int4_v2(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, q_scale_thresh=1e-5,
weight_fp8_max_scale=1.0,**kwargs):
"""Two-stage quantization: quantize tensor to fp8 by per tensor, then quantize fp8 to w4g128
This method first quantizes the input tensor into float8 format and then performs
a secondary quantization to int4 with grouping.
Args:
tensor (torch.Tensor): Input tensor to quantize.
bits (int, optional): Bit precision for secondary quantization. Defaults to 4.
group_size (int, optional): Group size for int4 quantization. Defaults to -1 (no grouping).
v (float, optional): Optional parameter for variance tuning. Defaults to 0.
min_scale (float, optional): Minimum scaling factor for int4 quantization. Defaults to 1.0.
max_scale (float, optional): Maximum scaling factor for int4 quantization. Defaults to 1.0.
q_scale_thresh (float, optional): Threshold for scaling. Defaults to 1e-5.
weight_fp8_max_scale (float, optional): Maximum scaling factor for float8 quantization. Defaults to 1.0.
**kwargs: Additional arguments for compatibility.
Returns:
tuple:
- Quantized and dequantized tensor (torch.Tensor).
- Combined scaling factor (torch.Tensor).
- Placeholder for zp (None).
"""
# convert to int4
from auto_round.data_type.int import quant_tensor_sym
qdq_int4_tensor, scale_bf16_to_int4, zp_fp8_to_int4 = quant_tensor_sym(
tensor,
bits=bits,
group_size=group_size,
v=v,
min_scale=min_scale,
max_scale=max_scale,
scale_dtype=torch.bfloat16,
q_scale_thresh=q_scale_thresh,
)
# FIXME(Yi): some fuse error here
torch._dynamo.graph_break()
fp8_max = STANDARD_FP8E4M3FN_MAX * global_config.FP8_WEIGHT_BACKOFF
tensor_max = torch.max(torch.abs(qdq_int4_tensor)).to(torch.float32) * weight_fp8_max_scale ## better train a ratio
scale = tensor_max.to(torch.float32) / fp8_max
min_scaling_factor = 1.0 / (fp8_max* 512.0) ##copy from vllm
scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor)
fp8_res = qdq_int4_tensor / scale_bf16_to_fp8
fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max)
float8_e4m3fn_ste_gaudi2 = get_gaudi2_fp8_ste_func()
fp8_res = float8_e4m3fn_ste_gaudi2(fp8_res)

##convert to bf16
fp8_res_using_16bit = fp8_res.to(tensor.dtype)

qdq_tensor = fp8_res_using_16bit * scale_bf16_to_fp8

# return qdq_tensor, (scale_fp8_to_int4 * scale_bf16_to_fp8, scale_bf16_to_fp8), zp_fp8_to_int4
return qdq_tensor, (scale_bf16_to_int4, scale_bf16_to_fp8), zp_fp8_to_int4

0 comments on commit 14a59a2

Please sign in to comment.