Skip to content

Commit

Permalink
Refactor LinearActQuantizedTensor (#542)
Browse files Browse the repository at this point in the history
Summary:
* rename to LinearActivationQuantizedTensor
* using `implements` util to implement torch function and torch dispatch overwrites

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jul 26, 2024
1 parent c9f79be commit afde175
Show file tree
Hide file tree
Showing 13 changed files with 288 additions and 269 deletions.
12 changes: 7 additions & 5 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from torchao.dtypes import (
AffineQuantizedTensor,
)
from torchao.quantization import (
LinearActivationQuantizedTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.subclass import (
LinearActQuantizedTensor,
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
)
Expand Down Expand Up @@ -504,8 +506,8 @@ def test_quantized_tensor_subclass_8da4w(self):
example_inputs = m.example_inputs()
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

Expand Down Expand Up @@ -577,8 +579,8 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
quantize_(m, int8_dynamic_activation_int8_weight())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

Expand Down
60 changes: 22 additions & 38 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from torchao.utils import find_multiple
from torchao.dtypes.utils import (
_implements,
_ATEN_OP_OR_TORCH_FN_TABLE,
_dispatch__torch_function__,
_dispatch__torch_dispatch__,
_register_layout_cls,
_get_layout_tensor_constructor,
LayoutType,
Expand Down Expand Up @@ -295,17 +296,6 @@ def from_float_static(
def layout_type(self) -> LayoutType:
return self.layout_tensor.layout_type

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs

if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs)

with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)


def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
Expand Down Expand Up @@ -347,29 +337,23 @@ def _apply_fn_to_data(self, fn):
strides=self.stride(),
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
# cpu device + et laytout --> gives current 8da4w executorch representation
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
# cuda device + some layout --> gives cuda kernel

# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
# kernels in CPU as well, see the note above
# 2 - we're given non-floats - quantizing long to int8 is crazy

if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
implements = classmethod(_implements)
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
# cpu device + et laytout --> gives current 8da4w executorch representation
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
# cuda device + some layout --> gives cuda kernel

raise NotImplementedError(
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)
# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
# kernels in CPU as well, see the note above
# 2 - we're given non-floats - quantizing long to int8 is crazy
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
__torch_function__ = classmethod(_dispatch__torch_function__)

def implements(aten_ops_or_torch_fn):
return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn)
implements = AffineQuantizedTensor.implements

def register_layout_cls(layout_type_class: type(LayoutType)):
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)
Expand Down Expand Up @@ -827,7 +811,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):


@implements(torch.nn.functional.linear)
def functional_linear(*args, **kwargs):
def _(func, types, *args, **kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
Expand All @@ -846,7 +830,7 @@ def functional_linear(*args, **kwargs):
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@implements([aten.mm.default, aten.addmm.default])
def aten_mm(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
if not args[0].is_floating_point():
raise NotImplementedError(f"{func} is not implemented for non floating point input")

Expand Down Expand Up @@ -885,21 +869,21 @@ def aten_mm(func, *args, **kwargs):
return func(input_tensor, weight_tensor)

@implements([aten.detach.default])
def detach(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements([aten.clone.default])
def clone(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements([aten._to_copy.default])
def _to_copy(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func,
args,
Expand All @@ -908,7 +892,7 @@ def _to_copy(func, *args, **kwargs):
)

@implements([aten.t.default])
def t(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
block_size = args[0].block_size
assert len(block_size) == 2
transposed_block_size = (block_size[1], block_size[0])
Expand Down
56 changes: 48 additions & 8 deletions torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@
from dataclasses import dataclass

"""
torch_function and torch_dispatch operator dispatch registrations
first key is a tensor subclass type like AffineQuantizedTensor,
second key is a `func` in __torhc_function__ or __torch_dispatch__,
value is a function that implements the dispatch
Helper function for implementing aten op or torch function dispatch
and dispatching to these implementations.
"""
_ATEN_OP_OR_TORCH_FN_TABLE: Dict[Callable, Dict[Callable, Callable]] = defaultdict(dict)

def _implements(cls, aten_ops_or_torch_fns):
"""Use this decorator to implement a function for an aten ops in __torch_dispatch__
(if user passed in a list of ops)
or torch function in __torch_function__ (if user passed in a single object)
class MyTensor(torch.Tensor):
...
implements = classmethod(_implements)
implements = MyTensor.implements
@implements(torch.nn.functional.linear):
def _(func, types, args, kwargs):
...
"""
if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"):
cls._ATEN_OP_OR_TORCH_FN_TABLE = {}

if not isinstance(aten_ops_or_torch_fns, (list, tuple)):
aten_ops_or_torch_fns = [aten_ops_or_torch_fns]
def decorator(func):
Expand All @@ -26,10 +35,41 @@ def decorator(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

_ATEN_OP_OR_TORCH_FN_TABLE[cls][op] = wrapper
cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper
return func
return decorator

def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None):
"""Use this util function for a common `__torch_function__` implementation
that dispatches to ops/functions registered with `_implements`
class MyTensor(torch.Tensor):
...
__torch_function__ = classmethod(_dispatch__torch_function__)
"""
kwargs = {} if kwargs is None else kwargs
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)

with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)

def _dispatch__torch_dispatch__(cls, func, types, args, kwargs):
"""Use this util function for a common `__torch_dispatch__` implementation
that dispatches to ops/functions registered with `_implements`
class MyTensor(torch.Tensor):
...
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
"""
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")


"""
Base class for different LayoutType, should not be instantiated directly
"""
Expand Down
17 changes: 6 additions & 11 deletions torchao/prototype/low_bit_optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import Tensor
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__

from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap

Expand Down Expand Up @@ -85,16 +85,11 @@ def __repr__(self):
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)


@OptimState4bit.implements(aten.copy_.default)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
dst = args[0]
src = args[1]

Expand All @@ -121,14 +116,14 @@ def _(func, *args, **kwargs):


@OptimState4bit.implements(aten.lerp.Scalar)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local() and for flattening tensor
@OptimState4bit.implements(aten.view.default)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
x, shape = args

if tuple(x.shape) == tuple(shape):
Expand All @@ -147,7 +142,7 @@ def _(func, *args, **kwargs):
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimState4bit):
raise ValueError(f"expecting a OptimState4bit but found {type(x)}")
Expand Down
17 changes: 6 additions & 11 deletions torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import Tensor
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__

from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap

Expand Down Expand Up @@ -71,16 +71,11 @@ def __repr__(self):
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)


@OptimState8bit.implements(aten.copy_.default)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
dst = args[0]
src = args[1]

Expand All @@ -103,14 +98,14 @@ def _(func, *args, **kwargs):


@OptimState8bit.implements(aten.lerp.Scalar)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimState8bit.implements(aten.view.default)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
x, shape = args
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed)

Expand All @@ -122,7 +117,7 @@ def _(func, *args, **kwargs):
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimState8bit):
raise ValueError(f"expecting a OptimState8bit but found {type(x)}")
Expand Down
Loading

0 comments on commit afde175

Please sign in to comment.