Skip to content

Commit

Permalink
[aoti] Remove need for -l in cmake
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Jan 10, 2025
1 parent 654bb03 commit e0299ff
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 59 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ torchchat/utils/scripts/build_native.sh aoti

Then run the compiled executable, with the pt2.
```bash
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -i "Once upon a time"
```

## Mobile Execution
Expand Down
58 changes: 25 additions & 33 deletions runner/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ typedef struct {
typedef struct {
Config config; // the hyperparameters of the architecture (the blueprint)
RunState state; // buffers for the "wave" of activations in the forward pass
std::unordered_map<std::string, std::string> metadata;

#ifdef __AOTI_MODEL__
torch::inductor::AOTIModelPackageLoader *runner;
Expand Down Expand Up @@ -141,20 +142,9 @@ void read_checkpoint(char *checkpoint, Config *config) {
config->vocab_size = abs(config->vocab_size);
}

void build_transformer(Transformer *t, char *model_path, int vocab_size,
int seq_len) {
// read in the Config and the Weights from the model
// read_checkpoint(model_path, &t->config);
// allocate the RunState buffers
t->config.vocab_size = vocab_size;
t->config.seq_len = seq_len;
malloc_run_state(&t->state, &t->config);

void build_transformer(Transformer *t, char *model_path) {
#ifdef __AOTI_MODEL__
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu"
? torch::Device(torch::kCPU)
: torch::Device(torch::kCUDA);
#else //__ET_MODEL__
t->runner = new Module(
/* path to PTE model */ model_path,
Expand Down Expand Up @@ -848,37 +838,35 @@ int main(int argc, char *argv[]) {
system_prompt = argv[i + 1];
} else if (argv[i][1] == 'l') {
llama_ver = atoi(argv[i + 1]);
#ifdef __AOTI_MODEL__
} else if (argv[i][1] == 'd') {
#ifdef USE_CUDA
if (strcasecmp(argv[i + 1], "CUDA") == 0) {
aoti_device = torch::Device(torch::kCUDA);
} else
#endif
if (strcasecmp(argv[i + 1], "CPU") == 0) {
aoti_device = torch::Device(torch::kCPU);
} else {
fprintf(stderr, "Unknown device %s", argv[i + 1]);
exit(1);
}
#endif
} else {
error_usage();
}
}

if (model_path == NULL) {
fprintf(stderr, "No model_path provided.");
error_usage();
}

Transformer transformer;
build_transformer(&transformer, model_path);

#ifdef __AOTI_MODEL__
auto aoti_metadata = transformer.runner->get_metadata();
aoti_device = aoti_metadata["AOTI_DEVICE_KEY"] == "cpu"
? torch::Device(torch::kCPU)
: torch::Device(torch::kCUDA);
ModelType model_type = get_model_type(stoi(aoti_metadata["tokenizer_type"]));
#else // __ET_MODEL__
ModelType model_type = get_model_type(llama_ver);
#endif

if (model_type == UNKNOWN_MODEL) {
fprintf(stderr, "Unknown model type passed by -l argument. Received l=%d.",
llama_ver);
error_usage();
}

if (model_path == NULL) {
fprintf(stderr, "No model_path provided.");
error_usage();
}

if (tokenizer_path == NULL) {
fprintf(stderr, "No tokenizer_path provided.");
error_usage();
Expand All @@ -901,8 +889,12 @@ int main(int argc, char *argv[]) {
vocab_size = tokenizer->vocab_size();
}

Transformer transformer;
build_transformer(&transformer, model_path, vocab_size, steps);
// read in the Config and the Weights from the model
// read_checkpoint(model_path, &t->config);
// allocate the RunState buffers
transformer.config.vocab_size = vocab_size;
transformer.config.seq_len = steps;
malloc_run_state(&transformer.state, &transformer.config);

Sampler sampler;
build_sampler(&sampler, vocab_size, temperature, topp, rng_seed);
Expand Down
77 changes: 52 additions & 25 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional
from typing import Dict, Optional

import torch
import torch._inductor
import torch.nn as nn

from torch.export import Dim
import torch._inductor

from torchchat.cli.builder import (
_initialize_model,
Expand Down Expand Up @@ -39,6 +39,7 @@ def export_for_server(
output_path: str = "model.pt2",
dynamic_shapes: bool = False,
package: bool = True,
metadata: Optional[Dict[str, str]] = None,
) -> str:
"""
Export the model using AOT Compile to get a .dso for server use cases.
Expand Down Expand Up @@ -67,8 +68,10 @@ def export_for_server(
dynamic_shapes = None

with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
metadata = {} # TODO: put more metadata here
options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata}
options = {
"aot_inductor.package": package,
"aot_inductor.metadata": metadata or {},
}
if not package:
options = {"aot_inductor.output_path": output_path}

Expand All @@ -81,6 +84,7 @@ def export_for_server(

if package:
from torch._inductor.package import package_aoti

path = package_aoti(output_path, path)

print(f"The generated packaged model can be found at: {path}")
Expand All @@ -102,13 +106,13 @@ def export_for_server(
from typing import Any, Dict, Tuple, Union

import executorch.exir as exir
from executorch.backends.xnnpack._passes.convert_to_linear import (
ConvertToLinearPass,
)

from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackDynamicallyQuantizedPartitioner,
)
from executorch.backends.xnnpack._passes.convert_to_linear import (
ConvertToLinearPass,
)
from executorch.exir import EdgeProgramManager, to_edge

from executorch.exir.capture._config import (
Expand Down Expand Up @@ -166,18 +170,22 @@ def __init__(self, attention: Attention):

self.wo = attention.wo

max_batch_size, n_heads, max_seq_length, head_dim = (
attention.kv_cache[0].k_cache.shape
)
max_batch_size, n_heads, max_seq_length, head_dim = attention.kv_cache[
0
].k_cache.shape
cache_dtype = attention.kv_cache[0].k_cache.dtype
# The `Attention` module being replaced can have multiple KV caches
# (denoted by `cache_lanes`). Thus we follow the same setup format
# as in `Attention.setup_cache`.
cache_lanes = len(attention.kv_cache)
self.kv_cache = nn.ModuleList([
CustomKVCache(max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype)
for _ in range(cache_lanes)
])
self.kv_cache = nn.ModuleList(
[
CustomKVCache(
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
)
for _ in range(cache_lanes)
]
)

self.n_heads = attention.n_heads
self.head_dim = attention.head_dim
Expand Down Expand Up @@ -215,9 +223,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
return self.wo(output)

def replace_attention_with_custom_sdpa_attention(module: nn.Module):
from executorch.extension.llm.custom_ops import ( # noqa
sdpa_with_kv_cache,
)
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa

for name, child in module.named_children():
if isinstance(child, Attention):
Expand All @@ -238,7 +244,9 @@ def _to_core_aten(
raise ValueError(
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
)
core_aten_ep = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shapes)
core_aten_ep = export_for_training(
model, example_inputs, dynamic_shapes=dynamic_shapes
)
if verbose:
logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
return core_aten_ep
Expand Down Expand Up @@ -350,7 +358,11 @@ def main(args):

print(f"Using device={builder_args.device}")
set_precision(builder_args.precision)
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path)
set_backend(
dso=args.output_dso_path,
pte=args.output_pte_path,
aoti_package=args.output_aoti_package_path,
)

builder_args.dso_path = None
builder_args.pte_path = None
Expand All @@ -372,6 +384,7 @@ def main(args):

# TODO: clean this up
# This mess is because ET does not support _weight_int4pack_mm right now
tokenizer_args = None
if not builder_args.gguf_path:
# tokenizer needed for quantization so get that here,
try:
Expand All @@ -382,9 +395,8 @@ def main(args):

if builder_args.max_seq_length is None:
if (
(output_dso_path is not None or output_aoti_package_path is not None)
and not builder_args.dynamic_shapes
):
output_dso_path is not None or output_aoti_package_path is not None
) and not builder_args.dynamic_shapes:
print("Setting max_seq_length to 300 for DSO export.")
builder_args.max_seq_length = 300
elif output_pte_path is not None:
Expand All @@ -397,7 +409,8 @@ def main(args):
quantize,
tokenizer,
max_seq_length=builder_args.max_seq_length,
support_tensor_subclass=output_dso_path is None and output_aoti_package_path is None,
support_tensor_subclass=output_dso_path is None
and output_aoti_package_path is None,
)
model_to_pte = model
model_to_dso = model
Expand Down Expand Up @@ -435,7 +448,9 @@ def main(args):
if output_dso_path:
output_dso_path = str(os.path.abspath(output_dso_path))
print(f"Exporting model using AOT Inductor to {output_dso_path}")
print("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead.")
print(
"WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
)
export_for_server(
model_to_dso,
builder_args.device,
Expand All @@ -446,11 +461,23 @@ def main(args):

if output_aoti_package_path:
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")

if tokenizer_args is None:
tokenizer_type = "0"
elif tokenizer_args.is_sentencepiece:
tokenizer_type = "2" # Corresponding to llama2
else:
tokenizer_type = "3" # Corresponding to llama3

metadata = {"tokenizer_type": tokenizer_type}
print(
"Exporting model using AOT Inductor to " f"{output_aoti_package_path}."
)
export_for_server(
model_to_aoti_package,
builder_args.device,
output_aoti_package_path,
builder_args.dynamic_shapes,
package=True,
metadata=metadata,
)

0 comments on commit e0299ff

Please sign in to comment.