From a117d9d861f894d939896404c96da336e22f72bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 9 Nov 2023 18:47:51 +0100 Subject: [PATCH] Add AutoAWQ integration --- outlines/models/awq.py | 45 ++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 ++ 2 files changed, 47 insertions(+) create mode 100644 outlines/models/awq.py diff --git a/outlines/models/awq.py b/outlines/models/awq.py new file mode 100644 index 000000000..3a2418080 --- /dev/null +++ b/outlines/models/awq.py @@ -0,0 +1,45 @@ +from typing import TYPE_CHECKING, Optional + +from .transformers import Transformer, TransformerTokenizer + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + +class AWQModel(Transformer): + """Represents a `transformers` model.""" + + def __init__( + self, + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + ): + self.device = model.model.device + self.model = model + self.tokenizer = tokenizer + + +def awq( + model_name: str, + fuse_layers: bool = True, + device: Optional[str] = None, + model_kwargs: dict = {}, + tokenizer_kwargs: dict = {}, +): + try: + from awq import AutoAWQForCausalLM + except ImportError: + raise ImportError( + "The `autoawq` and `transformers` library needs to be installed in order to use `AutoAWQ` models." + ) + + model_kwargs["fuse_layers"] = fuse_layers + model_kwargs["safetensors"] = True + + if device is not None: + model_kwargs["device_map"] = device + + model = AutoAWQForCausalLM.from_quantized(model_name, **model_kwargs) + tokenizer = TransformerTokenizer(model_name, trust_remote_code=True) + + return AWQModel(model, tokenizer) diff --git a/pyproject.toml b/pyproject.toml index bcab189cb..87493cb03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,8 @@ exclude=["examples"] [[tool.mypy.overrides]] module = [ + "awq.*", + "auto_gptq.*", "jinja2", "joblib.*", "jsonschema.*",