-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Int4CPULayout and update int4 woq #1278
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1278
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 8 New Failures, 2 PendingAs of commit 98b8f8c with merge base 01dc7da (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
1b26f26
to
104d1f3
Compare
we are doing a refactor for file structure btw: #1234 might be good to rebase after that is landed |
@@ -102,7 +102,8 @@ def _groupwise_affine_quantize_tensor_from_qparams( | |||
.reshape_as(w) | |||
) | |||
if TORCH_VERSION_AT_LEAST_2_5: | |||
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) | |||
if w.device.type != "cpu": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe you can use
Line 60 in 39f16f4
def is_device(target_device_str: str, device: Union[str, torch.device]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -630,6 +630,11 @@ def extra_repr(self): | |||
return f"inner_k_tiles={self.inner_k_tiles}" | |||
|
|||
|
|||
@dataclass(frozen=True) | |||
class Int4CPULayout(Layout): | |||
def pre_process(self, input: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to define this if it's the same as the default behavior?
you can just do pass
here I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the default behavior is ok, use pass
instead.
|
||
__torch_function__ = torch._C._disabled_torch_function_impl | ||
|
||
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have an unpack op for tensor core tiled layout now, so this can actually be replaced with a call to the op:
ao/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu
Lines 311 to 312 in 39f16f4
m.impl("torchao::unpack_tensor_core_tiled_layout", &_unpack_tensor_core_tiled_layout); | |
m.impl("torchao::dequantize_tensor_core_tiled_layout", &_dequantize_tensor_core_tiled_layout); |
do you plan to write similar ops for cpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have noticed this, but I have no bandwidth to do so these days. If you are not urgent for this feature, I can take this task.
cc @mingfeima
torchao/quantization/subclass.py
Outdated
@@ -609,5 +617,8 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): | |||
input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor( | |||
input_float, 4, groupsize, dtype=input_float.dtype | |||
) | |||
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) | |||
if input_float.device == torch.device("cpu"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, can probably use
Line 60 in 39f16f4
def is_device(target_device_str: str, device: Union[str, torch.device]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torchao/quantization/utils.py
Outdated
# if int_data_device_type == "mps": | ||
# int_data = int_data.cpu() | ||
if int_data_device_type != "cpu": | ||
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) | ||
# if int_data_device_type == "mps": | ||
# int_data = int_data.to(device="mps") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove the code that's commented out
is this equivalent to previous code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to #517 (comment), <<
can be used in MPS backend, don't need to convert to CPU
and use CPU
backend. Since I don't have mps machine, I want to use CI to check if this can work. Otherwise, I can update to int_data = (torch.bitwise_left_shift(int_data[::, ::2], 4) | int_data[::, 1::2]).to(torch.uint8)
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be a separate PR, but can you also help add support for conversion between int4 tensor core tiled layout and int4 cpu layout, we may need a separate util for this, like we discussed in the issue: #1117 (comment)
right now we error out when converting between different devices
ao/torchao/dtypes/affine_quantized_tensor.py
Lines 1486 to 1489 in 39f16f4
if not is_device(torch.device(self.device).type, device): | |
raise ValueError( | |
f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}" | |
) |
Test can be added in
ao/test/dtypes/test_affine_quantized.py
Line 44 in 39f16f4
class TestAffineQuantized(TestCase): |
) | ||
|
||
def _apply_fn_to_data(self, fn): | ||
# self.packed_weight = fn(self.packed_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove commented code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
def to(self, *args, **kwargs): | ||
kwargs = self._get_to_kwargs(*args, **kwargs) | ||
device = kwargs["device"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should ban the device change like
ao/torchao/dtypes/affine_quantized_tensor.py
Lines 1486 to 1489 in 39f16f4
if not is_device(torch.device(self.device).type, device): | |
raise ValueError( | |
f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}" | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add it back.
512eb75
to
98b8f8c
Compare
pytorch/pytorch#139611 is merged into PyTorch main branch.