diff --git a/README.md b/README.md
index d168627e..428e623b 100644
--- a/README.md
+++ b/README.md
@@ -25,7 +25,7 @@ image presents an overview of AutoRound.
## What's New
-
+* [2024/06] AutoRound format supports mixed bit-widths and group sizes for inference, resolving the significant performance drop issue with the asymmetric kernel
* [2024/05] Check out our updated paper on [arxiv](https://arxiv.org/pdf/2309.05516v4)
* [2024/05] AutoRound supports lm-head quantization, saving 0.7G for LLaMA3-8B at W4G128.
* [2024/05] AutoRound performs well
@@ -42,6 +42,8 @@ image presents an overview of AutoRound.
```bash
pip install -r requirements.txt
python setup.py install
+or
+pip install -vvv --no-build-isolation -e .
```
### Install from pypi
@@ -55,7 +57,7 @@ pip install auto-round
### Gaudi2/ CPU/ GPU
We found a significant accuracy discrepancy with the qdq model using the AutoGPTQ GPU backend with asymmetric
-quantization in some scenarios. Please switch to symmetric quantization to alleviate this issue.
+quantization in some scenarios, especially at lower bits,like 2. Please save quantized model to AuoRound format to fix this issue.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -71,7 +73,7 @@ bits, group_size, sym = 4, 128, False
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, device=None)
autoround.quantize()
output_dir = "./tmp_autoround"
-autoround.save_quantized(output_dir)
+autoround.save_quantized(output_dir) ##save_quantized(output_dir,format=="auto_round")
```
@@ -149,7 +151,7 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
-##from auto_round.auto_quantizer import AutoHfQuantizer ## uncomment it for models with quantized lm-head
+##from auto_round.auto_quantizer import AutoHfQuantizer ## uncomment it for models with auto_round format
quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto", trust_remote_code=True)
diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py
index f0de9b04..84aaa5e7 100644
--- a/auto_round/auto_quantizer.py
+++ b/auto_round/auto_quantizer.py
@@ -37,7 +37,6 @@
from packaging import version
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import Conv1D
-import transformers
from transformers.quantizers import AutoQuantizationConfig, HfQuantizer
from transformers.quantizers.auto import AUTO_QUANTIZER_MAPPING
from transformers.utils.quantization_config import AwqConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod
@@ -52,11 +51,17 @@
else:
import importlib.metadata as importlib_metadata
-AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
+AUTOROUND_MINIMUM_VERSION = version.parse("0.2")
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
+ try: ##TODO remove it later
+ import auto_round
+ return True, auto_round.__version__
+ except:
+ pass
+
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
@@ -71,26 +76,32 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
return package_exists
-_auto_gptq_available = _is_package_available("auto_gptq")
+_auto_round_available = _is_package_available("auto_round")
-def is_auto_gptq_available():
- if _auto_gptq_available:
- version_autogptq = version.parse(importlib_metadata.version("auto_gptq"))
- if AUTOGPTQ_MINIMUM_VERSION < version_autogptq:
+def is_auto_round_available():
+ if _auto_round_available:
+ version_autoround = version.parse(importlib_metadata.version("auto_round"))
+ if AUTOROUND_MINIMUM_VERSION < version_autoround:
return True
else:
raise ImportError(
- f"Found an incompatible version of auto-gptq. Found version {version_autogptq},"
- f" but only version above {AUTOGPTQ_MINIMUM_VERSION} are supported"
+ f"Found an incompatible version of auto-round. Found version {version_autoround},"
+ f" but only version above {AUTOROUND_MINIMUM_VERSION} are supported"
)
-if is_auto_gptq_available():
- from auto_gptq import exllama_set_max_input_length
- from auto_gptq.modeling._utils import autogptq_post_init
- from auto_gptq.quantization import GPTQ
- from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
+def is_autoround_exllamav2_available():
+ res = True
+ try:
+ from autoround_exllamav2_kernels import gemm_half_q_half, make_q_matrix
+ except ImportError as e:
+ res = False
+ return res
+
+
+if is_auto_round_available():
+ from auto_round_extension.cuda.post_init import autoround_post_init
#
@@ -201,15 +212,8 @@ def __init__(
dataset: str = None,
group_size: int = 128,
sym: bool = False,
- backend="gptq:exllamav2",
- iters: int = 200,
+ backend="autoround:exllamav2",
weight_config: dict = None,
- enable_quanted_input=True,
- enable_minmax_tuning=True,
- lr=None,
- minmax_lr=None,
- n_samples=512,
- seqlen=2048,
**kwargs,
):
self.bits = bits
@@ -218,14 +222,7 @@ def __init__(
self.group_size = group_size
self.sym = sym
self.backend = backend
- self.inters = iters
self.weight_config = weight_config
- self.enable_quanted_input = enable_quanted_input
- self.enable_minmax_tuning = enable_minmax_tuning
- self.lr = lr
- self.minmax_lr = minmax_lr
- self.n_samples = n_samples
- self.seqlen = seqlen
if kwargs is not None:
for key in kwargs.keys():
setattr(self, key, kwargs[key])
@@ -233,16 +230,12 @@ def __init__(
self.post_init()
def get_loading_attributes(self):
- pass
- # attibutes_dict = copy.deepcopy(self.__dict__)
- # loading_attibutes = ["disable_exllama", "use_exllama", "exllama_config", "use_cuda_fp16", "max_input_length"]
- # loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
- # return loading_attibutes_dict
+ return {}
def post_init(self):
r"""Safety checker that arguments are correct."""
- if self.bits not in [2, 3, 4, 8]:
- raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
+ if self.bits not in [2, 4, 8]:
+ raise ValueError(f"Only support quantization to [2,4,8] bits but found {self.bits}")
if self.group_size != -1 and self.group_size <= 0:
raise ValueError("group_size must be greater than 0 or equal to -1")
##TODO add more check
@@ -254,23 +247,23 @@ def to_dict(self):
class AutoRoundQuantizer(HfQuantizer):
- """Quantizer of the Autoround method, currently only gptq backend has been supported."""
+ """Quantizer of the AutoRound method, currently only triton and exllamav2 backend has been supported."""
requires_calibration = False
- required_packages = ["auto_gptq"]
+ required_packages = ["auto_round"]
optimum_quantizer = None
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
+ self.exllama2_available = is_autoround_exllamav2_available
def validate_environment(self, *args, **kwargs):
- gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
- if not gptq_supports_cpu and not torch.cuda.is_available():
- raise RuntimeError("GPU is required to quantize or run quantize model.")
- elif not is_auto_gptq_available():
- raise ImportError("Loading a GPTQ quantized model requires auto-gptq library (`pip install auto-gptq`)")
- elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"):
- raise ImportError("You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`")
+ if not is_auto_round_available():
+ raise ImportError("Loading a AutoRound quantized model requires auto-round library (`pip install "
+ "auto-round`)")
+ elif version.parse(importlib.metadata.version("auto_round")) < version.parse("0.2.0"):
+ raise ImportError("You need a version of auto_round > 0.2.0 to use AutoRound: `pip install --upgrade "
+ "auto-round`")
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
@@ -280,7 +273,7 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
return torch_dtype
def convert_model(self, model: nn.Module):
- """Convert the model to a GPTQ model by getting and replacing the layers.
+ """Convert the model to an AutoRound model by getting and replacing the layers.
Args:
model (`nn.Module`):
@@ -308,15 +301,22 @@ def convert_model(self, model: nn.Module):
layer_configs[layer_name]["data_type"] = data_type
layer_configs[layer_name]["sym"] = sym
else:
- layer_configs[layer_name]["bits"] = extra_config.get("bits", bits)
- layer_configs[layer_name]["group_size"] = extra_config.get("group_size", group_size)
- layer_configs[layer_name]["data_type"] = extra_config.get("data_type", data_type)
- layer_configs[layer_name]["sym"] = extra_config.get("sym", sym)
+ layer_configs[layer_name]["bits"] = extra_config[layer_name].get("bits", bits)
+ layer_configs[layer_name]["group_size"] = extra_config[layer_name].get("group_size", group_size)
+ layer_configs[layer_name]["data_type"] = extra_config[layer_name].get("data_type", data_type)
+ layer_configs[layer_name]["sym"] = extra_config[layer_name].get("sym", sym)
backend = quantization_config.backend
self._replace_by_quant_layers(model, layer_configs, backend)
return model
+ def _dynamic_import_inference_linear(self, bits, backend):
+ if bits == 4 and self.exllama2_available and "exllama2" in backend:
+ from auto_round_extension.cuda.qliner_exllamav2 import QuantLinear
+ else:
+ from auto_round_extension.cuda.qliner_triton import QuantLinear
+ return QuantLinear
+
def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
"""Replaces linear layers in `module` by `QuantLinear`
@@ -335,21 +335,7 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
data_type = config["data_type"]
if not (bits <= 8 and data_type == "int"):
continue
- from auto_round.export.export_to_autoround.export_to_autoround import get_autogptq_backend_config
-
- use_triton, disable_exllama, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config(
- backend, bits
- )
- QuantLinear = dynamically_import_QuantLinear(
- use_triton=False,
- desc_act=False,
- group_size=group_size,
- bits=bits,
- disable_exllama=True,
- disable_exllamav2=False,
- use_qigen=use_qigen,
- disable_marlin=disable_marlin,
- )
+ QuantLinear = self._dynamic_import_inference_linear(bits, backend)
layer = get_module(module, layer_name)
device = get_device(layer)
if isinstance(layer, nn.Linear):
@@ -382,23 +368,17 @@ def post_init_model(self, model):
The input model
"""
- # if self.bits == 4 and not self.disable_exllama:
- # if get_device(model) == torch.device("cpu") or (
- # hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk"])
- # ):
- # raise ValueError(
- # "Found modules on cpu/disk. Using Exllama
- # or Exllamav2 backend requires all the modules to be on GPU."
- # "You can deactivate exllama backend by
- # setting `disable_exllama=True` in the quantization config object"
- # )
+ #
+ # if self.bits == 4: if get_device(model) == torch.device("cpu") or ( hasattr(model, "hf_device_map") and
+ # any(d in model.hf_device_map for d in ["cpu", "disk"]) ): raise ValueError( "Found modules on cpu/disk.
+ # Using Exllamav2 backend requires all the modules to be on GPU." "You can deactivate exllama backend by
+ # setting `disable_exllama=True` in the quantization config object" )
class StoreAttr(object):
pass
model.quantize_config = StoreAttr()
- model.quantize_config.desc_act = False
- model = autogptq_post_init(model, use_act_order=False)
+ model = autoround_post_init(model)
return model
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
@@ -436,4 +416,3 @@ def is_serializable(self):
transformers.quantizers.auto.AutoHfQuantizer = AutoHfQuantizer
transformers.modeling_utils.AutoHfQuantizer = AutoHfQuantizer
-from transformers import AutoModelForCausalLM as AutoModelForCausalLM
diff --git a/auto_round/export/__init__.py b/auto_round/export/__init__.py
index 3797d641..4f4265bb 100644
--- a/auto_round/export/__init__.py
+++ b/auto_round/export/__init__.py
@@ -15,6 +15,6 @@
from .register import EXPORT_FORMAT
from .export_to_autogptq import save_quantized_as_autogptq
from .export_to_itrex import save_quantized_as_itrex, QuantConfig
-from .export_to_autoround.export_to_autoround import save_quantized_as_autoround
+from .export_to_autoround.export import save_quantized_as_autoround
diff --git a/auto_round/export/export_to_autogptq.py b/auto_round/export/export_to_autogptq.py
index 6ef2a08c..4377719e 100644
--- a/auto_round/export/export_to_autogptq.py
+++ b/auto_round/export/export_to_autogptq.py
@@ -52,6 +52,10 @@
@register_format("auto_gptq")
def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True, **kwargs):
"""Export the model to autogptq format to easily leverage cuda kernel."""
+ try:
+ import auto_gptq
+ except ImportError:
+ raise ImportError("export to autogptq requires autogptq library. Please run 'pip install auto-gptq'")
model = kwargs["model"]
weight_config = kwargs["weight_config"]
sym = kwargs["sym"]
@@ -95,7 +99,7 @@ def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True, **kwar
else:
compressed_model = copy.deepcopy(model.to("cpu"))
- from auto_gptq.modeling._utils import pack_model
+ from auto_gptq.modeling._utils import pack_model # pylint: disable=E0401
if bits == 3 or use_triton is False:
if bits == 3 and use_triton is True:
@@ -127,7 +131,7 @@ def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True, **kwar
info = weight_config[key]
if not check_to_quantized(info):
continue
- quantizers[key] = (None, info["scale"].to(torch.float32), info["zp"].to(torch.float32), info["g_idx"])
+ quantizers[key] = (None, info["scale"], info["zp"].to(torch.float32), info["g_idx"])
pack_model(
compressed_model,
quantizers,
@@ -236,7 +240,7 @@ def _save_quantized_to_autogptq(
model_save_name = model_base_name + ".bin"
torch.save(model.state_dict(), join(save_dir, model_save_name))
- from auto_gptq.modeling._base import BaseQuantizeConfig
+ from auto_gptq.modeling._base import BaseQuantizeConfig # pylint: disable=E0401
quantization_config = BaseQuantizeConfig(
bits=bits,
diff --git a/auto_round/export/export_to_autoround/__init__.py b/auto_round/export/export_to_autoround/__init__.py
index 862e97d5..afa7c3fe 100644
--- a/auto_round/export/export_to_autoround/__init__.py
+++ b/auto_round/export/export_to_autoround/__init__.py
@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .export_to_autoround import save_quantized_as_autoround
+from .export import save_quantized_as_autoround
diff --git a/auto_round/export/export_to_autoround/export_to_autoround.py b/auto_round/export/export_to_autoround/export.py
similarity index 87%
rename from auto_round/export/export_to_autoround/export_to_autoround.py
rename to auto_round/export/export_to_autoround/export.py
index 7f2253f7..f89a03aa 100644
--- a/auto_round/export/export_to_autoround/export_to_autoround.py
+++ b/auto_round/export/export_to_autoround/export.py
@@ -22,8 +22,7 @@
import transformers
from auto_round.export.register import register_format
-from auto_round.utils import get_layer_names_in_block, get_block_names, get_module, logger, set_module
-
+from auto_round.utils import get_layer_names_in_block, get_module, logger, set_module
def check_neq_config(config, data_type, bits, group_size, sym):
@@ -53,7 +52,7 @@ def get_autogptq_backend_config(backend, bits=4):
if backend == "gptq:marlin":
use_triton = False
disable_marlin = True
- if backend == "gptq:exllamav2":
+ if backend == "gptq:exllamav2": ##need v1 code to export
use_triton = False
disable_marlin = True
if backend == "gptq:exllamav1":
@@ -71,10 +70,34 @@ def get_autogptq_backend_config(backend, bits=4):
return use_triton, disable_exllamav1, disable_exllamav2, use_qigen, disable_marlin
-@register_format("autoround")
-def save_quantized_as_autoround(output_dir, inplace=True, backend="gptq:exllamav2", **kwargs):
- from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
+def dynamic_QuantLienar_for_packing(backend, bits, group_size):
+ if "gptq" in backend:
+ use_triton, disable_exllamav1, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config(
+ backend, bits
+ )
+ from auto_gptq.utils.import_utils import dynamically_import_QuantLinear # pylint: disable=E0401
+ QuantLinear = dynamically_import_QuantLinear(
+ use_triton=use_triton,
+ desc_act=False,
+ group_size=group_size,
+ bits=bits,
+ disable_exllama=disable_exllamav1,
+ disable_exllamav2=disable_exllamav2,
+ use_qigen=use_qigen,
+ disable_marlin=disable_marlin,
+ )
+ return QuantLinear
+ ##export all use trition, inference use exllamav2
+ elif "autoround" in backend or "auto-round" in backend or "auto_round" in backend:
+ from auto_round_extension.cuda.qliner_triton import QuantLinear
+ return QuantLinear
+
+ else:
+ assert False, f"only support gptq and autoround backend"
+
+@register_format("auto_round")
+def save_quantized_as_autoround(output_dir, inplace=True, backend="autoround:exllamav2", **kwargs):
model = kwargs["model"]
if not inplace:
model = copy.deepcopy(model.to("cpu"))
@@ -90,22 +113,11 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="gptq:exllamav
bits = config["bits"]
group_size = config["group_size"]
- use_triton, disable_exllamav1, disable_exllamav2, use_qigen, disable_marlin = get_autogptq_backend_config(
- backend, bits
- )
layer = get_module(model, name)
- device = "cpu"
- QuantLinear = dynamically_import_QuantLinear(
- use_triton=use_triton,
- desc_act=False,
- group_size=group_size,
- bits=bits,
- disable_exllama=disable_exllamav1,
- disable_exllamav2=disable_exllamav2,
- use_qigen=use_qigen,
- disable_marlin=disable_marlin,
- )
+ device = layer.weight.device
+
+ QuantLinear = dynamic_QuantLienar_for_packing(backend, bits, group_size)
if isinstance(layer, nn.Linear):
in_features = layer.in_features
@@ -138,7 +150,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="gptq:exllamav
quantization_config["backend"] = backend
extra_config = {}
for layer_name in weight_config:
- if weight_config[layer_name]["data_type"] != "int" and weight_config[layer_name]["bits"] >= 16:
+ if weight_config[layer_name]["bits"] >= 16:
continue
if layer_name not in layer_names_in_block:
extra_config[layer_name] = {}
@@ -190,7 +202,7 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_ser
"""
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
- config_file = "quantize_config.json"
- if hasattr(model, "config") and hasattr(model.config, "quantize_config"):
+ config_file = "quantization_config.json"
+ if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
json.dump(model.config.quantization_config, f, indent=2)
diff --git a/auto_round/utils.py b/auto_round/utils.py
index 79aa1b5b..f661aeea 100644
--- a/auto_round/utils.py
+++ b/auto_round/utils.py
@@ -472,7 +472,6 @@ def block_forward(block, input_ids, input_others, amp=False, amp_dtype=torch.flo
output: The output of the forward pass.
"""
if input_ids.device != device:
- # input_ids, input_others = move_to_device(input_ids, input_others, device)
input_ids = to_device(input_ids, device)
input_others = to_device(input_others, device)
input_tuple = input_others.pop("positional_inputs", None)
diff --git a/auto_round/version.py b/auto_round/version.py
index 69824212..b62be5f3 100644
--- a/auto_round/version.py
+++ b/auto_round/version.py
@@ -14,4 +14,4 @@
"""Intel® auto-round: An open-source Python library
supporting popular model weight only compression based on signround."""
-__version__ = "0.2"
+__version__ = "0.2.1.dev"
diff --git a/auto_round_extension/__init__.py b/auto_round_extension/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/auto_round_extension/cuda/__init__.py b/auto_round_extension/cuda/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/auto_round_extension/cuda/exllamav2/config.h b/auto_round_extension/cuda/exllamav2/config.h
new file mode 100644
index 00000000..86baaf41
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/config.h
@@ -0,0 +1,13 @@
+#ifndef _config_h
+#define _config_h
+
+#define MAX_Q_GEMM_ROWS 50
+
+#define QMODE_2BIT 1
+#define QMODE_3BIT 1
+#define QMODE_4BIT 1
+#define QMODE_5BIT 1
+#define QMODE_6BIT 0
+#define QMODE_8BIT 0
+
+#endif
diff --git a/auto_round_extension/cuda/exllamav2/cpp/util.h b/auto_round_extension/cuda/exllamav2/cpp/util.h
new file mode 100644
index 00000000..919703a8
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cpp/util.h
@@ -0,0 +1,12 @@
+#ifndef _util_h
+#define _util_h
+
+#define DBGS(__x) printf("%s\n", __x)
+#define DBGI(__x) printf("%s: %i\n", #__x, __x)
+#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
+#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
+#define DBGF(__x) printf("%s: %f\n", #__x, __x)
+#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
+#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
+
+#endif
diff --git a/auto_round_extension/cuda/exllamav2/cuda/compat.cuh b/auto_round_extension/cuda/exllamav2/cuda/compat.cuh
new file mode 100644
index 00000000..12684ff8
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/compat.cuh
@@ -0,0 +1,56 @@
+#ifndef _compat_cuh
+#define _compat_cuh
+
+// atomicAdd for half types, to support CC < 7.x
+
+__device__ __forceinline__ void atomicAdd_half(half* address, half val)
+{
+ unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
+ unsigned int old = *address_as_ui;
+ unsigned int assumed;
+
+ do
+ {
+ assumed = old;
+ __half_raw hsum;
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+ half tmpres = __hadd(hsum, val);
+ hsum = __half_raw(tmpres);
+ old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
+ old = atomicCAS(address_as_ui, assumed, old);
+ }
+ while (assumed != old);
+}
+
+// atomicAdd for half2 types
+
+__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
+{
+ unsigned int* address_as_ui = (unsigned int*)address;
+ unsigned int old = *address_as_ui;
+ unsigned int assumed;
+ do
+ {
+ assumed = old;
+ half2 old_val = *((half2*)&old);
+ half2 new_val = __hadd2(old_val, val);
+ old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
+ }
+ while (assumed != old);
+}
+
+//
+
+#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
+#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
+
+__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
+
+#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
+__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
+#endif
+
+#endif
+#endif
+
+#endif
diff --git a/auto_round_extension/cuda/exllamav2/cuda/compat_gemm.cuh b/auto_round_extension/cuda/exllamav2/cuda/compat_gemm.cuh
new file mode 100644
index 00000000..19b1e4a6
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/compat_gemm.cuh
@@ -0,0 +1,38 @@
+#ifndef _compat_gemm_cuh
+#define _compat_gemm_cuh
+
+#if defined(USE_ROCM)
+
+// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
+// for symbols as hipblasHalf.
+#include
+
+__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
+ hipblasOperation_t transA,
+ hipblasOperation_t transB,
+ int m,
+ int n,
+ int k,
+ const half* alpha,
+ const half* AP,
+ int lda,
+ const half* BP,
+ int ldb,
+ const half* beta,
+ half* CP,
+ int ldc) {
+ return hipblasHgemm(handle, transA, transB, m, n, k,
+ reinterpret_cast(alpha),
+ reinterpret_cast(AP), lda,
+ reinterpret_cast(BP), ldb,
+ reinterpret_cast(beta),
+ reinterpret_cast(CP), ldc);
+}
+#define hipblasHgemm __compat_hipblasHgemm
+
+// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
+#define rocblas_operation_none HIPBLAS_OP_N
+#define rocblas_hgemm __compat_hipblasHgemm
+#endif
+
+#endif
diff --git a/auto_round_extension/cuda/exllamav2/cuda/matrix_view.cuh b/auto_round_extension/cuda/exllamav2/cuda/matrix_view.cuh
new file mode 100644
index 00000000..55af84f2
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/matrix_view.cuh
@@ -0,0 +1,121 @@
+#ifndef _matrix_view_cuh
+#define _matrix_view_cuh
+
+#include
+#include
+
+#include "quant/qdq_util.cuh"
+
+class MatrixView_half
+{
+public:
+ const half* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
+ __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
+ __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
+ __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
+
+ __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
+ {
+ half2* ptr = (half2*) item_ptr(row, column);
+ half2 i01 = ptr[0];
+ half2 i23 = ptr[1];
+ items[0] = __low2half(i01);
+ items[1] = __high2half(i01);
+ items[2] = __low2half(i23);
+ items[3] = __high2half(i23);
+ }
+ __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
+ {
+ half2* ptr = (half2*)item_ptr(row, column);
+ half2 i01 = ptr[0];
+ half2 i23 = ptr[1];
+ items[0] = __half2float(__low2half(i01));
+ items[1] = __half2float(__high2half(i01));
+ items[2] = __half2float(__low2half(i23));
+ items[3] = __half2float(__high2half(i23));
+ }
+
+ __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
+ {
+ half2* ptr = (half2*)item_ptr(row, column);
+ half2 i01 = ptr[0];
+ half2 i23 = ptr[1];
+ items[0] = __half2half2(__low2half(i01));
+ items[1] = __half2half2(__high2half(i01));
+ items[2] = __half2half2(__low2half(i23));
+ items[3] = __half2half2(__high2half(i23));
+ }
+};
+
+class MatrixView_half_rw
+{
+public:
+ half* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
+ __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
+ __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
+ __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
+ __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
+
+ __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
+ {
+ half2 v01 = __halves2half2(v0, v1);
+ half2 v23 = __halves2half2(v2, v3);
+ half2* ptr = (half2*) item_ptr(row, column);
+ ptr[0] = v01;
+ ptr[1] = v23;
+ }
+};
+
+class MatrixView_q4_row
+{
+public:
+ const uint32_t* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ int item(int row, int column) const
+ {
+ int shift = (column & 0x07) * 4;
+ return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
+ }
+
+ __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
+ {
+ int shift = (column & 0x07) * 4;
+ uint32_t d = data[row * width / 8 + column / 8] >> shift;
+ items[0] = d & 0x0f;
+ items[1] = (d >> 4) & 0x0f;
+ }
+
+ __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
+ {
+ int shift = (column & 0x07) * 4;
+ uint32_t d = data[row * width / 8 + column / 8] >> shift;
+ items[0] = d & 0x0f;
+ items[1] = (d >> 4) & 0x0f;
+ items[2] = (d >> 8) & 0x0f;
+ items[3] = (d >> 12) & 0x0f;
+ }
+};
+
+#endif
\ No newline at end of file
diff --git a/auto_round_extension/cuda/exllamav2/cuda/q_gemm.cu b/auto_round_extension/cuda/exllamav2/cuda/q_gemm.cu
new file mode 100644
index 00000000..351b9cd5
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/q_gemm.cu
@@ -0,0 +1,211 @@
+#include "q_gemm.cuh"
+#include "util.cuh"
+#include "matrix_view.cuh"
+#include "../config.h"
+
+#include "quant/qdq_2.cuh"
+#include "quant/qdq_3.cuh"
+#include "quant/qdq_4.cuh"
+#include "quant/qdq_5.cuh"
+#include "quant/qdq_6.cuh"
+#include "quant/qdq_8.cuh"
+
+#define BLOCK_KN_SIZE 128
+#define BLOCK_M_SIZE_MAX 8
+#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
+#define CLEAR_N_SIZE 256
+
+#include "q_gemm_kernel.cuh"
+#include "q_gemm_kernel_gptq.cuh"
+
+#include "compat_gemm.cuh"
+
+void gemm_half_q_half_cuda_part
+(
+ const half* a,
+ QMatrix* b,
+ half* c,
+ int size_m,
+ int size_n,
+ int size_k,
+ int m_count,
+ bool clear
+)
+{
+ if (!b->is_gptq)
+ {
+ dim3 blockDim, gridDim;
+ blockDim.x = BLOCK_KN_SIZE;
+ blockDim.y = 1;
+ blockDim.z = 1;
+ gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
+ gridDim.y = DIVIDE(size_m, m_count);
+ gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
+
+ fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
+
+ kernel<<>>
+ (
+ a,
+ b->cuda_q_weight,
+ b->cuda_q_scale,
+ b->cuda_q_scale_max,
+ c,
+ size_m,
+ size_n,
+ size_k,
+ b->groups,
+ b->groupsize,
+ b->cuda_q_perm,
+ b->rows_8,
+ b->rows_6,
+ b->rows_5,
+ b->rows_4,
+ b->rows_3,
+ b->rows_2,
+ clear
+ );
+ }
+ else
+ {
+ dim3 blockDim, gridDim;
+ blockDim.x = BLOCK_KN_SIZE;
+ blockDim.y = 1;
+ blockDim.z = 1;
+ gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
+ gridDim.y = DIVIDE(size_m, m_count);
+ gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
+
+ fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
+
+// DBGX((uint64_t) b->cuda_q_perm);
+// DBGI(b->rows_4);
+// DBGI(b->height);
+
+ kernel<<>>
+ (
+ a,
+ b->cuda_q_weight,
+ b->cuda_gptq_qzeros,
+ b->cuda_gptq_scales,
+ c,
+ size_m,
+ size_n,
+ size_k,
+ b->groups,
+ b->groupsize,
+ b->cuda_q_perm,
+ b->rows_4,
+ clear
+ );
+ }
+}
+
+void gemm_half_q_half_cuda
+(
+ cublasHandle_t cublas_handle,
+ const half* a,
+ QMatrix* b,
+ half* c,
+ int size_m,
+ int size_n,
+ int size_k,
+ bool clear,
+ half* temp_dq,
+ bool force_cuda
+)
+{
+ if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
+ {
+ //printf("cublas\n");
+
+ // Reconstruct FP16 matrix, then cuBLAS
+
+ if (!temp_dq) temp_dq = b->temp_dq;
+ b->reconstruct(temp_dq);
+
+ //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
+
+ const half alpha = __float2half(1.0f);
+ const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
+ cublasHgemm(cublas_handle,
+ CUBLAS_OP_N,
+ CUBLAS_OP_N,
+ size_n, size_m, size_k,
+ &alpha, temp_dq, size_n,
+ a, size_k,
+ &beta, c, size_n);
+
+ //const float alpha = 1.0f;
+ //const float beta = clear ? 0.0f : 1.0f;
+ //cublasSgemmEx(cublas_handle,
+ // CUBLAS_OP_N,
+ // CUBLAS_OP_N,
+ // size_n, size_m, size_k,
+ // &alpha, temp_dq, CUDA_R_16F, size_n,
+ // a, CUDA_R_16F, size_k,
+ // &beta, c, CUDA_R_16F, size_n);
+
+ //const float alpha = 1.0f;
+ //const float beta = clear ? 0.0f : 1.0f;
+ //cublasGemmEx(cublas_handle,
+ // CUBLAS_OP_N, CUBLAS_OP_N,
+ // size_n, size_m, size_k,
+ // &alpha, temp_dq, CUDA_R_16F, size_n,
+ // a, CUDA_R_16F, size_k,
+ // &beta, c, CUDA_R_16F, size_n,
+ // CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
+ }
+ else
+ {
+ //printf("cuda\n");
+
+ // Quantized matmul
+
+ //if (clear) clear_tensor_cuda(c, size_m, size_n);
+
+ int max_chunks = size_m / BLOCK_M_SIZE_MAX;
+ int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
+ int last_chunk_size = size_m - last_chunk;
+
+ if (max_chunks)
+ {
+ gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
+ }
+
+ if (last_chunk_size)
+ {
+ gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
+ }
+ }
+}
+
+__global__ void clear_kernel
+(
+ half* __restrict__ c,
+ const int size_m,
+ const int size_n
+)
+{
+ int m = blockIdx.y;
+ int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
+ if (n >= size_n) return;
+ int4* c_ptr = (int4*)(c + m * size_n + n);
+ *c_ptr = {};
+}
+
+void clear_tensor_cuda
+(
+ half* c,
+ int size_m,
+ int size_n
+)
+{
+ return;
+ dim3 blockDim, gridDim;
+ blockDim.x = CLEAR_N_SIZE;
+ blockDim.y = 1;
+ gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
+ gridDim.y = size_m;
+ clear_kernel<<>>(c, size_m, size_n);
+}
diff --git a/auto_round_extension/cuda/exllamav2/cuda/q_gemm.cuh b/auto_round_extension/cuda/exllamav2/cuda/q_gemm.cuh
new file mode 100644
index 00000000..c69f1a70
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/q_gemm.cuh
@@ -0,0 +1,33 @@
+#ifndef _q_gemm_cuh
+#define _q_gemm_cuh
+
+#include
+#include
+#include
+#include
+#include
+
+#include "q_matrix.cuh"
+
+void gemm_half_q_half_cuda
+(
+ cublasHandle_t cublas_handle,
+ const half* a,
+ QMatrix* b,
+ half* c,
+ int size_m,
+ int size_n,
+ int size_k,
+ bool clear = false,
+ half* reconstruct = NULL,
+ bool force_cuda = false
+);
+
+void clear_tensor_cuda
+(
+ half* c,
+ int size_m,
+ int size_n
+);
+
+#endif
\ No newline at end of file
diff --git a/auto_round_extension/cuda/exllamav2/cuda/q_gemm_kernel.cuh b/auto_round_extension/cuda/exllamav2/cuda/q_gemm_kernel.cuh
new file mode 100644
index 00000000..0b899a84
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/q_gemm_kernel.cuh
@@ -0,0 +1,487 @@
+#include "compat.cuh"
+
+#include
+#include
+
+__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
+{
+ half2 result = {};
+ const half2* a2_ptr = (const half2*)a_ptr;
+ #pragma unroll
+ for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
+ return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
+}
+
+__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
+{
+ half2 result = {};
+ const half2* a2_ptr = (const half2*)a_ptr;
+ #pragma unroll
+ for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
+ return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
+}
+
+__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
+{
+ half2 result = {};
+ const half2* a2_ptr = (const half2*)a_ptr;
+ #pragma unroll
+ for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
+ return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
+}
+
+__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
+{
+ half2 result = {};
+ const half2* a2_ptr = (const half2*)a_ptr;
+ #pragma unroll
+ for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
+ float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
+ return fma(result_f, qs_f, g_result);
+}
+
+__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
+{
+ half2 result = {};
+ const half2* a2_ptr = (const half2*)a_ptr;
+ #pragma unroll
+ for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
+ float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
+ return fma(result_f, qs_f, g_result);
+}
+
+__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
+{
+ half2 result = {};
+ const half2* a2_ptr = (const half2*)a_ptr;
+ #pragma unroll
+ for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
+ float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
+ return fma(result_f, qs_f, g_result);
+}
+
+
+
+typedef void (*fp_gemm_half_q_half_kernel)
+(
+ const half*,
+ const uint32_t*,
+ const uint32_t*,
+ const half*,
+ half*,
+ const int,
+ const int,
+ const int,
+ const int,
+ const int,
+ const uint16_t*,
+ const int,
+ const int,
+ const int,
+ const int,
+ const int,
+ const int,
+ const bool
+);
+
+template
+__global__ void gemm_half_q_half_kernel
+(
+ const half* __restrict__ a,
+ const uint32_t* __restrict__ b_q_weight,
+ const uint32_t* __restrict__ b_q_scale,
+ const half* __restrict__ b_q_scale_max,
+ half* __restrict__ c,
+ const int size_m,
+ const int size_n,
+ const int size_k,
+ const int groups,
+ const int groupsize,
+ const uint16_t* __restrict__ b_q_perm,
+ const int rows_8,
+ const int rows_6,
+ const int rows_5,
+ const int rows_4,
+ const int rows_3,
+ const int rows_2,
+ const bool clear
+)
+{
+ MatrixView_half a_(a, size_m, size_k);
+ MatrixView_half_rw c_(c, size_m, size_n);
+ MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
+
+ int t = threadIdx.x;
+
+ // Block
+
+ int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
+ int offset_m = blockIdx.y * m_count;
+ int offset_k = blockIdx.z * BLOCK_KN_SIZE;
+
+ int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
+ int end_m = min(offset_m + m_count, size_m);
+ int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
+ int n = offset_n + t * 4;
+
+ // Preload block_a
+
+ __shared__ half block_a[m_count][BLOCK_KN_SIZE];
+
+ if (offset_k + t < end_k)
+ {
+ for (int m = 0; m < m_count; ++m)
+ {
+ const half* a_ptr = a_.item_ptr(offset_m + m, 0);
+ half* block_a_ptr = block_a[m];
+ half a0 = a_ptr[b_q_perm[offset_k + t]];
+ block_a_ptr[t] = a0;
+ }
+ }
+
+ // Clear
+
+ if (n >= size_n) return;
+
+ if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
+ {
+ for (int m = 0; m < m_count; m++)
+ *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
+ }
+
+ __syncthreads();
+
+ // Find initial group
+
+ int group = offset_k / groupsize;
+
+ // Preload scales
+
+ float scales[MAX_GROUPS_IN_BLOCK][4];
+
+ int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
+ for (int g = 0; g < groups_in_block; g++)
+ {
+ int qscales[4];
+ b_q_scale_.item4(qscales, group + g, n);
+ qscales[0]++;
+ qscales[1]++;
+ qscales[2]++;
+ qscales[3]++;
+ float maxscale = __half2float(b_q_scale_max[group + g]);
+ scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
+ scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
+ scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
+ scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
+ }
+
+ // a, b offset
+
+ int pre_rows_8 = min(rows_8, offset_k);
+ int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
+ int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
+ int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
+ int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
+ int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
+ int qk = 0;
+ qk += pre_rows_8 / 32 * 8;
+ qk += pre_rows_6 / 32 * 6;
+ qk += pre_rows_5 / 32 * 5;
+ qk += pre_rows_4 / 32 * 4;
+ qk += pre_rows_3 / 32 * 3;
+ qk += pre_rows_2 / 32 * 2;
+
+ const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
+ const half* a_ptr = &block_a[0][0];
+ int a_stride = BLOCK_KN_SIZE;
+
+ // Initial group
+
+ int scales_idx = 0;
+ float qs_f0 = scales[scales_idx][0];
+ float qs_f1 = scales[scales_idx][1];
+ float qs_f2 = scales[scales_idx][2];
+ float qs_f3 = scales[scales_idx][3];
+ int nextgroup = offset_k + groupsize;
+
+ // Column result
+
+ float block_c[m_count][4] = {};
+
+ // Dequantize groups
+
+ int k = offset_k;
+
+ while (k < rows_8 && k < end_k)
+ {
+ if (k == nextgroup)
+ {
+ group++;
+ scales_idx++;
+ qs_f0 = scales[scales_idx][0];
+ qs_f1 = scales[scales_idx][1];
+ qs_f2 = scales[scales_idx][2];
+ qs_f3 = scales[scales_idx][3];
+ nextgroup += groupsize;
+ }
+
+ #pragma unroll
+ for (int j = 0; j < 4; j++)
+ {
+ int4 load_int4[2];
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
+ load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
+
+ half2 dq[4][4];
+ dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
+ dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
+ dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
+ dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
+
+ for (int m = 0; m < m_count; m++)
+ {
+ block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
+ block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
+ block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
+ block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
+ }
+ a_ptr += 8;
+ }
+ k += 32;
+ }
+
+ while (k < rows_6 && k < end_k)
+ {
+ if (k == nextgroup)
+ {
+ group++;
+ scales_idx++;
+ qs_f0 = scales[scales_idx][0];
+ qs_f1 = scales[scales_idx][1];
+ qs_f2 = scales[scales_idx][2];
+ qs_f3 = scales[scales_idx][3];
+ nextgroup += groupsize;
+ }
+
+ #pragma unroll
+ for (int j = 0; j < 2; j++)
+ {
+ int4 load_int4[3];
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
+ load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
+ load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
+
+ half2 dq[4][8];
+ dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
+ dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
+ dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
+ dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
+
+ for (int m = 0; m < m_count; m++)
+ {
+ block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
+ block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
+ block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
+ block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
+ }
+ a_ptr += 16;
+ }
+ k += 32;
+ }
+
+ while (k < rows_5 && k < end_k)
+ {
+ if (k == nextgroup)
+ {
+ group++;
+ scales_idx++;
+ qs_f0 = scales[scales_idx][0];
+ qs_f1 = scales[scales_idx][1];
+ qs_f2 = scales[scales_idx][2];
+ qs_f3 = scales[scales_idx][3];
+ nextgroup += groupsize;
+ }
+
+ #pragma unroll
+ for (int j = 0; j < 1; j++)
+ {
+ int4 load_int4[5];
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
+ load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
+ load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
+ load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
+ load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
+
+ half2 dq[4][16];
+ dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
+ dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
+ dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
+ dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
+
+ for (int m = 0; m < m_count; m++)
+ {
+ block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
+ block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
+ block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
+ block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
+ }
+ a_ptr += 32;
+ }
+
+ k += 32;
+ }
+
+ while (k < rows_4 && k < end_k)
+ {
+ if (k == nextgroup)
+ {
+ group++;
+ scales_idx++;
+ qs_f0 = scales[scales_idx][0];
+ qs_f1 = scales[scales_idx][1];
+ qs_f2 = scales[scales_idx][2];
+ qs_f3 = scales[scales_idx][3];
+ nextgroup += groupsize;
+ }
+
+ #pragma unroll
+ for (int j = 0; j < 4; j++)
+ {
+ int4 load_int4[1];
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
+
+ half2 dq[4][4];
+ dequant_4bit_8(load_int4[0].x, dq[0], size_n);
+ dequant_4bit_8(load_int4[0].y, dq[1], size_n);
+ dequant_4bit_8(load_int4[0].z, dq[2], size_n);
+ dequant_4bit_8(load_int4[0].w, dq[3], size_n);
+
+ for (int m = 0; m < m_count; m++)
+ {
+ block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
+ block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
+ block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
+ block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
+ }
+ a_ptr += 8;
+ }
+ k += 32;
+ }
+
+ while (k < rows_3 && k < end_k)
+ {
+ if (k == nextgroup)
+ {
+ group++;
+ scales_idx++;
+ qs_f0 = scales[scales_idx][0];
+ qs_f1 = scales[scales_idx][1];
+ qs_f2 = scales[scales_idx][2];
+ qs_f3 = scales[scales_idx][3];
+ nextgroup += groupsize;
+ }
+
+ #pragma unroll
+ for (int j = 0; j < 1; j++)
+ {
+ int4 load_int4[3];
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
+ load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
+ load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
+
+ half2 dq[4][16];
+ dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
+ dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
+ dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
+ dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
+
+ for (int m = 0; m < m_count; m++)
+ {
+ block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
+ block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
+ block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
+ block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
+ }
+ a_ptr += 32;
+ }
+ k += 32;
+ }
+
+ while (k < rows_2 && k < end_k)
+ {
+ if (k == nextgroup)
+ {
+ group++;
+ scales_idx++;
+ qs_f0 = scales[scales_idx][0];
+ qs_f1 = scales[scales_idx][1];
+ qs_f2 = scales[scales_idx][2];
+ qs_f3 = scales[scales_idx][3];
+ nextgroup += groupsize;
+ }
+
+ #pragma unroll
+ for (int j = 0; j < 2; j++)
+ {
+ int4 load_int4[1];
+ load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
+
+ half2 dq[4][8];
+ dequant_2bit_16(load_int4[0].x, dq[0], size_n);
+ dequant_2bit_16(load_int4[0].y, dq[1], size_n);
+ dequant_2bit_16(load_int4[0].z, dq[2], size_n);
+ dequant_2bit_16(load_int4[0].w, dq[3], size_n);
+
+ for (int m = 0; m < m_count; m++)
+ {
+ block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
+ block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
+ block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
+ block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
+ }
+
+ a_ptr += 16;
+ }
+ k += 32;
+ }
+
+ // Accumulate column sums in c
+
+ for (int m = 0; m < m_count; m++)
+ {
+ half2* out = (half2*)c_.item_ptr(offset_m + m, n);
+ half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
+ half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
+ atomicAdd(out , result01);
+ atomicAdd(out + 1, result23);
+ }
+}
+
+fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
+{
+ #if BLOCK_M_SIZE_MAX >= 1
+ if (m_count == 1) return gemm_half_q_half_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 2
+ if (m_count == 2) return gemm_half_q_half_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 3
+ if (m_count == 3) return gemm_half_q_half_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 4
+ if (m_count == 4) return gemm_half_q_half_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 5
+ if (m_count == 5) return gemm_half_q_half_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 6
+ if (m_count == 6) return gemm_half_q_half_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 7
+ if (m_count == 7) return gemm_half_q_half_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 8
+ if (m_count == 8) return gemm_half_q_half_kernel;
+ #endif
+ return NULL;
+}
diff --git a/auto_round_extension/cuda/exllamav2/cuda/q_gemm_kernel_gptq.cuh b/auto_round_extension/cuda/exllamav2/cuda/q_gemm_kernel_gptq.cuh
new file mode 100644
index 00000000..4b722ef5
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/q_gemm_kernel_gptq.cuh
@@ -0,0 +1,223 @@
+#include "compat.cuh"
+
+__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
+{
+ half2 result = {};
+ const half2* a2_ptr = (const half2*)a_ptr;
+ #pragma unroll
+ for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
+ return __hadd2(result, g_result);
+}
+
+__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
+{
+ half2 result = {};
+ const half2* a2_ptr = (const half2*)a_ptr;
+ #pragma unroll
+ for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
+ return __half2float(__low2half(result)) + __half2float(__high2half(result));
+}
+
+typedef void (*fp_gemm_half_q_half_gptq_kernel)
+(
+ const half*,
+ const uint32_t*,
+ const uint32_t*,
+ const half*,
+ half*,
+ const int,
+ const int,
+ const int,
+ const int,
+ const int,
+ const uint16_t*,
+ const int,
+ const bool
+);
+
+template
+__global__ void gemm_half_q_half_gptq_kernel
+(
+ const half* __restrict__ a,
+ const uint32_t* __restrict__ b_q_weight,
+ const uint32_t* __restrict__ b_gptq_qzeros,
+ const half* __restrict__ b_gptq_scales,
+ half* __restrict__ c,
+ const int size_m,
+ const int size_n,
+ const int size_k,
+ const int groups,
+ const int groupsize,
+ const uint16_t* __restrict__ b_q_perm,
+ const int rows_4,
+ const bool clear
+)
+{
+ MatrixView_half a_(a, size_m, size_k);
+ MatrixView_half_rw c_(c, size_m, size_n);
+ MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
+ MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
+
+ int t = threadIdx.x;
+
+ // Block
+
+ int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
+ int offset_m = blockIdx.y * m_count;
+ int offset_k = blockIdx.z * BLOCK_KN_SIZE;
+
+ int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
+ int end_m = min(offset_m + m_count, size_m);
+ int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
+
+ int n = offset_n + t * 4;
+
+ // Preload block_a
+
+ __shared__ half block_a[m_count][BLOCK_KN_SIZE];
+
+ if (offset_k + t < end_k)
+ {
+ for (int m = 0; m < m_count; ++m)
+ {
+ const half* a_ptr = a_.item_ptr(offset_m + m, 0);
+ half* block_a_ptr = block_a[m];
+
+ half a0;
+ if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
+ else a0 = a_ptr[offset_k + t];
+ block_a_ptr[t] = a0;
+ }
+ }
+
+ // Zero output
+
+ if (n >= size_n) return;
+
+ if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
+ {
+ for (int m = 0; m < m_count; m++)
+ *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
+ }
+
+ __syncthreads();
+
+ // Find initial group
+
+ int group = offset_k / groupsize;
+ int nextgroup = offset_k + groupsize;
+
+ // a, b offset
+
+ int qk = offset_k / (32 / 4);
+
+ const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
+ const half* a_ptr = &block_a[0][0];
+ int a_stride = BLOCK_KN_SIZE;
+
+ // Initial group
+
+ int zeros[4];
+ float scales[4];
+ half2 z1z16[4][2];
+ half2 y1y16[4][2];
+ b_gptq_qzeros_.item4(zeros, group, n);
+ b_gptq_scales_.item4_f(scales, group, n);
+
+ // Avoid zeros overflow with & 0x0f.
+ dequant_4bit_8_prep_zero((zeros[0]) & 0x0f, z1z16[0], y1y16[0]);
+ dequant_4bit_8_prep_zero((zeros[1]) & 0x0f, z1z16[1], y1y16[1]);
+ dequant_4bit_8_prep_zero((zeros[2]) & 0x0f, z1z16[2], y1y16[2]);
+ dequant_4bit_8_prep_zero((zeros[3]) & 0x0f, z1z16[3], y1y16[3]);
+
+// __syncthreads();
+
+ // Column result
+
+ float block_c[m_count][4] = {};
+
+ // Dequantize and multiply
+
+ int k = offset_k;
+ while (k < end_k)
+ {
+ if (k == nextgroup)
+ {
+ group++;
+ nextgroup += groupsize;
+ b_gptq_qzeros_.item4(zeros, group, n);
+ b_gptq_scales_.item4_f(scales, group, n);
+
+ // Avoid zeros overflow with & 0x0f.
+ dequant_4bit_8_prep_zero((zeros[0]) & 0x0f, z1z16[0], y1y16[0]);
+ dequant_4bit_8_prep_zero((zeros[1]) & 0x0f, z1z16[1], y1y16[1]);
+ dequant_4bit_8_prep_zero((zeros[2]) & 0x0f, z1z16[2], y1y16[2]);
+ dequant_4bit_8_prep_zero((zeros[3]) & 0x0f, z1z16[3], y1y16[3]);
+ }
+
+ #pragma unroll
+ for (int j = 0; j < 4; j++)
+ {
+ const int4* b_ptr4 = (int4*) b_ptr;
+ int4 load_int4 = *b_ptr4;
+
+ half2 dq[4][4];
+ dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
+ dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
+ dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
+ dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
+
+ #pragma unroll
+ for (int m = 0; m < m_count; m++)
+ {
+ block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
+ block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
+ block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
+ block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
+ }
+
+ b_ptr += size_n;
+ a_ptr += 8;
+ }
+
+ k += 32;
+ }
+
+ for (int m = 0; m < m_count; m++)
+ {
+ half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
+ half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
+ half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
+ atomicAdd(out , result01);
+ atomicAdd(out + 1, result23);
+ }
+}
+
+fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
+{
+ #if BLOCK_M_SIZE_MAX >= 1
+ if (m_count == 1) return gemm_half_q_half_gptq_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 2
+ if (m_count == 2) return gemm_half_q_half_gptq_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 3
+ if (m_count == 3) return gemm_half_q_half_gptq_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 4
+ if (m_count == 4) return gemm_half_q_half_gptq_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 5
+ if (m_count == 5) return gemm_half_q_half_gptq_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 6
+ if (m_count == 6) return gemm_half_q_half_gptq_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 7
+ if (m_count == 7) return gemm_half_q_half_gptq_kernel;
+ #endif
+ #if BLOCK_M_SIZE_MAX >= 8
+ if (m_count == 8) return gemm_half_q_half_gptq_kernel;
+ #endif
+ return NULL;
+}
diff --git a/auto_round_extension/cuda/exllamav2/cuda/q_matrix.cu b/auto_round_extension/cuda/exllamav2/cuda/q_matrix.cu
new file mode 100644
index 00000000..aebba7b0
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/q_matrix.cu
@@ -0,0 +1,627 @@
+#include "q_matrix.cuh"
+#include "matrix_view.cuh"
+#include "util.cuh"
+
+#include "quant/qdq_2.cuh"
+#include "quant/qdq_3.cuh"
+#include "quant/qdq_4.cuh"
+#include "quant/qdq_5.cuh"
+#include "quant/qdq_6.cuh"
+#include "quant/qdq_8.cuh"
+
+#define BLOCK_KN_SIZE 128
+
+#define THREADS_X 32
+#define THREADS_Y 32
+
+// Shuffle quantized data on load
+
+__global__ void shuffle_kernel
+(
+ uint32_t* __restrict__ b_q_weight,
+ const int size_k,
+ const int size_n,
+ const int rows_8,
+ const int rows_6,
+ const int rows_5,
+ const int rows_4,
+ const int rows_3,
+ const int rows_2
+)
+{
+ int n = blockIdx.x * THREADS_X + threadIdx.x;
+ if (n >= size_n) return;
+ int k = 0;
+ uint32_t* b_ptr = b_q_weight + n;
+ while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
+ while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
+ while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
+ while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
+ while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
+ while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
+}
+
+
+// QMatrix constructor
+
+QMatrix::QMatrix
+(
+ const int _device,
+ const int _height,
+ const int _width,
+ const int _groups,
+
+ uint32_t* _q_weight,
+ uint16_t* _q_perm,
+ uint16_t* _q_invperm,
+ uint32_t* _q_scale,
+ half* _q_scale_max,
+ uint16_t* _q_groups,
+
+ uint32_t* _gptq_qzeros,
+ half* _gptq_scales,
+ uint32_t* _gptq_g_idx,
+
+ half* _temp_dq
+) :
+ device(_device),
+ height(_height),
+ width(_width),
+ groups(_groups),
+ temp_dq(_temp_dq)
+{
+ cudaSetDevice(device);
+
+ failed = false;
+
+ cuda_q_weight = _q_weight;
+ cuda_q_perm = _q_perm;
+ cuda_q_invperm = _q_invperm;
+ cuda_q_scale = _q_scale;
+ cuda_q_scale_max = _q_scale_max;
+ cuda_q_groups = _q_groups;
+ cuda_gptq_qzeros = _gptq_qzeros;
+ cuda_gptq_scales = _gptq_scales;
+
+ is_gptq = (_gptq_qzeros != NULL);
+
+ groupsize = 1;
+ while (groupsize * groups < height) groupsize *= 2;
+
+ // Create group map
+
+ rows_8 = 0;
+ rows_6 = 0;
+ rows_5 = 0;
+ rows_4 = 0;
+ rows_3 = 0;
+ rows_2 = 0;
+
+ if (!is_gptq)
+ {
+ uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
+ cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
+
+ for (int i = 0; i < groups; i++)
+ {
+ int bits = cpu_q_groups[i * 2];
+ if (bits == 8) rows_8 += groupsize;
+ if (bits == 6) rows_6 += groupsize;
+ if (bits == 5) rows_5 += groupsize;
+ if (bits == 4) rows_4 += groupsize;
+ if (bits == 3) rows_3 += groupsize;
+ if (bits == 2) rows_2 += groupsize;
+ }
+
+ free(cpu_q_groups);
+
+ rows_6 += rows_8;
+ rows_5 += rows_6;
+ rows_4 += rows_5;
+ rows_3 += rows_4;
+ rows_2 += rows_3;
+ }
+ else
+ {
+ rows_4 = height;
+ rows_3 = height;
+ rows_2 = height;
+
+ if (_gptq_g_idx)
+ {
+ if (!make_sequential(_gptq_g_idx))
+ {
+ failed = true;
+ //printf("FAIL\n");
+ return;
+ }
+ }
+ }
+
+ // Shuffle quantized data
+
+ dim3 blockDim, gridDim;
+ blockDim.x = THREADS_X;
+ blockDim.y = 1;
+ gridDim.x = DIVIDE(width, THREADS_X);
+ gridDim.y = 1;
+
+ shuffle_kernel<<>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
+}
+
+QMatrix::~QMatrix()
+{
+}
+
+// Reconstruct b[k,n] (GPTQ)
+
+__global__ void reconstruct_gptq_kernel
+(
+ const uint32_t* __restrict__ b_q_weight,
+ const uint16_t* __restrict__ b_q_perm,
+ const uint32_t* __restrict__ b_gptq_qzeros,
+ const half* __restrict__ b_gptq_scales,
+ //const uint16_t* __restrict__ b_q_groups,
+ const int size_k,
+ const int size_n,
+ const int groupsize,
+ const int groups,
+ half* __restrict__ b,
+ const int rows_4
+)
+{
+ MatrixView_half_rw b_(b, size_k, size_n);
+ MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
+ MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
+
+ int offset_k = BLOCK_KN_SIZE * blockIdx.y;
+ int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
+
+ int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
+
+ // Preload remapping table
+
+ __shared__ uint16_t perm[BLOCK_KN_SIZE];
+ int t = threadIdx.x;
+
+ if (b_q_perm)
+ {
+ if (offset_k + t < size_k)
+ perm[t] = b_q_perm[offset_k + t];
+ }
+
+ // Column
+
+ int n = offset_n + t * 4;
+ if (n >= size_n) return;
+
+ // Find initial group
+
+ int group = offset_k / groupsize;
+ int nextgroup = offset_k + groupsize;
+
+ // b offset
+
+ int qk = offset_k / (32 / 4);
+
+ const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
+
+ // Initial zeros/scale
+
+ int zeros[4];
+ half2 scales[4];
+ half2 z1z16[4][2];
+ half2 y1y16[4][2];
+ b_gptq_qzeros_.item4(zeros, group, n);
+ b_gptq_scales_.item4_h2(scales, group, n);
+
+ // Avoid zeros overflow with & 0x0f.
+ dequant_4bit_8_prep_zero((zeros[0]) & 0x0f, z1z16[0], y1y16[0]);
+ dequant_4bit_8_prep_zero((zeros[1]) & 0x0f, z1z16[1], y1y16[1]);
+ dequant_4bit_8_prep_zero((zeros[2]) & 0x0f, z1z16[2], y1y16[2]);
+ dequant_4bit_8_prep_zero((zeros[3]) & 0x0f, z1z16[3], y1y16[3]);
+
+ __syncthreads();
+
+ int k = offset_k;
+ int lk = 0;
+
+ while (k < end_k)
+ {
+ if (k == nextgroup)
+ {
+ group++;
+ nextgroup += groupsize;
+ b_gptq_qzeros_.item4(zeros, group, n);
+ b_gptq_scales_.item4_h2(scales, group, n);
+
+ // Avoid zeros overflow with & 0x0f.
+ dequant_4bit_8_prep_zero((zeros[0]) & 0x0f, z1z16[0], y1y16[0]);
+ dequant_4bit_8_prep_zero((zeros[1]) & 0x0f, z1z16[1], y1y16[1]);
+ dequant_4bit_8_prep_zero((zeros[2]) & 0x0f, z1z16[2], y1y16[2]);
+ dequant_4bit_8_prep_zero((zeros[3]) & 0x0f, z1z16[3], y1y16[3]);
+ }
+
+ for (int p = 0; p < 4; p++)
+ {
+ half2 dq[4][4];
+ const int4* b_ptr4 = (int4*) b_ptr;
+ int4 load_int4 = *b_ptr4;
+
+ dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
+ dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
+ dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
+ dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
+
+ b_ptr += size_n;
+ //half* dqh = (half*)dq;
+ if (b_q_perm)
+ {
+ for (int j = 0; j < 4; j++)
+ {
+ for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
+ b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
+ b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
+ }
+ }
+ else
+ {
+ for (int j = 0; j < 4; j++)
+ {
+ for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
+ b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
+ b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
+ }
+ }
+ }
+ k += 32;
+ }
+}
+
+
+// Reconstruct b[k,n]
+
+__global__ void reconstruct_kernel
+(
+ const uint32_t* __restrict__ b_q_weight,
+ const uint16_t* __restrict__ b_q_perm,
+ const uint32_t* __restrict__ b_q_scale,
+ const half* __restrict__ b_q_scale_max,
+ //const uint16_t* __restrict__ b_q_groups,
+ const int size_k,
+ const int size_n,
+ const int groupsize,
+ const int groups,
+ half* __restrict__ b,
+ const int rows_8,
+ const int rows_6,
+ const int rows_5,
+ const int rows_4,
+ const int rows_3,
+ const int rows_2
+)
+{
+ MatrixView_half_rw b_(b, size_k, size_n);
+ MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
+
+ int offset_k = BLOCK_KN_SIZE * blockIdx.y;
+ int offset_n = BLOCK_KN_SIZE * blockIdx.x;
+
+ // Preload remapping table
+
+ int t = threadIdx.x;
+ __shared__ uint16_t perm[BLOCK_KN_SIZE];
+ if (offset_k + t < size_k)
+ perm[t] = b_q_perm[offset_k + t];
+
+ // Column
+
+ int n = offset_n + t;
+ if (n >= size_n) return;
+
+ // Find initial group
+
+ int group = offset_k / groupsize;
+
+ int pre_rows_8 = min(rows_8, offset_k);
+ int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
+ int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
+ int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
+ int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
+ int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
+ int qk = 0;
+ qk += pre_rows_8 / 32 * 8;
+ qk += pre_rows_6 / 32 * 6;
+ qk += pre_rows_5 / 32 * 5;
+ qk += pre_rows_4 / 32 * 4;
+ qk += pre_rows_3 / 32 * 3;
+ qk += pre_rows_2 / 32 * 2;
+
+ const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
+
+ half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
+ half2 qs_h2 = __halves2half2(qs_h, qs_h);
+ int nextgroup = offset_k + groupsize;
+
+ int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
+ int k = offset_k;
+ int lk = 0;
+
+ __syncthreads();
+
+ while (k < rows_8 && k < end_k)
+ {
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
+ for (int p = 0; p < 4; p++)
+ {
+ half2 dq[4];
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
+ uint32_t q_1 = *b_ptr; b_ptr += size_n;
+ dequant_8bit_8(q_0, q_1, dq, size_n);
+ for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
+ half* dqh = (half*) dq;
+ for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
+ }
+ k += 32;
+ }
+
+ while (k < rows_6 && k < end_k)
+ {
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
+ for (int p = 0; p < 2; p++)
+ {
+ half2 dq[8];
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
+ uint32_t q_1 = *b_ptr; b_ptr += size_n;
+ uint32_t q_2 = *b_ptr; b_ptr += size_n;
+ dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
+ for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
+ half* dqh = (half*) dq;
+ for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
+ }
+ k += 32;
+ }
+
+ while (k < rows_5 && k < end_k)
+ {
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
+ for (int p = 0; p < 1; p++)
+ {
+ half2 dq[16];
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
+ uint32_t q_1 = *b_ptr; b_ptr += size_n;
+ uint32_t q_2 = *b_ptr; b_ptr += size_n;
+ uint32_t q_3 = *b_ptr; b_ptr += size_n;
+ uint32_t q_4 = *b_ptr; b_ptr += size_n;
+ dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
+ for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
+ half* dqh = (half*) dq;
+ for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
+ }
+ k += 32;
+ }
+
+ while (k < rows_4 && k < end_k)
+ {
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
+ for (int p = 0; p < 4; p++)
+ {
+ half2 dq[4];
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
+ dequant_4bit_8(q_0, dq, size_n);
+ for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
+ half* dqh = (half*) dq;
+ for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
+ }
+ k += 32;
+ }
+
+ while (k < rows_3 && k < end_k)
+ {
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
+ for (int p = 0; p < 1; p++)
+ {
+ half2 dq[16];
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
+ uint32_t q_1 = *b_ptr; b_ptr += size_n;
+ uint32_t q_2 = *b_ptr; b_ptr += size_n;
+ dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
+ for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
+ half* dqh = (half*) dq;
+ for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
+ }
+ k += 32;
+ }
+
+ while (k < rows_2 && k < end_k)
+ {
+ if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
+ for (int p = 0; p < 2; p++)
+ {
+ half2 dq[8];
+ uint32_t q_0 = *b_ptr; b_ptr += size_n;
+ dequant_2bit_16(q_0, dq, size_n);
+ for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
+ half* dqh = (half*) dq;
+ for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
+ }
+ k += 32;
+ }
+}
+
+void QMatrix::reconstruct(half* out)
+{
+ dim3 blockDim, gridDim;
+ blockDim.x = BLOCK_KN_SIZE;
+ blockDim.y = 1;
+ gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
+
+ if (!is_gptq)
+ {
+ gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
+ reconstruct_kernel<<>>
+ (
+ cuda_q_weight,
+ cuda_q_perm,
+ cuda_q_scale,
+ cuda_q_scale_max,
+ //cuda_q_groups,
+ height,
+ width,
+ groupsize,
+ groups,
+ out,
+ rows_8,
+ rows_6,
+ rows_5,
+ rows_4,
+ rows_3,
+ rows_2
+ );
+ }
+ else
+ {
+ gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
+ reconstruct_gptq_kernel<<>>
+ (
+ cuda_q_weight,
+ cuda_q_perm,
+ cuda_gptq_qzeros,
+ cuda_gptq_scales,
+ //const uint16_t* __restrict__ b_q_groups,
+ height,
+ width,
+ groupsize,
+ groups,
+ out,
+ rows_4
+ );
+ }
+}
+
+__global__ void make_sequential_kernel
+(
+ const uint32_t* __restrict__ w,
+ uint32_t* __restrict__ w_new,
+ const uint16_t* __restrict__ q_perm,
+ const int w_height,
+ const int w_width
+)
+{
+ const uint64_t* w2 = (uint64_t*) w;
+ uint64_t* w_new2 = (uint64_t*) w_new;
+ int w2_stride = w_width >> 1;
+
+ int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
+ if (w2_column >= w2_stride) return;
+
+ int w_new2_row = blockIdx.y;
+
+ int q_perm_idx = w_new2_row << 3;
+
+ uint64_t dst = 0;
+
+ #pragma unroll
+ for (int i = 0; i < 8; i++)
+ {
+ int source_row = q_perm[q_perm_idx++];
+
+ int w2_row = source_row >> 3;
+ int w2_subrow = source_row & 0x07;
+ int w2_row_shift = w2_subrow << 2;
+ int wnew2_row_shift = i << 2;
+
+ uint64_t src = w2[w2_row * w2_stride + w2_column];
+ src >>= w2_row_shift;
+ src &= 0x0000000f0000000f;
+ src <<= wnew2_row_shift;
+ dst |= src;
+ }
+
+ w_new2[w_new2_row * w2_stride + w2_column] = dst;
+}
+
+bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
+{
+ uint32_t* cuda_new_qweight = NULL;
+ cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
+ if (err != cudaSuccess) {
+ cudaError_t cuda_status = cudaGetLastError(); // Clear error
+ return false;
+ }
+
+ uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
+ uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
+ uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
+
+ // Group histogram
+
+ for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
+
+ // Group map
+
+ for (int i = 0, acc = 0; i < groups; i++)
+ {
+ short tmp = cpu_g_idx_map[i];
+ cpu_g_idx_map[i] = acc;
+ acc += tmp;
+ }
+
+ // X map (inverse)
+
+ for (int row = 0; row < height; row++)
+ {
+ uint32_t target_group = cpu_g_idx[row];
+ uint32_t target_row = cpu_g_idx_map[target_group];
+ cpu_g_idx_map[target_group]++;
+ cpu_x_map_inv[row] = target_row;
+ }
+
+ // X map
+
+ for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
+
+ // Reduce to uint16_t
+
+ uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
+ uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
+ for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
+ for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
+
+ // Move to CUDA
+
+ cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
+ cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
+
+ // Rearrange rows in w
+
+ dim3 blockDim, gridDim;
+ blockDim.x = THREADS_X;
+ blockDim.y = 1;
+ gridDim.x = DIVIDE(width, THREADS_X);
+ gridDim.y = height / 8;
+
+ make_sequential_kernel<<>>
+ (
+ cuda_q_weight,
+ cuda_new_qweight,
+ cuda_q_perm,
+ height / 8,
+ width
+ );
+
+ // Replace qweights
+
+ cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
+
+ // Cleanup
+
+ cudaDeviceSynchronize();
+
+ cudaFree(cuda_new_qweight);
+ free(cpu_g_idx_map);
+ free(cpu_x_map);
+ free(cpu_x_map_inv);
+
+ return true;
+}
diff --git a/auto_round_extension/cuda/exllamav2/cuda/q_matrix.cuh b/auto_round_extension/cuda/exllamav2/cuda/q_matrix.cuh
new file mode 100644
index 00000000..dda83a4f
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/q_matrix.cuh
@@ -0,0 +1,73 @@
+#ifndef _q_matrix_cuh
+#define _q_matrix_cuh
+
+#include
+#include
+#include
+#include
+
+#define MAX_SUPERGROUPS 16
+
+class QMatrix
+{
+public:
+
+ int device;
+ bool is_gptq;
+
+ int height;
+ int width;
+ int groups;
+ int groupsize;
+
+ int rows_8;
+ int rows_6;
+ int rows_5;
+ int rows_4;
+ int rows_3;
+ int rows_2;
+
+ uint32_t* cuda_q_weight = NULL;
+ uint16_t* cuda_q_perm = NULL;
+ uint16_t* cuda_q_invperm = NULL;
+ uint32_t* cuda_q_scale = NULL;
+ half* cuda_q_scale_max = NULL;
+ uint16_t* cuda_q_groups = NULL;
+ uint32_t* cuda_gptq_qzeros = NULL;
+ half* cuda_gptq_scales = NULL;
+
+ half* temp_dq;
+
+ bool failed;
+
+ QMatrix
+ (
+ const int _device,
+ const int _height,
+ const int _width,
+ const int _groups,
+
+ uint32_t* _q_weight,
+ uint16_t* _q_perm,
+ uint16_t* _q_invperm,
+ uint32_t* _q_scale,
+ half* _q_scale_max,
+ uint16_t* _q_groups,
+
+ uint32_t* _gptq_qzeros,
+ half* _gptq_scales,
+ uint32_t* _gptq_g_idx,
+
+ half* _temp_dq
+ );
+
+ ~QMatrix();
+
+ void reconstruct(half* out);
+ bool make_sequential(const uint32_t* cpu_g_idx);
+
+private:
+
+};
+
+#endif
diff --git a/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_2.cuh b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_2.cuh
new file mode 100644
index 00000000..3beaeefa
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_2.cuh
@@ -0,0 +1,103 @@
+#ifndef _qdq_2_cuh
+#define _qdq_2_cuh
+
+#include "qdq_util.cuh"
+#include "../../config.h"
+
+#if QMODE_2BIT == 1
+
+// Permutation:
+//
+// ffddbb99 77553311 eeccaa88 66442200
+
+__forceinline__ __device__ void shuffle_2bit_16
+(
+ uint32_t* q,
+ int stride
+)
+{
+ uint32_t qa = q[0];
+ uint32_t qb = 0;
+
+ #pragma unroll
+ for (int i = 0; i < 8; i++)
+ {
+ uint32_t qa0 = qa & 0x03;
+ uint32_t qa1 = (qa & 0x0c) >> 2;
+ qa >>= 4;
+ qb |= (qa1 << (i * 2 + 16));
+ qb |= (qa0 << (i * 2));
+ }
+ q[0] = qb;
+}
+
+__forceinline__ __device__ void dequant_2bit_16
+(
+ const uint32_t q_0,
+ half2 (&dq)[8],
+ int stride
+)
+{
+ const uint32_t c0 = 0x64006400;
+ const half y4_ = __float2half_rn(1.0f / 4.0f);
+ const half y16_ = __float2half_rn(1.0f / 16.0f);
+ const half y64_ = __float2half_rn(1.0f / 64.0f);
+ const half2 y4 = __halves2half2(y4_, y4_);
+ const half2 y16 = __halves2half2(y16_, y16_);
+ const half2 y64 = __halves2half2(y64_, y64_);
+ const half z1_ = __float2half_rn(-1024.0f - 2.0f);
+ const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
+ const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
+ const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
+ const half2 z1 = __halves2half2(z1_, z1_);
+ const half2 z4 = __halves2half2(z4_, z4_);
+ const half2 z16 = __halves2half2(z16_, z16_);
+ const half2 z64 = __halves2half2(z64_, z64_);
+
+ uint32_t qa = q_0;
+ half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
+ half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
+ half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
+ half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
+ qa >>= 8;
+ half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
+ half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
+ half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
+ half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
+
+ dq[0] = __hadd2(q0.as_half2, z1);
+ dq[1] = __hfma2(q1.as_half2, y4, z4);
+ dq[2] = __hfma2(q2.as_half2, y16, z16);
+ dq[3] = __hfma2(q3.as_half2, y64, z64);
+ dq[4] = __hadd2(q4.as_half2, z1);
+ dq[5] = __hfma2(q5.as_half2, y4, z4);
+ dq[6] = __hfma2(q6.as_half2, y16, z16);
+ dq[7] = __hfma2(q7.as_half2, y64, z64);
+}
+
+#else
+
+__forceinline__ __device__ void shuffle_2bit_16
+(
+ uint32_t* q,
+ int stride
+)
+{
+}
+
+__forceinline__ __device__ void dequant_2bit_16
+(
+ const uint32_t q_0,
+ half2 (&dq)[8],
+ int stride
+)
+{
+ half dqh[16];
+ for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
+
+ for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
+}
+
+#endif
+
+#endif
\ No newline at end of file
diff --git a/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_3.cuh b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_3.cuh
new file mode 100644
index 00000000..10117376
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_3.cuh
@@ -0,0 +1,169 @@
+#ifndef _qdq_3_cuh
+#define _qdq_3_cuh
+
+#include "qdq_util.cuh"
+#include "../../config.h"
+
+#if QMODE_3BIT == 1
+
+// Permutation:
+//
+// v9997775 55333111 u8886664 44222000 (u, v lsb)
+// vjjjhhhf ffdddbbb uiiiggge eecccaaa
+// vtttrrrp ppnnnlll usssqqqo oommmkkk
+
+__forceinline__ __device__ void shuffle_3bit_32
+(
+ uint32_t* q,
+ int stride
+)
+{
+ uint32_t qa = q[0 * stride];
+ uint32_t qb = q[1 * stride];
+ uint32_t qc = q[2 * stride];
+
+ // qa: aa999888 77766655 54443332 22111000
+ // qb: lkkkjjji iihhhggg fffeeedd dcccbbba
+ // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
+
+ uint32_t qd = qc >> 26;
+ qc <<= 4;
+ qc |= qb >> 28;
+ qb <<= 2;
+ qb |= qa >> 30;
+
+ // qa: ..999888 77766655 54443332 22111000
+ // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
+ // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
+ // qd: vvvuuu
+
+ uint32_t za = 0;
+ uint32_t zb = 0;
+ uint32_t zc = 0;
+
+ for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
+ for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
+ for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
+
+ // za: 9997775 55333111 8886664 44222000
+ // zb: jjjhhhf ffdddbbb iiiggge eecccaaa
+ // zc: tttrrrp ppnnnlll sssqqqo oommmkkk
+ // qd: vvvuuu
+
+ za |= ((qd & 0x01) >> 0) << 15;
+ zb |= ((qd & 0x02) >> 1) << 15;
+ zc |= ((qd & 0x04) >> 2) << 15;
+ za |= ((qd & 0x08) >> 3) << 31;
+ zb |= ((qd & 0x10) >> 4) << 31;
+ zc |= ((qd & 0x20) >> 5) << 31;
+
+ // za: v9997775 55333111 u8886664 44222000 (u, v lsb)
+ // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
+ // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
+
+ q[0 * stride] = za;
+ q[1 * stride] = zb;
+ q[2 * stride] = zc;
+}
+
+__forceinline__ __device__ void dequant_3bit_32
+(
+ const uint32_t q_0,
+ const uint32_t q_1,
+ const uint32_t q_2,
+ half2 (&dq)[16],
+ int stride
+)
+{
+ const uint32_t c0 = 0x64006400;
+ const half y8_ = __float2half_rn(1.0f / 8.0f);
+ const half y64_ = __float2half_rn(1.0f / 64.0f);
+ const half2 y8 = __halves2half2(y8_, y8_);
+ const half2 y64 = __halves2half2(y64_, y64_);
+ const half z1_ = __float2half_rn(-1024.0f - 4.0f);
+ const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
+ const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
+ const half2 z1 = __halves2half2(z1_, z1_);
+ const half2 z8 = __halves2half2(z8_, z8_);
+ const half2 z64 = __halves2half2(z64_, z64_);
+
+ uint32_t qa = q_0;
+ uint32_t qb = q_1;
+ uint32_t qc = q_2;
+
+ half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
+ half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
+ qa >>= 6;
+ half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
+ half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
+ half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
+ qa >>= 9;
+ qa &= 0x00010001;
+ half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
+ half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
+ qb >>= 6;
+ half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
+ half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
+ half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
+ qb >>= 8;
+ qb &= 0x00020002;
+ half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
+ half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
+ qc >>= 6;
+ half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
+ half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
+ half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
+ qc >>= 7;
+ qc &= 0x00040004;
+ half2_uint32 q15((qa | qb | qc) | c0);
+
+ dq[ 0] = __hadd2( q0.as_half2, z1);
+ dq[ 1] = __hfma2( q1.as_half2, y8, z8);
+ dq[ 2] = __hadd2( q2.as_half2, z1);
+ dq[ 3] = __hfma2( q3.as_half2, y8, z8);
+ dq[ 4] = __hfma2( q4.as_half2, y64, z64);
+ dq[ 5] = __hadd2( q5.as_half2, z1);
+ dq[ 6] = __hfma2( q6.as_half2, y8, z8);
+ dq[ 7] = __hadd2( q7.as_half2, z1);
+ dq[ 8] = __hfma2( q8.as_half2, y8, z8);
+ dq[ 9] = __hfma2( q9.as_half2, y64, z64);
+ dq[10] = __hadd2(q10.as_half2, z1);
+ dq[11] = __hfma2(q11.as_half2, y8, z8);
+ dq[12] = __hadd2(q12.as_half2, z1);
+ dq[13] = __hfma2(q13.as_half2, y8, z8);
+ dq[14] = __hfma2(q14.as_half2, y64, z64);
+ dq[15] = __hadd2(q15.as_half2, z1);
+}
+
+#else
+
+__forceinline__ __device__ void shuffle_3bit_32
+(
+ uint32_t* q,
+ int stride
+)
+{
+}
+
+__forceinline__ __device__ void dequant_3bit_32
+(
+ const uint32_t q_0,
+ const uint32_t q_1,
+ const uint32_t q_2,
+ half2 (&dq)[16],
+ int stride
+)
+{
+ half dqh[32];
+ for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
+ dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
+ for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
+ dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
+ for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
+
+ for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
+}
+
+#endif
+
+#endif
diff --git a/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_4.cuh b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_4.cuh
new file mode 100644
index 00000000..5fb070d0
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_4.cuh
@@ -0,0 +1,227 @@
+#ifndef _qdq_4_cuh
+#define _qdq_4_cuh
+
+#include "qdq_util.cuh"
+#include "../../config.h"
+
+#if QMODE_4BIT == 1
+
+// Permutation:
+//
+// 77775555 33331111 66664444 22220000
+
+__forceinline__ __device__ void shuffle_4bit_8
+(
+ uint32_t* q,
+ int stride
+)
+{
+ uint32_t qa = q[0];
+ uint32_t qb = 0;
+
+ #pragma unroll
+ for (int i = 0; i < 4; i++)
+ {
+ uint32_t qa0 = qa & 0x0f;
+ uint32_t qa1 = (qa & 0xf0) >> 4;
+ qa >>= 8;
+ qb |= (qa1 << (i * 4 + 16));
+ qb |= (qa0 << (i * 4));
+ }
+ q[0] = qb;
+}
+
+__forceinline__ __device__ void dequant_4bit_8
+(
+ const uint32_t q_0,
+ half2 (&dq)[4],
+ int stride
+)
+{
+ const uint32_t c0 = 0x64006400;
+ const half y16_ = __float2half_rn(1.0f / 16.0f);
+ const half2 y16 = __halves2half2(y16_, y16_);
+ const half z1_ = __float2half_rn(-1024.0f - 8.0f);
+ const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
+ const half2 z1 = __halves2half2(z1_, z1_);
+ const half2 z16 = __halves2half2(z16_, z16_);
+
+ uint32_t qa = q_0;
+ half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
+ half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
+ qa >>= 8;
+ half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
+ half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
+
+ dq[0] = __hadd2(q0.as_half2, z1);
+ dq[1] = __hfma2(q1.as_half2, y16, z16);
+ dq[2] = __hadd2(q2.as_half2, z1);
+ dq[3] = __hfma2(q3.as_half2, y16, z16);
+}
+
+__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
+(
+ const uint32_t zero,
+ const half scale,
+ half2 (&z1z16)[2],
+ half2 (&y1y16)[2]
+)
+{
+ half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
+ half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
+
+ half2 scale2 = __half2half2(scale);
+
+ z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
+ z1z16[1] = __hmul2(scale2, __half2half2(z16));
+
+ const half y1 = __float2half_rn(1.0f);
+ const half y16 = __float2half_rn(1.0f / 16.0f);
+
+ y1y16[0] = __hmul2(scale2, __half2half2(y1));
+ y1y16[1] = __hmul2(scale2, __half2half2(y16));
+}
+
+__forceinline__ __device__ void dequant_4bit_8_prep_zero
+(
+ const uint32_t zero,
+ half2(&z1z16)[2],
+ half2(&y1y16)[2]
+)
+{
+ half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
+ half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
+
+ z1z16[0] = __half2half2(z1.as_half);
+ z1z16[1] = __half2half2(z16);
+
+ const half y1 = __float2half_rn(1.0f);
+ const half y16 = __float2half_rn(1.0f / 16.0f);
+
+ y1y16[0] = __half2half2(y1);
+ y1y16[1] = __half2half2(y16);
+}
+
+
+__forceinline__ __device__ void dequant_4bit_8_gptq
+(
+ const uint32_t q_0,
+ half2 (&dq)[4],
+ half2 (&z1z16)[2],
+ half2 (&y1y16)[2],
+ int stride,
+ bool scaled
+)
+{
+ const uint32_t c0 = 0x64006400;
+
+ uint32_t qa = q_0;
+ half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
+ half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
+ qa >>= 8;
+ half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
+ half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
+
+ if (scaled)
+ {
+ dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
+ dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
+ dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
+ dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
+ }
+ else
+ {
+ dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
+ dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
+ dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
+ dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
+ }
+}
+
+#else
+
+__forceinline__ __device__ void shuffle_4bit_8
+(
+ uint32_t* q,
+ int stride
+)
+{
+}
+
+__forceinline__ __device__ void dequant_4bit_8
+(
+ const uint32_t q_0,
+ half2 (&dq)[4],
+ int stride
+)
+{
+ half dqh[8];
+ for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
+
+ for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
+}
+
+__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
+(
+ const uint32_t zero,
+ const half scale,
+ half2 (&z1)[2],
+ half2 (&y1)[2]
+)
+{
+ half z = __int2half_rn(-((int)zero));
+ z = __hmul(z, scale);
+ z1[0] = __half2half2(z);
+ y1[0] = __half2half2(scale);
+}
+
+__forceinline__ __device__ void dequant_4bit_8_prep_zero
+(
+ const uint32_t zero,
+ half2(&z1)[2],
+ half2(&y1)[2]
+)
+{
+ half z = __int2half_rn(-((int)zero));
+ z1[0] = __half2half2(z);
+}
+
+__forceinline__ __device__ void dequant_4bit_8_gptq
+(
+ const uint32_t q_0,
+ half2 (&dq)[4],
+ half2 (&z1)[2],
+ half2 (&y1)[2],
+ int stride,
+ bool scaled
+)
+{
+ half2 dqh2[8];
+
+ uint32_t qa = q_0;
+ for (int i = 0; i < 4; i++)
+ {
+ half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
+ half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
+ dqh2[i] = __halves2half2(d0, d1);
+ }
+
+ if (scaled)
+ {
+ dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
+ dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
+ dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
+ dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
+ }
+ else
+ {
+ dq[0] = __hadd2(dqh2[0], z1[0]);
+ dq[1] = __hadd2(dqh2[1], z1[0]);
+ dq[2] = __hadd2(dqh2[2], z1[0]);
+ dq[3] = __hadd2(dqh2[3], z1[0]);
+ }
+}
+
+#endif
+
+#endif
\ No newline at end of file
diff --git a/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_5.cuh b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_5.cuh
new file mode 100644
index 00000000..454e4b93
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_5.cuh
@@ -0,0 +1,207 @@
+#ifndef _qdq_5_cuh
+#define _qdq_5_cuh
+
+#include "qdq_util.cuh"
+#include "../../config.h"
+
+#if QMODE_5BIT == 1
+
+// Permutation:
+//
+// v5555533 33311111 u4444422 22200000 (u, v lsb)
+// vbbbbb99 99977777 uaaaaa88 88866666
+// vhhhhhff fffddddd ugggggee eeeccccc
+// vnnnnnll llljjjjj ummmmmkk kkkiiiii
+// vtttttrr rrrppppp usssssqq qqqooooo
+
+__forceinline__ __device__ void shuffle_5bit_32
+(
+ uint32_t* q,
+ int stride
+)
+{
+ uint32_t qa = q[0 * stride];
+ uint32_t qb = q[1 * stride];
+ uint32_t qc = q[2 * stride];
+ uint32_t qd = q[3 * stride];
+ uint32_t qe = q[4 * stride];
+
+ // qa: 66555554 44443333 32222211 11100000
+ // qb: ccccbbbb baaaaa99 99988888 77777666
+ // qc: jiiiiihh hhhggggg fffffeee eedddddc
+ // qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
+ // qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
+
+ uint32_t qf = qe >> 22;
+ qe <<= 8;
+ qe |= qd >> 24;
+ qd <<= 6;
+ qd |= qc >> 26;
+ qc <<= 4;
+ qc |= qb >> 28;
+ qb <<= 2;
+ qb |= qa >> 30;
+
+ // qa: 555554 44443333 32222211 11100000
+ // qb: bbbbba aaaa9999 98888877 77766666
+ // qc: hhhhhg ggggffff feeeeedd dddccccc
+ // qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
+ // qe: ttttts ssssrrrr rqqqqqpp pppooooo
+ // qf: vv vvvuuuuu
+
+ uint32_t za = 0;
+ uint32_t zb = 0;
+ uint32_t zc = 0;
+ uint32_t zd = 0;
+ uint32_t ze = 0;
+
+ for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
+ for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
+ for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
+ for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
+ for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
+
+ // za: 5555533 33311111 4444422 22200000
+ // zb: bbbbb99 99977777 aaaaa88 88866666
+ // zc: hhhhhff fffddddd gggggee eeeccccc
+ // zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
+ // ze: tttttrr rrrppppp sssssqq qqqooooo
+ // qf: vv vvvuuuuu
+
+ za |= ((qf & 0x001) >> 0) << 15;
+ zb |= ((qf & 0x002) >> 1) << 15;
+ zc |= ((qf & 0x004) >> 2) << 15;
+ zd |= ((qf & 0x008) >> 3) << 15;
+ ze |= ((qf & 0x010) >> 4) << 15;
+ za |= ((qf & 0x020) >> 5) << 31;
+ zb |= ((qf & 0x040) >> 6) << 31;
+ zc |= ((qf & 0x080) >> 7) << 31;
+ zd |= ((qf & 0x100) >> 8) << 31;
+ ze |= ((qf & 0x200) >> 9) << 31;
+
+ // za: v5555533 33311111 u4444422 22200000 (u, v lsb)
+ // zb: vbbbbb99 99977777 uaaaaa88 88866666
+ // zc: vhhhhhff fffddddd ugggggee eeeccccc
+ // zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
+ // ze: vtttttrr rrrppppp usssssqq qqqooooo
+
+ q[0 * stride] = za;
+ q[1 * stride] = zb;
+ q[2 * stride] = zc;
+ q[3 * stride] = zd;
+ q[4 * stride] = ze;
+}
+
+__forceinline__ __device__ void dequant_5bit_32
+(
+ const uint32_t q_0,
+ const uint32_t q_1,
+ const uint32_t q_2,
+ const uint32_t q_3,
+ const uint32_t q_4,
+ half2 (&dq)[16],
+ int stride
+)
+{
+ const uint32_t c0 = 0x64006400;
+ const half y32_ = __float2half_rn(1.0f / 32.0f);
+ const half2 y32 = __halves2half2(y32_, y32_);
+ const half z1_ = __float2half_rn(-1024.0f - 16.0f);
+ const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
+ const half2 z1 = __halves2half2(z1_, z1_);
+ const half2 z32 = __halves2half2(z32_, z32_);
+
+ uint32_t qa = q_0;
+ uint32_t qb = q_1;
+ uint32_t qc = q_2;
+ uint32_t qd = q_3;
+ uint32_t qe = q_4;
+
+ half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
+ half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
+ qa >>= 10;
+ half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
+ qa >>= 5;
+ qa &= 0x00010001;
+ half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
+ half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
+ qb >>= 10;
+ half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
+ qb >>= 4;
+ qb &= 0x00020002;
+ half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
+ half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
+ qc >>= 10;
+ half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
+ qc >>= 3;
+ qc &= 0x00040004;
+ half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
+ half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
+ qd >>= 10;
+ half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
+ qd >>= 2;
+ qd &= 0x00080008;
+ half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
+ half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
+ qe >>= 10;
+ half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
+ qe >>= 1;
+ qe &= 0x00100010;
+ half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
+
+ dq[ 0] = __hadd2( q0.as_half2, z1);
+ dq[ 1] = __hfma2( q1.as_half2, y32, z32);
+ dq[ 2] = __hadd2( q2.as_half2, z1);
+ dq[ 3] = __hadd2( q3.as_half2, z1);
+ dq[ 4] = __hfma2( q4.as_half2, y32, z32);
+ dq[ 5] = __hadd2( q5.as_half2, z1);
+ dq[ 6] = __hadd2( q6.as_half2, z1);
+ dq[ 7] = __hfma2( q7.as_half2, y32, z32);
+ dq[ 8] = __hadd2( q8.as_half2, z1);
+ dq[ 9] = __hadd2( q9.as_half2, z1);
+ dq[10] = __hfma2(q10.as_half2, y32, z32);
+ dq[11] = __hadd2(q11.as_half2, z1);
+ dq[12] = __hadd2(q12.as_half2, z1);
+ dq[13] = __hfma2(q13.as_half2, y32, z32);
+ dq[14] = __hadd2(q14.as_half2, z1);
+ dq[15] = __hadd2(q15.as_half2, z1);
+}
+
+#else
+
+__forceinline__ __device__ void shuffle_5bit_32
+(
+ uint32_t* q,
+ int stride
+)
+{
+}
+
+__forceinline__ __device__ void dequant_5bit_32
+(
+ const uint32_t q_0,
+ const uint32_t q_1,
+ const uint32_t q_2,
+ const uint32_t q_3,
+ const uint32_t q_4,
+ half2 (&dq)[16],
+ int stride
+)
+{
+ half dqh[32];
+ for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
+ dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
+ for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
+ dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
+ for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
+ dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
+ for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
+ dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
+ for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
+
+ for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
+}
+
+#endif
+
+#endif
\ No newline at end of file
diff --git a/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_6.cuh b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_6.cuh
new file mode 100644
index 00000000..c2eb8cfb
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_6.cuh
@@ -0,0 +1,44 @@
+#ifndef _qdq_6_cuh
+#define _qdq_6_cuh
+
+#include "qdq_util.cuh"
+#include "../../config.h"
+
+#if QMODE_6BIT == 1
+
+ // Not implemented
+
+#else
+
+__forceinline__ __device__ void shuffle_6bit_16
+(
+ uint32_t* q,
+ int stride
+)
+{
+}
+
+__forceinline__ __device__ void dequant_6bit_16
+(
+ const uint32_t q_0,
+ const uint32_t q_1,
+ const uint32_t q_2,
+ half2 (&dq)[8],
+ int stride
+)
+{
+ half dqh[16];
+ for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
+ dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
+ for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
+ dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
+ for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
+
+ for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
+}
+
+#endif
+
+#endif
+
+
diff --git a/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_8.cuh b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_8.cuh
new file mode 100644
index 00000000..e2409efa
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_8.cuh
@@ -0,0 +1,38 @@
+#ifndef _qdq_8_cuh
+#define _qdq_8_cuh
+
+#include "qdq_util.cuh"
+#include "../../config.h"
+
+#if QMODE_8BIT == 1
+
+ // Not implemented
+
+#else
+
+__forceinline__ __device__ void shuffle_8bit_4
+(
+ uint32_t* q,
+ int stride
+)
+{
+}
+
+__forceinline__ __device__ void dequant_8bit_8
+(
+ const uint32_t q_0,
+ const uint32_t q_1,
+ half2 (&dq)[4],
+ int stride
+)
+{
+ half dqh[8];
+ for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
+ for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
+
+ for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
+}
+
+#endif
+
+#endif
\ No newline at end of file
diff --git a/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_util.cuh b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_util.cuh
new file mode 100644
index 00000000..71657191
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/quant/qdq_util.cuh
@@ -0,0 +1,51 @@
+#ifndef _qdq_util_cuh
+#define _qdq_util_cuh
+
+union half2_uint32
+{
+ uint32_t as_uint32;
+ half2 as_half2;
+ __device__ half2_uint32(uint32_t val) : as_uint32(val) {}
+ __device__ half2_uint32(half2 val) : as_half2(val) {}
+};
+
+union half_uint16
+{
+ uint16_t as_uint16;
+ half as_half;
+ __device__ half_uint16(uint16_t val) : as_uint16(val) {}
+ __device__ half_uint16(half val) : as_half(val) {}
+};
+
+// Max_scale premultiplied by 1/256
+
+__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
+{
+ int qs_i = qs + 1;
+ half qs_h = __int2half_rn(qs_i * qs_i);
+ qs_h = __hmul(qs_h, max_scale);
+ return qs_h;
+}
+
+__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
+{
+ return __hmul(__int2half_rn(q - qzero), scale);
+}
+
+__forceinline__ __device__ half dq_ns(const int q, const int qzero)
+{
+ //return __hsub(__int2half_rn(q), __int2half_rn(qzero));
+ return __int2half_rn(q - qzero);
+}
+
+__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
+{
+ return (int)((q >> shift) & mask);
+}
+
+__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
+{
+ return (int)(__funnelshift_rc(q0, q1, shift) & mask);
+}
+
+#endif
diff --git a/auto_round_extension/cuda/exllamav2/cuda/util.cuh b/auto_round_extension/cuda/exllamav2/cuda/util.cuh
new file mode 100644
index 00000000..36be0e24
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/cuda/util.cuh
@@ -0,0 +1,42 @@
+
+#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
+
+#define DBGS(__x) printf("%s\n", __x)
+#define DBGI(__x) printf("%s: %i\n", #__x, __x)
+#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
+#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
+#define DBGX(__x) printf("%s: %x\n", #__x, __x)
+#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
+#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
+#define DBGF(__x) printf("%s: %f\n", #__x, __x)
+#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
+#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
+#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
+#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
+#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
+
+#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
+#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
+
+__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
+{
+ half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
+ qs_h = __hmul(qs_h, qs_h);
+ qs_h = __hmul(qs_h, max_scale);
+ return qs_h;
+}
+
+__forceinline__ __device__ float clamp(float x, float a, float b)
+{
+ return fmaxf(a, fminf(b, x));
+}
+
+#define cuda_check(res) { gpu_assert((res), __FILE__, __LINE__); }
+inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
+{
+ if (code != cudaSuccess)
+ {
+ fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
+ if (abort) exit(code);
+ }
+}
diff --git a/auto_round_extension/cuda/exllamav2/ext.cpp b/auto_round_extension/cuda/exllamav2/ext.cpp
new file mode 100644
index 00000000..5e52e6ab
--- /dev/null
+++ b/auto_round_extension/cuda/exllamav2/ext.cpp
@@ -0,0 +1,134 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "config.h"
+
+#include "cuda/q_matrix.cuh"
+#include "cuda/q_gemm.cuh"
+
+#include "cpp/util.h"
+
+// Some decluttering macros
+
+#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
+#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
+#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
+#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
+
+
+// Quant matrix
+
+uintptr_t make_q_matrix
+(
+ torch::Tensor q_weight,
+ torch::Tensor q_perm,
+ torch::Tensor q_invperm,
+ torch::Tensor q_scale,
+ torch::Tensor q_scale_max,
+ torch::Tensor q_groups,
+ torch::Tensor gptq_qzeros,
+ torch::Tensor gptq_scales,
+ torch::Tensor gptq_g_idx,
+ torch::Tensor temp_dq
+)
+{
+ TORCH_CHECK_DTYPE(q_weight, kInt);
+ TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
+ TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
+ TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
+ TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
+ TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
+ TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
+ TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
+ TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
+
+ TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
+
+ int device = q_weight.device().index();
+ int width = q_weight.size(1);
+ int groups;
+ int height;
+
+ if (!q_scale.device().is_meta())
+ {
+ TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
+ TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
+ groups = q_scale.size(0);
+ height = q_invperm.size(0);
+ }
+ else
+ {
+ TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
+ TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
+ groups = gptq_qzeros.size(0);
+ height = q_weight.size(0) * 8;
+ }
+
+ TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
+
+ QMatrix* m = new QMatrix
+ (
+ device,
+ height,
+ width,
+ groups,
+ (uint32_t*) q_weight.data_ptr(),
+ q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
+ q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
+ q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
+ q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
+ q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
+ gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
+ gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
+ gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
+ (half*) temp_dq.data_ptr()
+ );
+
+ return reinterpret_cast (m);
+}
+
+void gemm_half_q_half
+(
+ torch::Tensor a,
+ uintptr_t b,
+ torch::Tensor c,
+ bool force_cuda
+)
+{
+ QMatrix* qm = reinterpret_cast (b);
+
+ TORCH_CHECK_DTYPE(a, kHalf);
+ TORCH_CHECK_DTYPE(c, kHalf);
+ TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
+ TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
+ TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
+
+ gemm_half_q_half_cuda
+ (
+ at::cuda::getCurrentCUDABlasHandle(),
+ (const half*) a.data_ptr(),
+ qm,
+ (half*) c.data_ptr(),
+ c.size(0), // m
+ c.size(1), // n
+ a.size(1), // k
+ true,
+ NULL,
+ force_cuda
+ );
+}
+
+// Bindings
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
+ m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
+}
diff --git a/auto_round_extension/cuda/post_init.py b/auto_round_extension/cuda/post_init.py
new file mode 100644
index 00000000..3de7e5ab
--- /dev/null
+++ b/auto_round_extension/cuda/post_init.py
@@ -0,0 +1,154 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MIT License
+#
+# Copyright (c) 2023 潘其威(William)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import torch
+
+EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
+
+
+def autoround_post_init(model):
+ """
+ The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state.
+ """
+ device_to_buffers_size = {}
+
+ model_uses_exllama = False
+ for name, submodule in model.named_modules():
+ if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
+ model_uses_exllama = True
+ device = submodule.qweight.device
+ if device not in device_to_buffers_size:
+ device_to_buffers_size[device] = {
+ "max_dq_buffer_size": 1,
+ "max_inner_outer_dim": 1,
+ }
+
+ submodule._use_act_order = False
+
+ # Disable this heuristic for detecting act_order, but it could be used instead of the config.
+ """
+ if submodule.g_idx is None:
+ submodule.act_order = False
+ elif submodule.g_idx is not None and ((submodule.g_idx == 0).all() or
+ torch.equal(submodule.g_idx.cpu(),
+ torch.tensor([i // submodule.group_size for i in range(submodule.g_idx.shape[0])], dtype=torch.int32))):
+ submodule.g_idx = None
+ submodule.act_order = False
+ else:
+ submodule.act_order = True
+ """
+
+ device_to_buffers_size[device]["max_dq_buffer_size"] = max(
+ device_to_buffers_size[device]["max_dq_buffer_size"],
+ submodule.qweight.numel() * 8,
+ )
+
+ if model_uses_exllama:
+ # To be honest this is quite ugly, not proud of this.
+ try:
+ from exllama_kernels import prepare_buffers, set_tuning_params
+ except ImportError as e:
+ raise ImportError(
+ f"Could not import exllama backend dependencies prepare_buffers, set_tuning_params with the following "
+ f"error: {e}"
+ )
+
+ device_to_buffers = {}
+
+ max_input_len = 1
+
+ for device, buffers_size in device_to_buffers_size.items():
+ # The temp_state buffer is required to reorder X in the act-order case.
+ # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
+ device_to_buffers[device] = {
+ "temp_state": torch.zeros(
+ (max_input_len, buffers_size["max_inner_outer_dim"]),
+ dtype=torch.float16,
+ device=device,
+ ),
+ "temp_dq": torch.zeros(
+ (1, buffers_size["max_dq_buffer_size"]),
+ dtype=torch.float16,
+ device=device,
+ ),
+ "max_dq_buffer_size": buffers_size["max_dq_buffer_size"],
+ "max_inner_outer_dim": buffers_size["max_inner_outer_dim"],
+ }
+
+ # Buffers need to be persistent to avoid any bug.
+ model.device_to_buffers = device_to_buffers
+
+ for device, buffers in model.device_to_buffers.items():
+ prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
+
+ # Using the default from exllama repo here.
+ matmul_recons_thd = 8
+ matmul_fused_remap = False
+ matmul_no_half2 = False
+ set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
+
+ # The buffers need to have been initialized first before calling make_q4.
+ for name, submodule in model.named_modules():
+ if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
+ submodule.post_init()
+
+ ## exllamav2
+ fixed_bytes = {}
+ model_uses_exllamav2 = False
+
+ for _, submodule in model.named_modules():
+ if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
+ model_uses_exllamav2 = True
+ device = submodule.qweight.device
+ scratch_fixed = submodule.scratch_space_fixed()
+ fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0))
+
+ if model_uses_exllamav2:
+ from auto_round_extension.cuda.qliner_exllamav2 import ExLlamaV2DeviceTensors
+
+ device_tensors = {}
+ for device, scratch_bytes in fixed_bytes.items():
+ device_tensors[device] = ExLlamaV2DeviceTensors(device.index, scratch_bytes)
+
+ # have persistent buffers, otherwise we will get OOM
+ model.device_tensors = device_tensors
+
+ for _, submodule in model.named_modules():
+ if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
+ device = submodule.qweight.device
+ submodule.post_init(temp_dq=model.device_tensors[device])
+ torch.cuda.empty_cache()
+
+ return model
diff --git a/auto_round_extension/cuda/qliner_exllamav2.py b/auto_round_extension/cuda/qliner_exllamav2.py
new file mode 100644
index 00000000..3f98700d
--- /dev/null
+++ b/auto_round_extension/cuda/qliner_exllamav2.py
@@ -0,0 +1,271 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
+# MIT License
+#
+# Copyright (c) 2023 潘其威(William)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import math
+from logging import getLogger
+
+import torch
+import torch.nn as nn
+
+
+logger = getLogger(__name__)
+
+try:
+ from autoround_exllamav2_kernels import gemm_half_q_half, make_q_matrix
+except ImportError as e:
+ exllama_v2_import_exception = e
+
+ def error_raiser_exllama(*args, **kwargs):
+ raise ValueError(
+ f"Trying to use the exllama v2 backend, but could not import the C++/CUDA dependencies with the following "
+ f"error: {exllama_v2_import_exception}"
+ )
+
+ make_q_matrix = error_raiser_exllama
+ gemm_half_q_half = error_raiser_exllama
+
+# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
+none_tensor = torch.empty((1, 1), device="meta")
+
+
+def _torch_device(idx):
+ if idx == -1:
+ return "cpu"
+ return f"cuda:{idx}"
+
+
+def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
+ """Matrix multiplication, returns x @ q4"""
+ output_shape = x.shape[:-1] + (q4_width,)
+ x = x.view(-1, x.shape[-1])
+ output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
+ gemm_half_q_half(x, q_handle, output, force_cuda)
+ return output.view(output_shape)
+
+
+def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
+ """
+ Create Q matrix
+ """
+ # EXL2
+ # won't work as the moment because the tensors are not the same.
+ if "q_weight" in w:
+ w["q_scale_max"] /= 256
+ w["q_perm"] = w["q_perm"].short()
+ w["q_invperm"] = w["q_invperm"].short()
+ return make_q_matrix(
+ w["q_weight"],
+ w["q_perm"],
+ w["q_invperm"],
+ w["q_scale"],
+ w["q_scale_max"],
+ w["q_groups"],
+ none_tensor,
+ none_tensor,
+ none_tensor,
+ temp_dq,
+ )
+ # GPTQ
+ elif "qweight" in w:
+ if w["scales"].dtype == torch.float:
+ w["scales"] = w["scales"].half()
+
+ # GPTQ with g_idx (act_order)
+ if "g_idx" in w and not (w["g_idx"] == 0).all().item():
+ w["q_perm"] = torch.empty(
+ (w["qweight"].shape[0] * 8,),
+ dtype=torch.short,
+ device=w["qweight"].device,
+ )
+ w["q_invperm"] = torch.empty_like(w["q_perm"])
+ # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs
+ # to be passed for g_idx.
+ return make_q_matrix(
+ w["qweight"],
+ w["q_perm"],
+ w["q_invperm"],
+ none_tensor,
+ none_tensor,
+ none_tensor,
+ w["qzeros"],
+ w["scales"],
+ w["g_idx"].cpu(),
+ temp_dq,
+ )
+ # GPTQ without g_idx
+ else:
+ return make_q_matrix(
+ w["qweight"],
+ none_tensor,
+ none_tensor,
+ none_tensor,
+ none_tensor,
+ none_tensor,
+ w["qzeros"],
+ w["scales"],
+ none_tensor,
+ temp_dq,
+ )
+
+
+class QuantLinear(nn.Module):
+ QUANT_TYPE = "exllamav2"
+
+ """Linear layer implementation with per-group 4-bit quantization of the weights"""
+
+ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
+ super().__init__()
+ if bits != 4:
+ raise ValueError(
+ f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model "
+ f"initialization."
+ )
+ if trainable:
+ raise NotImplementedError("Exllamav2 kernel does not support training.")
+
+ self.q_handle = None
+ self.q_tensors = None
+
+ self.padding = -outfeatures % 32
+ self.outfeatures = outfeatures + self.padding
+ outfeatures = self.outfeatures
+
+ self.infeatures = infeatures
+ self.bits = bits
+ self.group_size = group_size if group_size != -1 else infeatures
+ self.trainable = trainable
+ self.maxq = 2**self.bits - 1
+
+ assert infeatures % 32 == 0
+ assert infeatures % self.group_size == 0
+ assert outfeatures % 32 == 0
+
+ # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
+ self.register_buffer(
+ "qweight",
+ torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
+ )
+ self.register_buffer(
+ "qzeros",
+ torch.zeros(
+ (
+ math.ceil(infeatures / self.group_size),
+ outfeatures // 32 * self.bits,
+ ),
+ dtype=torch.int32,
+ ),
+ )
+ self.register_buffer(
+ "scales",
+ torch.zeros(
+ (math.ceil(infeatures / self.group_size), outfeatures),
+ dtype=torch.float16,
+ ),
+ )
+ self.register_buffer(
+ "g_idx",
+ torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32),
+ )
+
+ if bias:
+ self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
+ else:
+ self.bias = None
+
+ def post_init(self, temp_dq):
+ assert self.qweight.device.type == "cuda"
+ assert self.qweight.device.index is not None
+ self.q_tensors = {
+ "qweight": self.qweight,
+ "qzeros": self.qzeros,
+ "scales": self.scales,
+ "g_idx": self.g_idx,
+ }
+ temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
+ self.q_handle = ext_make_q_matrix(self.q_tensors, temp_dq)
+
+ def forward(self, x, force_cuda=False):
+ if x.dtype != torch.float16:
+ logger.warning_once(
+ f"The exllama v2 kernel for GPTQ requires a float16 input activation, while {x.dtype} was passed. "
+ f"Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model "
+ f"definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 "
+ f"intermediate activations in the model."
+ )
+
+ x = x.half()
+
+ output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
+
+ if self.bias is not None:
+ output.add_(self.bias)
+ return output
+
+ def temp_dq_size(self):
+ return self.infeatures * self.outfeatures * 2 + 128
+
+ def temp_fwd_size(self, max_input_len, max_batch_size):
+ return self.outfeatures * max_input_len * max_batch_size * 4 + 128
+
+ def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
+ return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
+
+
+class ExLlamaV2DeviceTensors:
+ device_idx: int
+ scratch_bytes: int
+ scratch_idx: int
+ scratch: torch.tensor = None
+
+ def __init__(self, device_idx, scratch_bytes):
+ self.device_idx = device_idx
+ self.scratch_bytes = scratch_bytes
+
+ def prepare(self):
+ self.scratch = torch.empty(
+ (self.scratch_bytes // 2,),
+ dtype=torch.half,
+ device=_torch_device(self.device_idx),
+ )
+
+ def get_scratch_slice(self, size_bytes):
+ if self.scratch is None:
+ self.prepare()
+
+ size_bytes = ((size_bytes + 127) // 128) * 128
+ size_half = size_bytes // 2
+ scratch_slice = self.scratch.narrow(0, 0, size_half)
+ return scratch_slice
diff --git a/auto_round_extension/cuda/qliner_triton.py b/auto_round_extension/cuda/qliner_triton.py
new file mode 100644
index 00000000..e307ace3
--- /dev/null
+++ b/auto_round_extension/cuda/qliner_triton.py
@@ -0,0 +1,255 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MIT License
+#
+# Copyright (c) 2023 潘其威(William)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import math
+from logging import getLogger
+
+import numpy as np
+import torch
+import torch.nn as nn
+import transformers
+
+from auto_round_extension.cuda.triton_utils.mixin import TritonModuleMixin
+
+
+logger = getLogger(__name__)
+
+try:
+ from auto_round_extension.cuda.triton_utils import (
+ QuantLinearFunction,
+ QuantLinearInferenceOnlyFunction,
+ quant_matmul_248,
+ quant_matmul_inference_only_248,
+ transpose_quant_matmul_248,
+ )
+except ImportError as e:
+ triton_import_exception = e
+
+ def error_raiser_triton(*args, **kwargs):
+ raise ValueError(
+ f'Trying to use the triton backend, but could not import triton '
+ f'dependencies with the following error: {triton_import_exception}'
+ )
+
+ class FakeTriton:
+ def __getattr__(self, name):
+ raise ImportError(
+ f"Trying to use the triton backend, but could not import triton "
+ f"dependencies with the following error: {triton_import_exception}"
+ )
+
+ quant_matmul_248 = error_raiser_triton
+ transpose_quant_matmul_248 = error_raiser_triton
+ quant_matmul_inference_only_248 = error_raiser_triton
+ QuantLinearFunction = FakeTriton
+ QuantLinearInferenceOnlyFunction = FakeTriton
+
+
+class QuantLinear(nn.Module, TritonModuleMixin):
+ QUANT_TYPE = "triton"
+
+ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
+ super().__init__()
+ if bits not in [2, 4, 8]:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ if infeatures % 32 != 0 or outfeatures % 32 != 0:
+ raise NotImplementedError("in_feature and out_feature must be divisible by 32.")
+ self.infeatures = infeatures
+ self.outfeatures = outfeatures
+ self.bits = bits
+ self.group_size = group_size if group_size != -1 else infeatures
+ self.maxq = 2**self.bits - 1
+
+ self.register_buffer(
+ "qweight",
+ torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
+ )
+ self.register_buffer(
+ "qzeros",
+ torch.zeros(
+ (
+ math.ceil(infeatures / self.group_size),
+ outfeatures // 32 * self.bits,
+ ),
+ dtype=torch.int32,
+ ),
+ )
+ self.register_buffer(
+ "scales",
+ torch.zeros(
+ (math.ceil(infeatures / self.group_size), outfeatures),
+ dtype=torch.float16,
+ ),
+ )
+ self.register_buffer(
+ "g_idx",
+ torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32),
+ )
+ if bias:
+ self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
+ else:
+ self.bias = None
+
+ self.trainable = trainable
+
+ def post_init(self):
+ pass
+
+ def pack(self, linear, scales, zeros, g_idx=None):
+ W = linear.weight.data.clone()
+ if isinstance(linear, nn.Conv2d):
+ W = W.flatten(1)
+ if isinstance(linear, transformers.pytorch_utils.Conv1D):
+ W = W.t()
+
+ self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
+
+ scales = scales.t().contiguous()
+ zeros = zeros.t().contiguous()
+ scale_zeros = zeros * scales
+ self.scales = scales.clone().half()
+ if linear.bias is not None:
+ self.bias = linear.bias.clone().half()
+
+ intweight = []
+ for idx in range(self.infeatures):
+ intweight.append(
+ torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[
+ :, None
+ ]
+ )
+ intweight = torch.cat(intweight, dim=1)
+ intweight = intweight.t().contiguous()
+ intweight = intweight.numpy().astype(np.uint32)
+
+ i = 0
+ row = 0
+ qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
+ while row < qweight.shape[0]:
+ if self.bits in [2, 4, 8]:
+ for j in range(i, i + (32 // self.bits)):
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
+ i += 32 // self.bits
+ row += 1
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+
+ qweight = qweight.astype(np.int32)
+ self.qweight = torch.from_numpy(qweight)
+
+ # zeros -= 1
+ zeros = zeros.numpy().astype(np.uint32)
+ qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
+ i = 0
+ col = 0
+ while col < qzeros.shape[1]:
+ if self.bits in [2, 4, 8]:
+ for j in range(i, i + (32 // self.bits)):
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
+ i += 32 // self.bits
+ col += 1
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+
+ qzeros = qzeros.astype(np.int32)
+ self.qzeros = torch.from_numpy(qzeros)
+
+ def forward(self, x):
+ out_shape = x.shape[:-1] + (self.outfeatures,)
+ quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
+ out = quant_linear_fn.apply(
+ x.reshape(-1, x.shape[-1]),
+ self.qweight,
+ self.scales,
+ self.qzeros,
+ self.g_idx,
+ self.bits,
+ self.maxq,
+ )
+ out = out.half().reshape(out_shape)
+ out = out + self.bias if self.bias is not None else out
+ return out
+
+ @classmethod
+ def warmup(cls, model, transpose=False, seqlen=2048):
+ """
+ Pre-tunes the quantized kernel
+ """
+ from tqdm import tqdm
+
+ kn_values = {}
+
+ for _, m in model.named_modules():
+ if not isinstance(m, cls):
+ continue
+
+ k = m.infeatures
+ n = m.outfeatures
+
+ if (k, n) not in kn_values:
+ kn_values[(k, n)] = (
+ m.qweight,
+ m.scales,
+ m.qzeros,
+ m.g_idx,
+ m.bits,
+ m.maxq,
+ )
+
+ logger.info(f"Found {len(kn_values)} unique KN Linear values.")
+ logger.info("Warming up autotune cache ...")
+ with torch.no_grad():
+ for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
+ m = 2**m
+ for (k, n), (
+ qweight,
+ scales,
+ qzeros,
+ g_idx,
+ bits,
+ maxq,
+ ) in kn_values.items():
+ if transpose:
+ a = torch.randn(m, k, dtype=torch.float16, device=model.device)
+ quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
+ a = torch.randn(m, n, dtype=torch.float16, device=model.device)
+ transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
+ else:
+ a = torch.randn(m, k, dtype=torch.float16, device=model.device)
+ quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
+ del kn_values
+
+
+__all__ = ["QuantLinear"]
diff --git a/auto_round_extension/cuda/triton_utils/__init__.py b/auto_round_extension/cuda/triton_utils/__init__.py
new file mode 100644
index 00000000..2045808a
--- /dev/null
+++ b/auto_round_extension/cuda/triton_utils/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/auto_round_extension/cuda/triton_utils/custom_autotune.py b/auto_round_extension/cuda/triton_utils/custom_autotune.py
new file mode 100644
index 00000000..b511579c
--- /dev/null
+++ b/auto_round_extension/cuda/triton_utils/custom_autotune.py
@@ -0,0 +1,254 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MIT License
+#
+# Copyright (c) 2023 潘其威(William)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import builtins
+import math
+import time
+from typing import Dict
+
+import triton
+
+
+# code based https://github.com/fpgaminer/GPTQ-triton
+"""
+Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
+"""
+
+
+class CustomizedTritonAutoTuner(triton.KernelInterface):
+ def __init__(
+ self,
+ fn,
+ arg_names,
+ configs,
+ key,
+ reset_to_zero,
+ prune_configs_by: Dict = None,
+ nearest_power_of_two: bool = False,
+ ):
+ if not configs:
+ self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
+ else:
+ self.configs = configs
+ self.key_idx = [arg_names.index(k) for k in key]
+ self.nearest_power_of_two = nearest_power_of_two
+ self.cache = {}
+ # hook to reset all required tensor to zeros before relaunching a kernel
+ self.hook = lambda args: 0
+ if reset_to_zero is not None:
+ self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
+
+ def _hook(args):
+ for i in self.reset_idx:
+ args[i].zero_()
+
+ self.hook = _hook
+ self.arg_names = arg_names
+ # prune configs
+ if prune_configs_by:
+ perf_model, top_k = (
+ prune_configs_by["perf_model"],
+ prune_configs_by["top_k"],
+ )
+ if "early_config_prune" in prune_configs_by:
+ early_config_prune = prune_configs_by["early_config_prune"]
+ else:
+ perf_model, top_k, early_config_prune = None, None, None
+ self.perf_model, self.configs_top_k = perf_model, top_k
+ self.early_config_prune = early_config_prune
+ self.fn = fn
+
+ def _bench(self, *args, config, **meta):
+ # check for conflicts, i.e. meta-parameters both provided
+ # as kwargs and by the autotuner
+ conflicts = meta.keys() & config.kwargs.keys()
+ if conflicts:
+ raise ValueError(
+ f"Conflicting meta-parameters: {', '.join(conflicts)}."
+ " Make sure that you don't re-define auto-tuned symbols."
+ )
+ # augment meta-parameters with tunable ones
+ current = dict(meta, **config.kwargs)
+
+ def kernel_call():
+ if config.pre_hook:
+ config.pre_hook(self.nargs)
+ self.hook(args)
+ self.fn.run(
+ *args,
+ num_warps=config.num_warps,
+ num_stages=config.num_stages,
+ **current,
+ )
+
+ try:
+ # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
+ # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
+ return triton.testing.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40)
+ except triton.OutOfResources:
+ return (float("inf"), float("inf"), float("inf"))
+
+ def run(self, *args, **kwargs):
+ self.nargs = dict(zip(self.arg_names, args))
+ if len(self.configs) > 1:
+ key = tuple(args[i] for i in self.key_idx)
+
+ # This reduces the amount of autotuning by rounding the keys to the nearest power of two
+ # In my testing this gives decent results, and greatly reduces the amount of tuning required
+ if self.nearest_power_of_two:
+ key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
+
+ if key not in self.cache:
+ # prune configs
+ pruned_configs = self.prune_configs(kwargs)
+ bench_start = time.time()
+ timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
+ bench_end = time.time()
+ self.bench_time = bench_end - bench_start
+ self.cache[key] = builtins.min(timings, key=timings.get)
+ self.hook(args)
+ self.configs_timings = timings
+ config = self.cache[key]
+ else:
+ config = self.configs[0]
+ self.best_config = config
+ if config.pre_hook is not None:
+ config.pre_hook(self.nargs)
+ return self.fn.run(
+ *args,
+ num_warps=config.num_warps,
+ num_stages=config.num_stages,
+ **kwargs,
+ **config.kwargs,
+ )
+
+ def prune_configs(self, kwargs):
+ pruned_configs = self.configs
+ if self.early_config_prune:
+ pruned_configs = self.early_config_prune(self.configs, self.nargs)
+ if self.perf_model:
+ top_k = self.configs_top_k
+ if isinstance(top_k, float) and top_k <= 1.0:
+ top_k = int(len(self.configs) * top_k)
+ if len(pruned_configs) > top_k:
+ est_timing = {
+ config: self.perf_model(
+ **self.nargs,
+ **kwargs,
+ **config.kwargs,
+ num_stages=config.num_stages,
+ num_warps=config.num_warps,
+ )
+ for config in pruned_configs
+ }
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
+ return pruned_configs
+
+ def warmup(self, *args, **kwargs):
+ self.nargs = dict(zip(self.arg_names, args))
+ for config in self.prune_configs(kwargs):
+ self.fn.warmup(
+ *args,
+ num_warps=config.num_warps,
+ num_stages=config.num_stages,
+ **kwargs,
+ **config.kwargs,
+ )
+ self.nargs = None
+
+
+def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
+ def decorator(fn):
+ return CustomizedTritonAutoTuner(
+ fn,
+ fn.arg_names,
+ configs,
+ key,
+ reset_to_zero,
+ prune_configs_by,
+ nearest_power_of_two,
+ )
+
+ return decorator
+
+
+def matmul248_kernel_config_pruner(configs, nargs):
+ """
+ The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
+ """
+ m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
+ n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
+ k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
+
+ used = set()
+ for config in configs:
+ block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
+ block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
+ block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
+ group_size_m = config.kwargs["GROUP_SIZE_M"]
+
+ if (
+ block_size_m,
+ block_size_n,
+ block_size_k,
+ group_size_m,
+ config.num_stages,
+ config.num_warps,
+ ) in used:
+ continue
+
+ used.add(
+ (
+ block_size_m,
+ block_size_n,
+ block_size_k,
+ group_size_m,
+ config.num_stages,
+ config.num_warps,
+ )
+ )
+ yield triton.Config(
+ {
+ "BLOCK_SIZE_M": block_size_m,
+ "BLOCK_SIZE_N": block_size_n,
+ "BLOCK_SIZE_K": block_size_k,
+ "GROUP_SIZE_M": group_size_m,
+ },
+ num_stages=config.num_stages,
+ num_warps=config.num_warps,
+ )
+
+
+__all__ = ["autotune"]
diff --git a/auto_round_extension/cuda/triton_utils/dequant.py b/auto_round_extension/cuda/triton_utils/dequant.py
new file mode 100644
index 00000000..b7c6316d
--- /dev/null
+++ b/auto_round_extension/cuda/triton_utils/dequant.py
@@ -0,0 +1,180 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MIT License
+#
+# Copyright (c) 2023 潘其威(William)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import itertools
+
+import torch
+import triton
+import triton.language as tl
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+
+def make_dequant_configs(block_sizes, num_warps):
+ configs = []
+ for bs, ws in itertools.product(block_sizes, num_warps):
+ configs.append(triton.Config({"X_BLOCK": bs}, num_warps=ws))
+ return configs
+
+
+DEFAULT_DEQUANT_CONFIGS = make_dequant_configs([128, 256, 512, 1024], [4, 8])
+
+
+@triton.autotune(DEFAULT_DEQUANT_CONFIGS, key=["numels"])
+@triton.jit
+def dequant_kernel_248(
+ g_idx_ptr,
+ scales_ptr,
+ qweight_ptr,
+ qzeros_ptr,
+ out_ptr,
+ numels,
+ maxq: tl.constexpr,
+ bits: tl.constexpr,
+ outfeatures: tl.constexpr,
+ num_groups: tl.constexpr,
+ X_BLOCK: tl.constexpr,
+):
+ # Block indexing
+ xoffset = tl.program_id(0) * X_BLOCK
+ x_index = xoffset + tl.arange(0, X_BLOCK)
+ xmask = x_index < numels
+ row_idx = x_index // outfeatures
+ col_idx = x_index % outfeatures
+
+ elements_per_feature: tl.constexpr = 32 // bits
+
+ # Load parameters
+ g_idx = tl.load(g_idx_ptr + (row_idx), None, eviction_policy="evict_last")
+ qweights = tl.load(
+ qweight_ptr + (col_idx + (outfeatures * (row_idx // elements_per_feature))),
+ None,
+ )
+
+ wf_weights = (row_idx % elements_per_feature) * bits
+
+ wf_zeros = (col_idx % elements_per_feature) * bits
+
+ tmp1 = g_idx + num_groups
+ tmp2 = g_idx < 0
+ tl.device_assert(g_idx >= 0, "index out of bounds: 0 <= tmp0 < 0")
+ groups = tl.where(tmp2, tmp1, g_idx) # tmp3 are g_idx
+
+ scales = tl.load(scales_ptr + (col_idx + (outfeatures * groups)), None).to(
+ tl.float32
+ )
+
+ # Unpack weights
+ weights = qweights >> wf_weights # bit shift qweight
+
+ weights = weights & maxq
+
+ # Unpack zeros
+ qzero_ncols: tl.constexpr = outfeatures // elements_per_feature
+ qzeros = tl.load(
+ qzeros_ptr + ((qzero_ncols * groups) + (col_idx // elements_per_feature)),
+ None,
+ eviction_policy="evict_last",
+ )
+ zeros = qzeros >> wf_zeros
+ zeros = zeros & maxq
+
+ # Dequantize
+ # zeros = zeros + 1
+ weights = weights - zeros
+ weights = weights.to(tl.float32)
+ weights = scales * weights
+
+ tl.store(out_ptr + (x_index), weights, mask=xmask)
+
+
+def dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None):
+ """
+ Launcher for triton dequant kernel. Only valid for bits = 2, 4, 8
+ """
+
+ num_groups = scales.shape[0]
+ outfeatures = scales.shape[1]
+ infeatures = g_idx.shape[0]
+
+ out = torch.empty((infeatures, outfeatures), device="cuda", dtype=torch.float16)
+ numels = out.numel()
+ maxq = 2**bits - 1 if maxq is None else maxq
+ grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731
+
+ dequant_kernel_248[grid](
+ g_idx,
+ scales,
+ qweight,
+ qzeros,
+ out,
+ numels,
+ maxq=maxq,
+ bits=bits,
+ outfeatures=outfeatures,
+ num_groups=num_groups,
+ )
+ return out
+
+
+def quant_matmul_248(
+ input, qweight, scales, qzeros, g_idx, bits, maxq=None, transpose=False
+):
+ W = dequant248(qweight, scales, qzeros, g_idx, bits, maxq=maxq)
+ if transpose:
+ return input @ W.t()
+ return input @ W
+
+
+class QuantLinearFunction(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
+ output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
+ ctx.save_for_backward(qweight, scales, qzeros, g_idx)
+ ctx.bits, ctx.maxq = bits, maxq
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ qweight, scales, qzeros, g_idx = ctx.saved_tensors
+ bits, maxq = ctx.bits, ctx.maxq
+ grad_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad_input = quant_matmul_248(
+ grad_output, qweight, scales, qzeros, g_idx, bits, maxq, transpose=True
+ )
+ return grad_input, None, None, None, None, None, None
diff --git a/auto_round_extension/cuda/triton_utils/kernels.py b/auto_round_extension/cuda/triton_utils/kernels.py
new file mode 100644
index 00000000..44981c7a
--- /dev/null
+++ b/auto_round_extension/cuda/triton_utils/kernels.py
@@ -0,0 +1,500 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MIT License
+#
+# Copyright (c) 2023 潘其威(William)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+from logging import getLogger
+
+import torch
+import triton
+import triton.language as tl
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+from . import custom_autotune
+
+
+logger = getLogger(__name__)
+
+
+# code based https://github.com/fpgaminer/GPTQ-triton
+
+
+@custom_autotune.autotune(
+ configs=[
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=2,
+ num_warps=8,
+ ),
+ ],
+ key=["M", "N", "K"],
+ nearest_power_of_two=True,
+ prune_configs_by={
+ "early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
+ "perf_model": None,
+ "top_k": None,
+ },
+)
+@triton.jit
+def quant_matmul_248_kernel(
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ scales_ptr,
+ zeros_ptr,
+ g_ptr,
+ M,
+ N,
+ K,
+ bits,
+ maxq,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ stride_scales,
+ stride_zeros,
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+):
+ """
+ Compute the matrix multiplication C = A x B.
+ A is of shape (M, K) float16
+ B is of shape (K//8, N) int32
+ C is of shape (M, N) float16
+ scales is of shape (G, N) float16
+ zeros is of shape (G, N) float16
+ g_ptr is of shape (K) int32
+ """
+ infearure_per_bits = 32 // bits
+
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
+ a_mask = offs_am[:, None] < M
+ # b_ptrs is set up such that it repeats elements along the K axis 8 times
+ b_ptrs = b_ptr + (
+ (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn
+ ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
+ g_ptrs = g_ptr + offs_k
+ # shifter is used to extract the N bits of each element in the 32-bit word from B
+ scales_ptrs = scales_ptr + offs_bn[None, :]
+ zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
+
+ shifter = (offs_k % infearure_per_bits) * bits
+ zeros_shifter = (offs_bn % infearure_per_bits) * bits
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+
+ for k in range(0, num_pid_k):
+ g_idx = tl.load(g_ptrs)
+
+ # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
+ scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+ zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+
+ zeros = (zeros >> zeros_shifter[None, :]) & maxq
+ # zeros = zeros + 1
+
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
+ b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
+
+ # Now we need to unpack b (which is N-bit values) into 32-bit values
+ b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
+ b = (b - zeros) * scales # Scale and shift
+
+ accumulator += tl.dot(a, b)
+ a_ptrs += BLOCK_SIZE_K
+ b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
+ g_ptrs += BLOCK_SIZE_K
+
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
+ c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
+ tl.store(c_ptrs, accumulator, mask=c_mask)
+
+
+@custom_autotune.autotune(
+ configs=[
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=2,
+ num_warps=8,
+ ),
+ ],
+ key=["M", "N", "K"],
+ nearest_power_of_two=True,
+)
+@triton.jit
+def transpose_quant_matmul_248_kernel(
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ scales_ptr,
+ zeros_ptr,
+ g_ptr,
+ M,
+ N,
+ K,
+ bits,
+ maxq,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ stride_scales,
+ stride_zeros,
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+):
+ """
+ Compute the matrix multiplication C = A x B.
+ A is of shape (M, N) float16
+ B is of shape (K//8, N) int32
+ C is of shape (M, K) float16
+ scales is of shape (G, N) float16
+ zeros is of shape (G, N) float16
+ g_ptr is of shape (K) int32
+ """
+ infearure_per_bits = 32 // bits
+
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_k
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_k = (pid % num_pid_in_group) // group_size_m
+
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
+ a_mask = offs_am[:, None] < M
+ # b_ptrs is set up such that it repeats elements along the K axis 8 times
+ b_ptrs = b_ptr + (
+ (offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn
+ ) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
+ g_ptrs = g_ptr + offs_bk
+ g_idx = tl.load(g_ptrs)
+
+ # shifter is used to extract the N bits of each element in the 32-bit word from B
+ scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
+ zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
+
+ shifter = (offs_bk % infearure_per_bits) * bits
+ zeros_shifter = (offs_n % infearure_per_bits) * bits
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
+
+ for k in range(0, num_pid_n):
+ # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
+ scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+ zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
+
+ zeros = (zeros >> zeros_shifter[None, :]) & maxq
+ # zeros = zeros + 1
+
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
+ b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
+
+ # Now we need to unpack b (which is N-bit values) into 32-bit values
+ b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
+ b = (b - zeros) * scales # Scale and shift
+ b = tl.trans(b)
+
+ accumulator += tl.dot(a, b)
+ a_ptrs += BLOCK_SIZE_N
+ b_ptrs += BLOCK_SIZE_N
+ scales_ptrs += BLOCK_SIZE_N
+ zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits
+
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
+ c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
+ tl.store(c_ptrs, accumulator, mask=c_mask)
+
+
+@triton.jit
+def silu(x):
+ return x * tl.sigmoid(x)
+
+
+def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
+ with torch.cuda.device(input.device):
+ output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)
+ grid = lambda META: ( # noqa: E731
+ triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
+ )
+ quant_matmul_248_kernel[grid](
+ input,
+ qweight,
+ output,
+ scales.to(input.dtype),
+ qzeros,
+ g_idx,
+ input.shape[0],
+ qweight.shape[1],
+ input.shape[1],
+ bits,
+ maxq,
+ input.stride(0),
+ input.stride(1),
+ qweight.stride(0),
+ qweight.stride(1),
+ output.stride(0),
+ output.stride(1),
+ scales.stride(0),
+ qzeros.stride(0),
+ )
+ return output
+
+
+def transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
+ with torch.cuda.device(input.device):
+ output_dim = (qweight.shape[0] * 32) // bits
+ output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype)
+ grid = lambda META: ( # noqa: E731
+ triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(output_dim, META["BLOCK_SIZE_K"]),
+ )
+ transpose_quant_matmul_248_kernel[grid](
+ input,
+ qweight,
+ output,
+ scales.to(input.dtype),
+ qzeros,
+ g_idx,
+ input.shape[0],
+ qweight.shape[1],
+ output_dim,
+ bits,
+ maxq,
+ input.stride(0),
+ input.stride(1),
+ qweight.stride(0),
+ qweight.stride(1),
+ output.stride(0),
+ output.stride(1),
+ scales.stride(0),
+ qzeros.stride(0),
+ )
+ return output
+
+
+class QuantLinearFunction(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
+ output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
+ ctx.save_for_backward(qweight, scales, qzeros, g_idx)
+ ctx.bits, ctx.maxq = bits, maxq
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ qweight, scales, qzeros, g_idx = ctx.saved_tensors
+ bits, maxq = ctx.bits, ctx.maxq
+ grad_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
+ return grad_input, None, None, None, None, None, None
+
+
+def quant_matmul_inference_only_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
+ with torch.cuda.device(input.device):
+ output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
+ grid = lambda META: ( # noqa: E731
+ triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
+ )
+ quant_matmul_248_kernel[grid](
+ input,
+ qweight,
+ output,
+ scales,
+ qzeros,
+ g_idx,
+ input.shape[0],
+ qweight.shape[1],
+ input.shape[1],
+ bits,
+ maxq,
+ input.stride(0),
+ input.stride(1),
+ qweight.stride(0),
+ qweight.stride(1),
+ output.stride(0),
+ output.stride(1),
+ scales.stride(0),
+ qzeros.stride(0),
+ )
+ return output
+
+
+class QuantLinearInferenceOnlyFunction(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
+ output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
+ return output
diff --git a/auto_round_extension/cuda/triton_utils/mixin.py b/auto_round_extension/cuda/triton_utils/mixin.py
new file mode 100644
index 00000000..557d3b48
--- /dev/null
+++ b/auto_round_extension/cuda/triton_utils/mixin.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MIT License
+#
+# Copyright (c) 2023 潘其威(William)
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+class TritonModuleMixin:
+ @classmethod
+ def warmup(cls, model, transpose=False, seqlen=2048):
+ pass
diff --git a/examples/language-modeling/eval_042/evaluation.py b/examples/language-modeling/eval_042/evaluation.py
index 2af9dd99..193ce9fc 100644
--- a/examples/language-modeling/eval_042/evaluation.py
+++ b/examples/language-modeling/eval_042/evaluation.py
@@ -574,7 +574,7 @@ def evaluate(
)
parser.add_argument("--tasks",
default="lambada_openai,hellaswag,winogrande,piqa,mmlu,truthfulqa_mc1," \
- "truthfulqa_mc2,openbookqa,boolq,rte,arc_easy,arc_challenge",
+ "openbookqa,boolq,rte,arc_easy,arc_challenge",
help="lm-eval tasks for lm_eval version 0.4.2")
args = parser.parse_args()
@@ -582,6 +582,7 @@ def evaluate(
from transformers import AutoConfig
config = AutoConfig.from_pretrained(args.model_name)
+
if hasattr(config, "quantization_config"):
quantization_config = config.quantization_config
if "quant_method" in quantization_config and "auto-round" in quantization_config["quant_method"]:
@@ -593,8 +594,12 @@ def evaluate(
model_name = args.model_name.rstrip('/')
from lm_eval.utils import make_table
+ model_args = f"pretrained={args.model_name}"
+ if config.torch_dtype == torch.float32:
+ model_args += ",dtype=float16"
+ model_args += ",dtype=float16"
result = simple_evaluate(model="hf",
- model_args=f"pretrained={args.model_name}",
+ model_args=model_args,
tasks=test_tasks,
batch_size=args.eval_bs)
print(make_table(result))
diff --git a/examples/language-modeling/main.py b/examples/language-modeling/main.py
index b0dce0e1..4cfb98a0 100644
--- a/examples/language-modeling/main.py
+++ b/examples/language-modeling/main.py
@@ -151,6 +151,7 @@ def get_library_version(library_name):
except subprocess.CalledProcessError:
return "Library not found"
+
res = get_library_version("lm-eval")
if res == "0.3.0":
use_eval_legacy = True
@@ -289,7 +290,7 @@ def get_library_version(library_name):
f"supported currently")
break
if args.quant_lm_head:
- weight_config[lm_head_layer_name] = {"data_type": "int"}
+ weight_config[lm_head_layer_name] = {"data_type": "int", "bits": 4, "group_size": 32}
transformers_version = [int(item) for item in transformers.__version__.split('.')[:2]]
if transformers_version[0] == 4 and transformers_version[1] < 38:
error_message = "Please upgrade transformers>=4.38.0 to support lm-head quantization."
@@ -302,7 +303,10 @@ def get_library_version(library_name):
gpu_format = "auto_gptq"
if 'gpu' in deployment_device:
if lm_head_layer_name in weight_config.keys() and weight_config[lm_head_layer_name]["data_type"] == "int":
- gpu_format = "autoround"
+ gpu_format = "auto_round"
+
+ if "autoround" in deployment_device or "auto-round" in deployment_device or "auto_round" in deployment_device:
+ gpu_format = "auto_round"
autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.train_bs,
dataset=args.dataset, seqlen=seqlen, n_blocks=args.n_blocks, iters=args.iters, lr=args.lr,
@@ -323,7 +327,7 @@ def get_library_version(library_name):
output_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-autoround-w{args.bits}g{args.group_size}-qdq"
inplace = True if len(deployment_device) < 2 else False
- if 'gpu' in deployment_device:
+ if 'gpu' in deployment_device or "auto_round" in gpu_format or "auto-round" in gpu_format:
autoround.save_quantized(f'{export_dir}-gpu', format=gpu_format, use_triton=True, inplace=inplace)
if 'xpu' in deployment_device:
autoround.save_quantized(f'{export_dir}-xpu', format="itrex_xpu", use_triton=True, inplace=inplace,
@@ -343,5 +347,3 @@ def get_library_version(library_name):
eval_model(model_path=output_dir, tasks=tasks, dtype=dtype, limit=None,
eval_bs=args.eval_bs, use_accelerate=not args.disable_low_gpu_mem_usage,
device=torch_device, excel_file=excel_name)
-
-
diff --git a/requirements.txt b/requirements.txt
index 09874574..cb8df6bb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,7 @@
accelerate
-auto-gptq
datasets
py-cpuinfo
sentencepiece
torch
transformers
+triton
diff --git a/setup.py b/setup.py
index 4fe2b30d..67f58a27 100644
--- a/setup.py
+++ b/setup.py
@@ -1,8 +1,11 @@
import re
from io import open
-
+import os
from setuptools import find_packages, setup
+import sys
+os.environ["CC"] = "g++"
+os.environ["CXX"] = "g++"
try:
filepath = "./auto_round/version.py"
with open(filepath) as version_file:
@@ -10,17 +13,154 @@
except Exception as error:
assert False, "Error: Could not open '%s' due %s\n" % (filepath, error)
+version = __version__
+
def fetch_requirements(path):
with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
+BUILD_CUDA_EXT = int(os.environ.get('BUILD_CUDA_EXT', '1')) == 1
+PYPI_RELEASE = os.environ.get('PYPI_RELEASE', None)
+
+
+def detect_local_sm_architectures():
+ """
+ Detect compute capabilities of one machine's GPUs as PyTorch does.
+
+ Copied from https://github.com/pytorch/pytorch/blob/v2.2.2/torch/utils/cpp_extension.py#L1962-L1976
+ """
+ arch_list = []
+
+ for i in range(torch.cuda.device_count()):
+ capability = torch.cuda.get_device_capability(i)
+ supported_sm = [int(arch.split('_')[1])
+ for arch in torch.cuda.get_arch_list() if 'sm_' in arch]
+ max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)
+ # Capability of the device may be higher than what's supported by the user's
+ # NVCC, causing compilation error. User's NVCC is expected to match the one
+ # used to build pytorch, so we use the maximum supported capability of pytorch
+ # to clamp the capability.
+ capability = min(max_supported_sm, capability)
+ arch = f'{capability[0]}.{capability[1]}'
+ if arch not in arch_list:
+ arch_list.append(arch)
+
+ arch_list = sorted(arch_list)
+ arch_list[-1] += '+PTX'
+ return arch_list
+
+
+UNSUPPORTED_COMPUTE_CAPABILITIES = ['3.5', '3.7', '5.0', '5.2', '5.3']
+requirements = [
+ "torch",
+ "accelerate",
+ "datasets",
+ "sentencepiece",
+ "safetensors",
+ "transformers",
+ "tqdm",
+ 'py-cpuinfo'
+ 'sentencepiece'
+]
+
+if BUILD_CUDA_EXT:
+ try:
+ import torch
+ except Exception as e:
+ print(
+ f"Building PyTorch CUDA extension requires PyTorch being installed, please install PyTorch first: {e}.\n NOTE: This issue may be raised due to pip build isolation system (ignoring local packages). Please use `--no-build-isolation` when installing with pip, and refer to https://github.com/AutoGPTQ/AutoGPTQ/pull/620 for more details.")
+ sys.exit(1)
+
+ CUDA_VERSION = None
+ ROCM_VERSION = os.environ.get('ROCM_VERSION', None)
+ if ROCM_VERSION and not torch.version.hip:
+ print(
+ f"Trying to compile auto-gptq for ROCm, but PyTorch {torch.__version__} "
+ "is installed without ROCm support."
+ )
+ sys.exit(1)
+
+ if not ROCM_VERSION:
+ default_cuda_version = torch.version.cuda
+ CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", default_cuda_version).split("."))
+
+ if ROCM_VERSION:
+ version += f"+rocm{ROCM_VERSION}"
+ else:
+ if not CUDA_VERSION:
+ print(
+ f"Trying to compile auto-gptq for CUDA, but Pytorch {torch.__version__} "
+ "is installed without CUDA support."
+ )
+ sys.exit(1)
+
+ torch_cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
+ if torch_cuda_arch_list is not None:
+ torch_cuda_arch_list = torch_cuda_arch_list.replace(' ', ';')
+ archs = torch_cuda_arch_list.split(';')
+
+ requested_but_unsupported_archs = {arch for arch in archs if arch in UNSUPPORTED_COMPUTE_CAPABILITIES}
+ if len(requested_but_unsupported_archs) > 0:
+ raise ValueError(
+ f"Trying to compile AutoGPTQ for CUDA compute capabilities {torch_cuda_arch_list}, but AutoGPTQ does not support the compute capabilities {requested_but_unsupported_archs} (AutoGPTQ requires Pascal or higher). Please fix your environment variable TORCH_CUDA_ARCH_LIST (Reference: https://github.com/pytorch/pytorch/blob/v2.2.2/setup.py#L135-L139).")
+ else:
+ local_arch_list = detect_local_sm_architectures()
+ local_but_unsupported_archs = {arch for arch in local_arch_list if arch in UNSUPPORTED_COMPUTE_CAPABILITIES}
+ if len(local_but_unsupported_archs) > 0:
+ raise ValueError(
+ f"PyTorch detected the compute capabilities {local_arch_list} for the NVIDIA GPUs on the current machine, but AutoGPTQ can not be built for compute capabilities {local_but_unsupported_archs} (AutoGPTQ requires Pascal or higher). Please set the environment variable TORCH_CUDA_ARCH_LIST (Reference: https://github.com/pytorch/pytorch/blob/v2.2.2/setup.py#L135-L139) with your necessary architectures.")
+
+ # For the PyPI release, the version is simply x.x.x to comply with PEP 440.
+ if not PYPI_RELEASE:
+ version += f"+cu{CUDA_VERSION}"
+
+additional_setup_kwargs = {}
+include_dirs = ["autoround_cuda"]
+if BUILD_CUDA_EXT:
+ from torch.utils import cpp_extension
+
+ if not ROCM_VERSION:
+ from distutils.sysconfig import get_python_lib
+
+ conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
+
+ print("conda_cuda_include_dir", conda_cuda_include_dir)
+ if os.path.isdir(conda_cuda_include_dir):
+ include_dirs.append(conda_cuda_include_dir)
+ print(f"appending conda cuda include dir {conda_cuda_include_dir}")
+ if os.name == "nt":
+ # On Windows, fix an error LNK2001: unresolved external symbol cublasHgemm bug in the compilation
+ cuda_path = os.environ.get("CUDA_PATH", None)
+ if cuda_path is None:
+ raise ValueError(
+ "The environment variable CUDA_PATH must be set to the path to the CUDA install when installing from source on Windows systems.")
+ extra_link_args = ["-L", f"{cuda_path}/lib/x64/cublas.lib"]
+ else:
+ extra_link_args = []
+ extensions = []
+ extensions.append(
+ cpp_extension.CUDAExtension(
+ "autoround_exllamav2_kernels",
+ [
+ "auto_round_extension/cuda/exllamav2/ext.cpp",
+ "auto_round_extension/cuda/exllamav2/cuda/q_matrix.cu",
+ "auto_round_extension/cuda/exllamav2/cuda/q_gemm.cu",
+ ],
+ extra_link_args=extra_link_args
+ )
+ )
+ additional_setup_kwargs = {
+ "ext_modules": extensions,
+ "cmdclass": {'build_ext': cpp_extension.BuildExtension}
+ }
+
if __name__ == "__main__":
setup(
name="auto_round",
author="Intel AIPT Team",
- version=__version__,
+ version=version,
author_email="wenhua.cheng@intel.com, weiwei1.zhang@intel.com",
description="Repository of AutoRound: Advanced Weight-Only Quantization Algorithm for LLMs",
long_description=open("README.md", "r", encoding="utf-8").read(),
@@ -29,7 +169,8 @@ def fetch_requirements(path):
license="Apache 2.0",
url="https://github.com/intel/auto-round",
packages=find_packages(),
- include_package_data=False,
+ include_dirs=include_dirs,
+ ##include_package_data=False,
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.7.0",
classifiers=[
@@ -38,4 +179,5 @@ def fetch_requirements(path):
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: Apache Software License",
],
+ **additional_setup_kwargs
)