Skip to content

Commit

Permalink
Merge branch 'main' into mps-readme
Browse files Browse the repository at this point in the history
  • Loading branch information
mikekgfb authored May 9, 2024
2 parents 546e667 + a89913d commit d39e296
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 11 deletions.
1 change: 1 addition & 0 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def _initialize_model(
"Cannot load specified DSO to MPS. Attempting to load model to CPU instead"
)
builder_args.device = "cpu"

# Replace model forward with the AOT-compiled forward
# This is a hacky way to quickly demo AOTI's capability.
# model is still a Python object, and any mutation to its
Expand Down
2 changes: 1 addition & 1 deletion config/data/desktop.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"executor": {"accelerator": "fast"},
"precision": {"dtype" : "fast16"},
"precision": {"dtype" : "fast16"}
}
6 changes: 1 addition & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,11 +615,7 @@ def _main(
# arbitrarily large number as chat mode goes until max_seq length
# or user exits
num_samples = generator_args.num_samples if not generator_args.chat_mode else 100000
i = (
-1
) # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
while i < num_samples:
i += 1
for i in range(num_samples):
device_sync(device=builder_args.device)
if i >= 0 and generator_args.chat_mode:
prompt = input("User: ")
Expand Down
4 changes: 3 additions & 1 deletion install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ $PIP_EXECUTABLE install -r requirements.txt --extra-index-url https://download.p
# NOTE: If a newly-fetched version of the executorch repo changes the value of
# NIGHTLY_VERSION, you should re-run this script to install the necessary
# package versions.
NIGHTLY_VERSION=dev20240422
NIGHTLY_VERSION=dev20240507

# The pip repository that hosts nightly torch packages. cpu by default.
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
# with cuda for faster execution on cuda GPUs.
if [[ -x "$(command -v nvidia-smi)" ]];
then
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cu121"
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
$PIP_EXECUTABLE uninstall -y triton
else
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
fi
Expand Down
14 changes: 10 additions & 4 deletions qops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def linear_int8_aoti(input, weight, scales):
scales = scales.view(-1)
if (
torch.compiler.is_compiling()
or input.device.type != "cpu"
or input.device.type not in ["cpu", "mps"]
or not hasattr(torch.ops.aten, "_weight_int8pack_mm")
):
lin = F.linear(input, weight.to(dtype=input.dtype))
Expand Down Expand Up @@ -395,9 +395,15 @@ def _prepare_weight_and_scales_and_zeros(
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32, inner_k_tiles
)
if weight_bf16.device.type == "mps":
# There are still no MPS-accelerated conversion OP
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32.cpu(), inner_k_tiles
).to("mps")
else:
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32, inner_k_tiles
)
return weight_int4pack, scales_and_zeros

@classmethod
Expand Down

0 comments on commit d39e296

Please sign in to comment.