Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Int4CPULayout and update int4 woq #1278

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

yanbing-j
Copy link
Contributor

pytorch/pytorch#139611 is merged into PyTorch main branch.

Copy link

pytorch-bot bot commented Nov 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1278

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 8 New Failures, 2 Pending

As of commit 98b8f8c with merge base 01dc7da (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 13, 2024
@yanbing-j yanbing-j marked this pull request as ready for review November 14, 2024 02:47
@jerryzh168
Copy link
Contributor

we are doing a refactor for file structure btw: #1234 might be good to rebase after that is landed

@@ -102,7 +102,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
.reshape_as(w)
)
if TORCH_VERSION_AT_LEAST_2_5:
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
if w.device.type != "cpu":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you can use

def is_device(target_device_str: str, device: Union[str, torch.device]):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -630,6 +630,11 @@ def extra_repr(self):
return f"inner_k_tiles={self.inner_k_tiles}"


@dataclass(frozen=True)
class Int4CPULayout(Layout):
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to define this if it's the same as the default behavior?

you can just do pass here I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the default behavior is ok, use pass instead.


__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have an unpack op for tensor core tiled layout now, so this can actually be replaced with a call to the op:

m.impl("torchao::unpack_tensor_core_tiled_layout", &_unpack_tensor_core_tiled_layout);
m.impl("torchao::dequantize_tensor_core_tiled_layout", &_dequantize_tensor_core_tiled_layout);

do you plan to write similar ops for cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have noticed this, but I have no bandwidth to do so these days. If you are not urgent for this feature, I can take this task.

cc @mingfeima

@@ -609,5 +617,8 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor(
input_float, 4, groupsize, dtype=input_float.dtype
)
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
if input_float.device == torch.device("cpu"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can probably use

def is_device(target_device_str: str, device: Union[str, torch.device]):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 405 to 410
# if int_data_device_type == "mps":
# int_data = int_data.cpu()
if int_data_device_type != "cpu":
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
# if int_data_device_type == "mps":
# int_data = int_data.to(device="mps")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove the code that's commented out

is this equivalent to previous code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to #517 (comment), << can be used in MPS backend, don't need to convert to CPU and use CPU backend. Since I don't have mps machine, I want to use CI to check if this can work. Otherwise, I can update to int_data = (torch.bitwise_left_shift(int_data[::, ::2], 4) | int_data[::, 1::2]).to(torch.uint8) instead.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be a separate PR, but can you also help add support for conversion between int4 tensor core tiled layout and int4 cpu layout, we may need a separate util for this, like we discussed in the issue: #1117 (comment)

right now we error out when converting between different devices

if not is_device(torch.device(self.device).type, device):
raise ValueError(
f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}"
)
, this is fine I think, just need separate utils if people want to do this move.

Test can be added in

class TestAffineQuantized(TestCase):

)

def _apply_fn_to_data(self, fn):
# self.packed_weight = fn(self.packed_weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove commented code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs["device"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should ban the device change like

if not is_device(torch.device(self.device).type, device):
raise ValueError(
f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}"
)
as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add it back.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants