diff --git a/install_requirements.sh b/install_requirements.sh index cfb2862fa..bbc1b48d9 100755 --- a/install_requirements.sh +++ b/install_requirements.sh @@ -39,7 +39,7 @@ $PIP_EXECUTABLE install -r requirements.txt --extra-index-url https://download.p # NOTE: If a newly-fetched version of the executorch repo changes the value of # NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -NIGHTLY_VERSION=dev20240422 +NIGHTLY_VERSION=dev20240507 # The pip repository that hosts nightly torch packages. cpu by default. # If cuda is available, based on presence of nvidia-smi, install the pytorch nightly diff --git a/qops.py b/qops.py index ab86250ff..b4f172163 100644 --- a/qops.py +++ b/qops.py @@ -15,7 +15,7 @@ def linear_int8_aoti(input, weight, scales): scales = scales.view(-1) if ( torch.compiler.is_compiling() - or input.device.type != "cpu" + or input.device.type not in ["cpu", "mps"] or not hasattr(torch.ops.aten, "_weight_int8pack_mm") ): lin = F.linear(input, weight.to(dtype=input.dtype)) @@ -395,9 +395,15 @@ def _prepare_weight_and_scales_and_zeros( weight_int32, scales_and_zeros = group_quantize_tensor( weight_bf16, n_bit=4, groupsize=groupsize ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - weight_int32, inner_k_tiles - ) + if weight_bf16.device.type == "mps": + # There are still no MPS-accelerated conversion OP + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weight_int32.cpu(), inner_k_tiles + ).to("mps") + else: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weight_int32, inner_k_tiles + ) return weight_int4pack, scales_and_zeros @classmethod