Skip to content

Commit

Permalink
Update PyTorch pin and enable MPS qops (#725)
Browse files Browse the repository at this point in the history
* Update PyTorch pin

And enable linter:int8 and linter:int4 acceleration on MPS

* Update run-readme-pr.yml

* Update install_requirements.sh
  • Loading branch information
malfet authored May 9, 2024
1 parent 4b69985 commit 8a59fd3
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/run-readme-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ jobs:
uname -a
echo "::endgroup::"
# echo "::group::Install newer objcopy that supports --set-section-alignment"
# yum install -y devtoolset-10-binutils
# export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
# echo "::endgroup::"
echo "::group::Install newer objcopy that supports --set-section-alignment"
yum install -y devtoolset-10-binutils
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
echo "::endgroup::"
echo "::group::Create script to run README"
python3 scripts/updown.py --file README.md --replace 'llama3:stories15M,-l 3:-l 2,meta-llama/Meta-Llama-3-8B-Instruct:stories15M' --suppress huggingface-cli,HF_TOKEN > ./run-readme.sh
Expand Down
4 changes: 3 additions & 1 deletion install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ $PIP_EXECUTABLE install -r requirements.txt --extra-index-url https://download.p
# NOTE: If a newly-fetched version of the executorch repo changes the value of
# NIGHTLY_VERSION, you should re-run this script to install the necessary
# package versions.
NIGHTLY_VERSION=dev20240422
NIGHTLY_VERSION=dev20240507

# The pip repository that hosts nightly torch packages. cpu by default.
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
# with cuda for faster execution on cuda GPUs.
if [[ -x "$(command -v nvidia-smi)" ]];
then
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cu121"
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
$PIP_EXECUTABLE uninstall -y triton
else
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
fi
Expand Down
14 changes: 10 additions & 4 deletions qops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def linear_int8_aoti(input, weight, scales):
scales = scales.view(-1)
if (
torch.compiler.is_compiling()
or input.device.type != "cpu"
or input.device.type not in ["cpu", "mps"]
or not hasattr(torch.ops.aten, "_weight_int8pack_mm")
):
lin = F.linear(input, weight.to(dtype=input.dtype))
Expand Down Expand Up @@ -395,9 +395,15 @@ def _prepare_weight_and_scales_and_zeros(
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32, inner_k_tiles
)
if weight_bf16.device.type == "mps":
# There are still no MPS-accelerated conversion OP
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32.cpu(), inner_k_tiles
).to("mps")
else:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32, inner_k_tiles
)
return weight_int4pack, scales_and_zeros

@classmethod
Expand Down

0 comments on commit 8a59fd3

Please sign in to comment.