Skip to content

Commit

Permalink
add config
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Liu <[email protected]>
  • Loading branch information
wenhuach21 authored and Yi4Liu committed Dec 24, 2024
1 parent 9f281d7 commit f899cbf
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
7 changes: 5 additions & 2 deletions auto_round/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ class GlobalConfig:
global_config.FP8_INPUT_BACKOFF = float(os.environ.get("AR_FP8_INPUT_BACKOFF", 1.0))
global_config.FP8_WEIGHT_BACKOFF = float(os.environ.get("AR_FP8_WEIGHT_BACKOFF", 1.0))

from loguru import logger

logger.info(f"Global config: {global_config}")
# AR_FP8_INPUT_BACKOFF=0.5 AR_FP8_WEIGHT_BACKOFF=1
import logging
logger = logging.getLogger(__name__)

logger.warning(f"Global config: {global_config}")

inc_default_config = GlobalConfig(FP8_INPUT_BACKOFF=0.25, FP8_WEIGHT_BACKOFF=0.5)
config4_in_result_table = inc_default_config
Expand Down
10 changes: 6 additions & 4 deletions auto_round/data_type/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import lru_cache

import torch

from auto_round.utils import logger
from auto_round.config import global_config
STANDARD_FP8E4M3FN_MAX = torch.finfo(torch.float8_e4m3fn).max

Expand All @@ -27,10 +27,12 @@
def get_gaudi2_fp8_ste_func():
from auto_round.utils import is_hpu_supported
if is_hpu_supported():
return float8_e4m3fn_hpu_ste
fn = float8_e4m3fn_hpu_ste
logger.warning("Using HPU STE for FP8")
else:
return float8_e4m3fn_ste

fn = float8_e4m3fn_ste
logger.warning("Using CUDA/CPU STE for FP8")
return fn

def float8_e4m3fn_ste(x: torch.Tensor):
"""Straight-Through Estimator (STE) for float8.
Expand Down
2 changes: 2 additions & 0 deletions auto_round/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ def __init__(self, orig_layer):
self.act_quant_func = self.orig_layer.act_quant_func

def forward(self, x):
# FIXME: (Yi) for static quant, remove it later
assert hasattr(self.orig_layer, "act_max"), f"For static quant, expecting act_max in {self.orig_layer}"
act_max = self.orig_layer.act_max if hasattr(self.orig_layer, "act_max") else None
x, _, _ = self.orig_layer.act_quant_func(x, bits=self.orig_layer.act_bits,
group_size=self.orig_layer.group_size,
Expand Down

0 comments on commit f899cbf

Please sign in to comment.