Skip to content

Commit

Permalink
Update PyTorch pin
Browse files Browse the repository at this point in the history
And enable linter:int8 and linter:int4 acceleration on MPS
  • Loading branch information
malfet committed May 9, 2024
1 parent 23610c7 commit 4e520b3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ $PIP_EXECUTABLE install -r requirements.txt --extra-index-url https://download.p
# NOTE: If a newly-fetched version of the executorch repo changes the value of
# NIGHTLY_VERSION, you should re-run this script to install the necessary
# package versions.
NIGHTLY_VERSION=dev20240422
NIGHTLY_VERSION=dev20240507

# The pip repository that hosts nightly torch packages. cpu by default.
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
Expand Down
14 changes: 10 additions & 4 deletions qops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def linear_int8_aoti(input, weight, scales):
scales = scales.view(-1)
if (
torch.compiler.is_compiling()
or input.device.type != "cpu"
or input.device.type not in ["cpu", "mps"]
or not hasattr(torch.ops.aten, "_weight_int8pack_mm")
):
lin = F.linear(input, weight.to(dtype=input.dtype))
Expand Down Expand Up @@ -395,9 +395,15 @@ def _prepare_weight_and_scales_and_zeros(
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32, inner_k_tiles
)
if weight_bf16.device.type == "mps":
# There are still no MPS-accelerated conversion OP
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32.cpu(), inner_k_tiles
).to("mps")
else:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32, inner_k_tiles
)
return weight_int4pack, scales_and_zeros

@classmethod
Expand Down

0 comments on commit 4e520b3

Please sign in to comment.