diff --git a/ai_edge_torch/generative/examples/README.md b/ai_edge_torch/generative/examples/README.md index 661c2de2..74d1b404 100644 --- a/ai_edge_torch/generative/examples/README.md +++ b/ai_edge_torch/generative/examples/README.md @@ -89,7 +89,16 @@ LlamaForCausalLM( ``` Based on the original model structure, construct a new nn.Module model using -the AI Edge Torch Generative API +the AI Edge Torch Generative API. As many examples do, either use +[`DecoderOnlyModel`](https://github.com/protobird-git/ai-edge-torch/blob/main/ai_edge_torch/generative/utilities/model_builder.py) +class as is like [SmolLM](https://github.com/protobird-git/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/smollm/smollm.py), +or inherit [`DecoderOnlyModel`](https://github.com/protobird-git/ai-edge-torch/blob/main/ai_edge_torch/generative/utilities/model_builder.py) +class then modify only some component like +[Llama 3.2](https://github.com/protobird-git/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/llama/llama.py), +or construct entirely a new nn.Module from scratch like +[Gemma 2](https://github.com/protobird-git/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/gemma/gemma2.py). + +Here is an example of TinyLlama constructed from scratch. https://github.com/google-ai-edge/ai-edge-torch/blob/853301630f2b2455bd2e2f73d8a47e1a1534c91c/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py#L46-L77 diff --git a/ai_edge_torch/generative/examples/gemma/gemma1.py b/ai_edge_torch/generative/examples/gemma/gemma1.py index 5811d6a9..719279ac 100644 --- a/ai_edge_torch/generative/examples/gemma/gemma1.py +++ b/ai_edge_torch/generative/examples/gemma/gemma1.py @@ -15,14 +15,9 @@ """Example of building a Gemma1 model.""" -from ai_edge_torch.generative.layers import attention -from ai_edge_torch.generative.layers import builder -from ai_edge_torch.generative.layers import kv_cache as kv_utils -import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg +from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils -import torch -from torch import nn TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( ff_up_proj="model.layers.{}.mlp.up_proj", @@ -38,84 +33,6 @@ ) -class Gemma(nn.Module): - """A Gemma model built from the Edge Generative API layers.""" - - def __init__(self, config: cfg.ModelConfig): - super().__init__() - - # Construct model layers. - self.tok_embedding = nn.Embedding( - config.vocab_size, config.embedding_dim, padding_idx=0 - ) - self.lm_head = nn.Linear( - config.embedding_dim, - config.vocab_size, - bias=config.lm_head_use_bias, - ) - # Gemma re-uses the embedding as the head projection layer. - self.lm_head.weight.data = self.tok_embedding.weight.data - # Gemma has only one block config. - block_config = config.block_config(0) - self.transformer_blocks = nn.ModuleList( - attention.TransformerBlock(block_config, config) - for _ in range(config.num_layers) - ) - self.final_norm = builder.build_norm( - config.embedding_dim, - config.final_norm_config, - ) - attn_config = block_config.attn_config - self.rope_cache = attn_utils.build_rope_cache( - size=config.kv_cache_max, - dim=int(attn_config.rotary_percentage * attn_config.head_dim), - base=attn_config.rotary_base, - ) - self.mask_cache = attn_utils.build_causal_mask_cache( - size=config.kv_cache_max, - ) - self.config = config - - @torch.inference_mode - def forward( - self, - tokens: torch.Tensor, - input_pos: torch.Tensor, - kv_cache: kv_utils.KVCache, - ) -> dict[torch.Tensor, kv_utils.KVCache]: - _, seq_len = tokens.size() - assert self.config.max_seq_len >= seq_len, ( - f"Cannot forward sequence of length {seq_len}, max seq length is only" - f" {self.config.max_seq_len}" - ) - assert len(self.transformer_blocks) == len(kv_cache.caches), ( - "The number of transformer blocks and the number of KV cache entries" - " must be the same." - ) - - cos, sin = self.rope_cache - cos = cos.index_select(0, input_pos) - sin = sin.index_select(0, input_pos) - mask = self.mask_cache.index_select(2, input_pos) - mask = mask[:, :, :, : self.config.kv_cache_max] - - # token embeddings of shape (b, t, n_embd) - x = self.tok_embedding(tokens) - x = x * (self.config.embedding_dim**0.5) - - updated_kv_entires = [] - for i, block in enumerate(self.transformer_blocks): - kv_entry = kv_cache.caches[i] if kv_cache else None - x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry) - if kv_entry: - updated_kv_entires.append(kv_entry) - updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires)) - - x = self.final_norm(x) - logits = self.lm_head(x) # (b, t, vocab_size) - return {"logits": logits, "kv_cache": updated_kv_cache} - - def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for a Gemma 2B model. @@ -154,6 +71,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: num_layers=18, max_seq_len=8192, embedding_dim=2048, + embedding_scale=2048**0.5, kv_cache_max_len=kv_cache_max_len, block_configs=block_config, final_norm_config=norm_config, @@ -173,12 +91,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: return config -def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module: - config = get_model_config_2b(**kwargs) - model = Gemma(config) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - # Since embedding and lm-head use the same weight, we need to set strict - # to False. - loader.load(model, strict=False) - model.eval() - return model +def build_2b_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_model_config_2b(**kwargs), + tensor_names=TENSOR_NAMES, + ) diff --git a/ai_edge_torch/generative/examples/gemma/gemma2.py b/ai_edge_torch/generative/examples/gemma/gemma2.py index 3d211783..a20fdc91 100644 --- a/ai_edge_torch/generative/examples/gemma/gemma2.py +++ b/ai_edge_torch/generative/examples/gemma/gemma2.py @@ -15,7 +15,6 @@ """Example of building a Gemma2 model.""" -import os from typing import Optional, Tuple from ai_edge_torch.generative.layers import attention diff --git a/ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py b/ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py deleted file mode 100644 index 9502f004..00000000 --- a/ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2024 The AI Edge Torch Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Example of converting Llama 3.2 3B model to multi-signature tflite model.""" - -import os -import pathlib - -from absl import app -from absl import flags -from ai_edge_torch.generative.examples.llama import llama -from ai_edge_torch.generative.utilities import converter - -_CHECKPOINT_PATH = flags.DEFINE_string( - 'checkpoint_path', - os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'), - 'The path to the model checkpoint, or directory holding the checkpoint.', -) -_TFLITE_PATH = flags.DEFINE_string( - 'tflite_path', - '/tmp/', - 'The tflite file path to export.', -) -_PREFILL_SEQ_LEN = flags.DEFINE_integer( - 'prefill_seq_len', - 1024, - 'The maximum size of prefill input tensor.', -) -_KV_CACHE_MAX_LEN = flags.DEFINE_integer( - 'kv_cache_max_len', - 1280, - 'The maximum size of KV cache buffer, including both prefill and decode.', -) -_QUANTIZE = flags.DEFINE_bool( - 'quantize', - True, - 'Whether the model should be quantized.', -) - - -def main(_): - pytorch_model = llama.build_3b_model( - _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value - ) - quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = f'llama_3b_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' - converter.convert_to_tflite( - pytorch_model, - tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), - prefill_seq_len=_PREFILL_SEQ_LEN.value, - quantize=_QUANTIZE.value, - ) - - -if __name__ == '__main__': - app.run(main) diff --git a/ai_edge_torch/generative/examples/llama/convert_to_tflite.py b/ai_edge_torch/generative/examples/llama/convert_to_tflite.py index 8a6d9d32..5bc09922 100644 --- a/ai_edge_torch/generative/examples/llama/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/llama/convert_to_tflite.py @@ -23,6 +23,12 @@ from ai_edge_torch.generative.examples.llama import llama from ai_edge_torch.generative.utilities import converter +_MODEL_SIZE = flags.DEFINE_enum( + 'model_size', + '1b', + ['1b', '3b'], + 'The size of the model to verify.', +) _CHECKPOINT_PATH = flags.DEFINE_string( 'checkpoint_path', os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'), @@ -49,13 +55,18 @@ 'Whether the model should be quantized.', ) +_BUILDER = { + '1b': llama.build_1b_model, + '3b': llama.build_3b_model, +} + def main(_): - pytorch_model = llama.build_model( + pytorch_model = _BUILDER[_MODEL_SIZE.value]( _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value ) quant_suffix = 'q8' if _QUANTIZE.value else 'f32' - output_filename = f'llama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' + output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite' converter.convert_to_tflite( pytorch_model, tflite_path=os.path.join(_TFLITE_PATH.value, output_filename), diff --git a/ai_edge_torch/generative/examples/llama/llama.py b/ai_edge_torch/generative/examples/llama/llama.py index 271afdd9..1fde9f69 100644 --- a/ai_edge_torch/generative/examples/llama/llama.py +++ b/ai_edge_torch/generative/examples/llama/llama.py @@ -15,19 +15,15 @@ """Example of building Llama 3.2 models.""" -import copy import math from typing import Tuple -from ai_edge_torch.generative.examples.tiny_llama import tiny_llama import ai_edge_torch.generative.layers.model_config as cfg +from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils import torch -from torch import nn -TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES) -# SmolLM re-uses the embedding as the head projection layer. -TENSOR_NAMES.lm_head = None +TENSOR_NAMES = model_builder.TENSOR_NAMES def _build_llama3_rope_cache( @@ -93,7 +89,7 @@ def _build_llama3_rope_cache( return cos, sin -class Llama(tiny_llama.TinyLlama): +class Llama(model_builder.DecoderOnlyModel): """A Llama model built from the Edge Generative API layers. Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation. @@ -101,9 +97,6 @@ class Llama(tiny_llama.TinyLlama): def __init__(self, config: cfg.ModelConfig): super().__init__(config) - # Llama 3.2 re-uses the embedding as the head projection layer. - self.lm_head.weight.data = self.tok_embedding.weight.data - # Llama has only one block config. attn_config = self.config.block_config(0).attn_config self.rope_cache = _build_llama3_rope_cache( size=self.config.kv_cache_max, @@ -119,7 +112,7 @@ def __init__(self, config: cfg.ModelConfig): ) -def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: +def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for a Llama 3.2-1B model. Args: @@ -163,7 +156,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for a Llama 3.2-3B model.""" - config = get_model_config(kv_cache_max_len) + config = get_1b_model_config(kv_cache_max_len) # Llama 3.2 has only one block config. attn_config = config.block_config(0).attn_config attn_config.num_heads = 24 @@ -174,7 +167,7 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: def get_fake_model_config(**kwargs) -> cfg.ModelConfig: - config = get_model_config(**kwargs) + config = get_1b_model_config(**kwargs) config.vocab_size = 128 config.num_layers = 2 # SmolLM has only one block config. @@ -182,8 +175,9 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig: return config -def build_model(checkpoint_path: str, **kwargs) -> nn.Module: - config = get_model_config(**kwargs) +def _build_model( + checkpoint_path: str, config: cfg.ModelConfig +) -> model_builder.DecoderOnlyModel: model = Llama(config) loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) # Since embedding and lm-head use the same weight, we need to set strict @@ -193,12 +187,13 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module: return model -def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module: - config = get_3b_model_config(**kwargs) - model = Llama(config) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - # Since embedding and lm-head use the same weight, we need to set strict - # to False. - loader.load(model, strict=False) - model.eval() - return model +def build_1b_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return _build_model(checkpoint_path, get_1b_model_config(**kwargs)) + + +def build_3b_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return _build_model(checkpoint_path, get_3b_model_config(**kwargs)) diff --git a/ai_edge_torch/generative/examples/llama/verify.py b/ai_edge_torch/generative/examples/llama/verify.py index cacc35e6..199b09ba 100644 --- a/ai_edge_torch/generative/examples/llama/verify.py +++ b/ai_edge_torch/generative/examples/llama/verify.py @@ -25,7 +25,12 @@ from ai_edge_torch.generative.utilities import verifier import transformers - +_MODEL_SIZE = flags.DEFINE_enum( + "model_size", + "1b", + ["1b", "3b"], + "The size of the model to verify.", +) _PROMPTS = flags.DEFINE_multi_string( "prompts", "What is the meaning of life?", @@ -37,9 +42,19 @@ "The maximum size of the generated tokens.", ) +_CHECKPOINT = { + "1b": "meta-llama/Llama-3.2-1B-Instruct", + "3b": "meta-llama/Llama-3.2-3B-Instruct", +} + +_BUILDER = { + "1b": llama.build_1b_model, + "3b": llama.build_3b_model, +} + def main(_): - checkpoint = "meta-llama/Llama-3.2-1B-Instruct" + checkpoint = _CHECKPOINT[_MODEL_SIZE.value] logging.info("Loading the original model from: %s", checkpoint) original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint) @@ -49,7 +64,7 @@ def main(_): ) reauthored_checkpoint = pathlib.Path(cached_config_file).parent logging.info("Building the reauthored model from: %s", reauthored_checkpoint) - reauthored_model = llama.build_model(reauthored_checkpoint) + reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint) logging.info("Loading the tokenizer from: %s", checkpoint) # Llama tokenizer_config.json sets a fast tokenizer class explicitly, diff --git a/ai_edge_torch/generative/examples/llama/verify_3b.py b/ai_edge_torch/generative/examples/llama/verify_3b.py deleted file mode 100644 index bbc230d6..00000000 --- a/ai_edge_torch/generative/examples/llama/verify_3b.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2024 The AI Edge Torch Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Verifies the reauthored Llama 3.2-3B model.""" - -import logging -import pathlib - -from absl import app -from absl import flags -from ai_edge_torch.generative.examples.llama import llama -from ai_edge_torch.generative.utilities import transformers_verifier -from ai_edge_torch.generative.utilities import verifier -import transformers - - -_PROMPTS = flags.DEFINE_multi_string( - "prompts", - "What is the meaning of life?", - "The input prompts to generate answers.", -) -_MAX_NEW_TOKENS = flags.DEFINE_integer( - "max_new_tokens", - 30, - "The maximum size of the generated tokens.", -) - - -def main(_): - checkpoint = "meta-llama/Llama-3.2-3B-Instruct" - logging.info("Loading the original model from: %s", checkpoint) - original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint) - - # Locate the cached dir. - cached_config_file = transformers.utils.cached_file( - checkpoint, transformers.utils.CONFIG_NAME - ) - reauthored_checkpoint = pathlib.Path(cached_config_file).parent - logging.info("Building the reauthored model from: %s", reauthored_checkpoint) - reauthored_model = llama.build_3b_model(reauthored_checkpoint) - - logging.info("Loading the tokenizer from: %s", checkpoint) - # Llama tokenizer_config.json sets a fast tokenizer class explicitly, - # "PreTrainedTokenizerFast". It works only when the fast tokenizer is - # available. - tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) - - verifier.verify_reauthored_model( - original_model=transformers_verifier.TransformersModelWrapper( - original_model - ), - reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model), - tokenizer=verifier.TokenizerWrapper(tokenizer), - generate_prompts=_PROMPTS.value, - max_new_tokens=_MAX_NEW_TOKENS.value, - atol=1e-04, - ) - - -if __name__ == "__main__": - app.run(main) diff --git a/ai_edge_torch/generative/examples/openelm/openelm.py b/ai_edge_torch/generative/examples/openelm/openelm.py index 46e58ed6..d8de0b84 100644 --- a/ai_edge_torch/generative/examples/openelm/openelm.py +++ b/ai_edge_torch/generative/examples/openelm/openelm.py @@ -15,14 +15,9 @@ """Example of building an OpenELM model.""" -from ai_edge_torch.generative.layers import attention -from ai_edge_torch.generative.layers import builder -from ai_edge_torch.generative.layers import kv_cache as kv_utils -import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg +from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils -import torch -from torch import nn TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( ff_up_proj="transformer.layers.{}.ffn.proj_1", @@ -39,81 +34,6 @@ ) -class OpenELM(nn.Module): - """An OpenELM model built from the Edge Generative API layers.""" - - def __init__(self, config: cfg.ModelConfig): - super().__init__() - - # Construct model layers. - self.tok_embedding = nn.Embedding( - config.vocab_size, config.embedding_dim, padding_idx=0 - ) - self.lm_head = nn.Linear( - config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias - ) - # OpenELM re-uses the embedding as the head projection layer. - self.lm_head.weight.data = self.tok_embedding.weight.data - self.transformer_blocks = nn.ModuleList( - attention.TransformerBlock(config.block_config(idx), config) - for idx in range(config.num_layers) - ) - self.final_norm = builder.build_norm( - config.embedding_dim, - config.final_norm_config, - ) - # OpenELM has same hyper parameters for rotary_percentage and head_dim for - # each layer block. Use the first block. - attn_config = config.block_config(0).attn_config - self.rope_cache = attn_utils.build_rope_cache( - size=config.kv_cache_max, - dim=int(attn_config.rotary_percentage * attn_config.head_dim), - base=attn_config.rotary_base, - ) - self.mask_cache = attn_utils.build_causal_mask_cache( - size=config.kv_cache_max, - ) - self.config = config - - @torch.inference_mode - def forward( - self, - tokens: torch.Tensor, - input_pos: torch.Tensor, - kv_cache: kv_utils.KVCache, - ) -> dict[torch.Tensor, kv_utils.KVCache]: - _, seq_len = tokens.size() - assert self.config.max_seq_len >= seq_len, ( - f"Cannot forward sequence of length {seq_len}, max seq length is only" - f" {self.config.max_seq_len}" - ) - assert len(self.transformer_blocks) == len(kv_cache.caches), ( - "The number of transformer blocks and the number of KV cache entries" - " must be the same." - ) - - cos, sin = self.rope_cache - cos = cos.index_select(0, input_pos) - sin = sin.index_select(0, input_pos) - mask = self.mask_cache.index_select(2, input_pos) - mask = mask[:, :, :, : self.config.kv_cache_max] - - # token embeddings of shape (b, t, n_embd) - x = self.tok_embedding(tokens) - - updated_kv_entires = [] - for i, block in enumerate(self.transformer_blocks): - kv_entry = kv_cache.caches[i] if kv_cache else None - x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry) - if kv_entry: - updated_kv_entires.append(kv_entry) - updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires)) - - x = self.final_norm(x) - logits = self.lm_head(x) # (b, t, vocab_size) - return {"logits": logits, "kv_cache": updated_kv_cache} - - def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for an OpenELM model. @@ -191,12 +111,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: return config -def build_model(checkpoint_path: str, **kwargs) -> nn.Module: - config = get_model_config(**kwargs) - model = OpenELM(config) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - # Since embedding and lm-head use the same weight, we need to set strict - # to False. - loader.load(model, strict=False) - model.eval() - return model +def build_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_model_config(**kwargs), + tensor_names=TENSOR_NAMES, + ) diff --git a/ai_edge_torch/generative/examples/phi/phi2.py b/ai_edge_torch/generative/examples/phi/phi2.py index 718b5369..61a1e647 100644 --- a/ai_edge_torch/generative/examples/phi/phi2.py +++ b/ai_edge_torch/generative/examples/phi/phi2.py @@ -15,14 +15,9 @@ """Example of building a Phi-2 model.""" -from ai_edge_torch.generative.layers import attention -from ai_edge_torch.generative.layers import builder -from ai_edge_torch.generative.layers import kv_cache as kv_utils -import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg +from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils -import torch -from torch import nn TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( ff_up_proj="model.layers.{}.mlp.fc1", @@ -38,78 +33,6 @@ ) -class Phi2(nn.Module): - """A Phi-2 model built from the Edge Generative API layers.""" - - def __init__(self, config: cfg.ModelConfig): - super().__init__() - - # Construct model layers. - self.lm_head = nn.Linear( - config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias - ) - self.tok_embedding = nn.Embedding( - config.vocab_size, config.embedding_dim, padding_idx=0 - ) - # Phi-2 has only one block config. - block_config = config.block_config(0) - self.transformer_blocks = nn.ModuleList( - attention.TransformerBlock(block_config, config) - for _ in range(config.num_layers) - ) - self.final_norm = builder.build_norm( - config.embedding_dim, - config.final_norm_config, - ) - attn_config = block_config.attn_config - self.rope_cache = attn_utils.build_rope_cache( - size=config.kv_cache_max, - dim=int(attn_config.rotary_percentage * attn_config.head_dim), - base=attn_config.rotary_base, - ) - self.mask_cache = attn_utils.build_causal_mask_cache( - size=config.kv_cache_max, - ) - self.config = config - - @torch.inference_mode - def forward( - self, - tokens: torch.Tensor, - input_pos: torch.Tensor, - kv_cache: kv_utils.KVCache, - ) -> dict[torch.Tensor, kv_utils.KVCache]: - _, seq_len = tokens.size() - assert self.config.max_seq_len >= seq_len, ( - f"Cannot forward sequence of length {seq_len}, max seq length is only" - f" {self.config.max_seq_len}" - ) - assert len(self.transformer_blocks) == len(kv_cache.caches), ( - "The number of transformer blocks and the number of KV cache entries" - " must be the same." - ) - - cos, sin = self.rope_cache - cos = cos.index_select(0, input_pos) - sin = sin.index_select(0, input_pos) - mask = self.mask_cache.index_select(2, input_pos) - mask = mask[:, :, :, : self.config.kv_cache_max] - - x = self.tok_embedding(tokens) - - updated_kv_entires = [] - for i, block in enumerate(self.transformer_blocks): - kv_entry = kv_cache.caches[i] if kv_cache else None - x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry) - if kv_entry: - updated_kv_entires.append(kv_entry) - updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires)) - - x = self.final_norm(x) - logits = self.lm_head(x) # (b, t, vocab_size) - return {"logits": logits, "kv_cache": updated_kv_cache} - - def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: """Returns the model config for a Phi-2 model. @@ -154,6 +77,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: block_configs=block_config, final_norm_config=norm_config, lm_head_use_bias=True, + lm_head_share_weight_with_embedding=False, enable_hlfb=True, ) return config @@ -169,11 +93,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: return config -def build_model(checkpoint_path: str, **kwargs) -> nn.Module: - """Instantiates the model instance and load checkpoint if provided.""" - config = get_model_config(**kwargs) - model = Phi2(config) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - loader.load(model) - model.eval() - return model +def build_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_model_config(**kwargs), + tensor_names=TENSOR_NAMES, + ) diff --git a/ai_edge_torch/generative/examples/phi/phi3.py b/ai_edge_torch/generative/examples/phi/phi3.py index c2c2b802..6958dad9 100644 --- a/ai_edge_torch/generative/examples/phi/phi3.py +++ b/ai_edge_torch/generative/examples/phi/phi3.py @@ -18,14 +18,10 @@ import math from typing import Tuple -from ai_edge_torch.generative.layers import attention -from ai_edge_torch.generative.layers import builder -from ai_edge_torch.generative.layers import kv_cache as kv_utils -import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg +from ai_edge_torch.generative.utilities import model_builder import ai_edge_torch.generative.utilities.loader as loading_utils import torch -from torch import nn TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( ff_up_proj="model.layers.{}.mlp.gate_up_proj", @@ -137,32 +133,14 @@ def _build_rope_cache( return cos, sin -class Phi3_5Mini(nn.Module): +class Phi3_5Mini(model_builder.DecoderOnlyModel): """A Phi-3.5 model built from the Edge Generative API layers.""" def __init__(self, config: cfg.ModelConfig): - super().__init__() - - # Construct model layers. - self.lm_head = nn.Linear( - config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias - ) - self.tok_embedding = nn.Embedding( - config.vocab_size, config.embedding_dim, padding_idx=0 - ) - # Phi-3.5 has only one block config. - block_config = config.block_config(0) - self.transformer_blocks = nn.ModuleList( - attention.TransformerBlock(block_config, config) - for _ in range(config.num_layers) - ) - self.final_norm = builder.build_norm( - config.embedding_dim, - config.final_norm_config, - ) - attn_config = block_config.attn_config + super().__init__(config) + attn_config = self.config.block_config(0).attn_config self.rope_cache = _build_rope_cache( - size=config.kv_cache_max, + size=self.config.kv_cache_max, dim=int(attn_config.rotary_percentage * attn_config.head_dim), base=attn_config.rotary_base, condense_ratio=1, @@ -173,47 +151,6 @@ def __init__(self, config: cfg.ModelConfig): 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len) ), ) - self.mask_cache = attn_utils.build_causal_mask_cache( - size=config.kv_cache_max, - ) - self.config = config - - @torch.inference_mode - def forward( - self, - tokens: torch.Tensor, - input_pos: torch.Tensor, - kv_cache: kv_utils.KVCache, - ) -> dict[torch.Tensor, kv_utils.KVCache]: - _, seq_len = tokens.size() - assert self.config.max_seq_len >= seq_len, ( - f"Cannot forward sequence of length {seq_len}, max seq length is only" - f" {self.config.max_seq_len}" - ) - assert len(self.transformer_blocks) == len(kv_cache.caches), ( - "The number of transformer blocks and the number of KV cache entries" - " must be the same." - ) - - cos, sin = self.rope_cache - cos = cos.index_select(0, input_pos) - sin = sin.index_select(0, input_pos) - mask = self.mask_cache.index_select(2, input_pos) - mask = mask[:, :, :, : self.config.kv_cache_max] - - x = self.tok_embedding(tokens) - - updated_kv_entires = [] - for i, block in enumerate(self.transformer_blocks): - kv_entry = kv_cache.caches[i] if kv_cache else None - x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry) - if kv_entry: - updated_kv_entires.append(kv_entry) - updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires)) - - x = self.final_norm(x) - logits = self.lm_head(x) # (b, t, vocab_size) - return {"logits": logits, "kv_cache": updated_kv_cache} def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: @@ -254,6 +191,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: embedding_dim=3072, block_configs=block_config, final_norm_config=norm_config, + lm_head_share_weight_with_embedding=False, enable_hlfb=True, ) return config @@ -269,7 +207,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: return config -def build_model(checkpoint_path: str, **kwargs) -> nn.Module: +def build_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: """Instantiates the model instance and load checkpoint if provided.""" config = get_model_config(**kwargs) model = Phi3_5Mini(config) diff --git a/ai_edge_torch/generative/examples/qwen/qwen.py b/ai_edge_torch/generative/examples/qwen/qwen.py index 669207d2..0347758b 100644 --- a/ai_edge_torch/generative/examples/qwen/qwen.py +++ b/ai_edge_torch/generative/examples/qwen/qwen.py @@ -15,28 +15,10 @@ """Example of building Qwen 2.5 models.""" -import copy - -from ai_edge_torch.generative.examples.tiny_llama import tiny_llama import ai_edge_torch.generative.layers.model_config as cfg -import ai_edge_torch.generative.utilities.loader as loading_utils -from torch import nn - -TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES) -# Qwen re-uses the embedding as the head projection layer. -TENSOR_NAMES.lm_head = None - - -class Qwen(tiny_llama.TinyLlama): - """A Qwen model built from the Edge Generative API layers. - - Qwen 2.5 shares the same architecture as TinyLlama. - """ +from ai_edge_torch.generative.utilities import model_builder - def __init__(self, config: cfg.ModelConfig): - super().__init__(config) - # Qwen re-uses the embedding as the head projection layer. - self.lm_head.weight.data = self.tok_embedding.weight.data +TENSOR_NAMES = model_builder.TENSOR_NAMES def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: @@ -119,23 +101,31 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig: return config -def _build_model(checkpoint_path: str, config: cfg.ModelConfig) -> nn.Module: - model = Qwen(config) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - # Since embedding and lm-head use the same weight, we need to set strict - # to False. - loader.load(model, strict=False) - model.eval() - return model - - -def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module: - return _build_model(checkpoint_path, get_3b_model_config(**kwargs)) +def build_3b_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_3b_model_config(**kwargs), + tensor_names=TENSOR_NAMES, + ) -def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module: - return _build_model(checkpoint_path, get_1_5b_model_config(**kwargs)) +def build_1_5b_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_1_5b_model_config(**kwargs), + tensor_names=TENSOR_NAMES, + ) -def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module: - return _build_model(checkpoint_path, get_0_5b_model_config(**kwargs)) +def build_0_5b_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_0_5b_model_config(**kwargs), + tensor_names=TENSOR_NAMES, + ) diff --git a/ai_edge_torch/generative/examples/smollm/smollm.py b/ai_edge_torch/generative/examples/smollm/smollm.py index 303671bb..2e5942f8 100644 --- a/ai_edge_torch/generative/examples/smollm/smollm.py +++ b/ai_edge_torch/generative/examples/smollm/smollm.py @@ -15,29 +15,10 @@ """Example of building a SmolLM model.""" -import copy - -from ai_edge_torch.generative.examples.tiny_llama import tiny_llama import ai_edge_torch.generative.layers.model_config as cfg -import ai_edge_torch.generative.utilities.loader as loading_utils -from torch import nn - -TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES) -# SmolLM re-uses the embedding as the head projection layer. -TENSOR_NAMES.lm_head = None - - -class SmolLM(tiny_llama.TinyLlama): - """A SmolLM model built from the Edge Generative API layers. +from ai_edge_torch.generative.utilities import model_builder - SmolLM shares the same architecture as TinyLlama, but with different model - sizes. - """ - - def __init__(self, config: cfg.ModelConfig): - super().__init__(config) - # SmolLM re-uses the embedding as the head projection layer. - self.lm_head.weight.data = self.tok_embedding.weight.data +TENSOR_NAMES = model_builder.TENSOR_NAMES def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: @@ -91,12 +72,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig: return config -def build_model(checkpoint_path: str, **kwargs) -> nn.Module: - config = get_model_config(**kwargs) - model = SmolLM(config) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - # Since embedding and lm-head use the same weight, we need to set strict - # to False. - loader.load(model, strict=False) - model.eval() - return model +def build_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_model_config(**kwargs), + tensor_names=TENSOR_NAMES, + ) diff --git a/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py b/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py index 986fb716..5de9c237 100644 --- a/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +++ b/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py @@ -15,102 +15,10 @@ """Example of building a TinyLlama model.""" -from ai_edge_torch.generative.layers import attention -from ai_edge_torch.generative.layers import builder -from ai_edge_torch.generative.layers import kv_cache as kv_utils -import ai_edge_torch.generative.layers.attention_utils as attn_utils import ai_edge_torch.generative.layers.model_config as cfg -import ai_edge_torch.generative.utilities.loader as loading_utils -import torch -from torch import nn +from ai_edge_torch.generative.utilities import model_builder -TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( - ff_up_proj="model.layers.{}.mlp.up_proj", - ff_down_proj="model.layers.{}.mlp.down_proj", - ff_gate_proj="model.layers.{}.mlp.gate_proj", - attn_query_proj="model.layers.{}.self_attn.q_proj", - attn_key_proj="model.layers.{}.self_attn.k_proj", - attn_value_proj="model.layers.{}.self_attn.v_proj", - attn_output_proj="model.layers.{}.self_attn.o_proj", - pre_attn_norm="model.layers.{}.input_layernorm", - post_attn_norm="model.layers.{}.post_attention_layernorm", - embedding="model.embed_tokens", - final_norm="model.norm", - lm_head="lm_head", -) - - -class TinyLlama(nn.Module): - """A TinyLlama model built from the Edge Generative API layers.""" - - def __init__(self, config: cfg.ModelConfig): - super().__init__() - - # Construct model layers. - self.lm_head = nn.Linear( - config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias - ) - self.tok_embedding = nn.Embedding( - config.vocab_size, config.embedding_dim, padding_idx=0 - ) - # TinyLlama has only one block config. - block_config = config.block_config(0) - self.transformer_blocks = nn.ModuleList( - attention.TransformerBlock(block_config, config) - for _ in range(config.num_layers) - ) - self.final_norm = builder.build_norm( - config.embedding_dim, - config.final_norm_config, - ) - attn_config = block_config.attn_config - self.rope_cache = attn_utils.build_rope_cache( - size=config.kv_cache_max, - dim=int(attn_config.rotary_percentage * attn_config.head_dim), - base=attn_config.rotary_base, - ) - self.mask_cache = attn_utils.build_causal_mask_cache( - size=config.kv_cache_max, - ) - self.config = config - - @torch.inference_mode - def forward( - self, - tokens: torch.Tensor, - input_pos: torch.Tensor, - kv_cache: kv_utils.KVCache, - ) -> dict[torch.Tensor, kv_utils.KVCache]: - _, seq_len = tokens.size() - assert self.config.max_seq_len >= seq_len, ( - f"Cannot forward sequence of length {seq_len}, max seq length is only" - f" {self.config.max_seq_len}" - ) - assert len(self.transformer_blocks) == len(kv_cache.caches), ( - "The number of transformer blocks and the number of KV cache entries" - " must be the same." - ) - - cos, sin = self.rope_cache - cos = cos.index_select(0, input_pos) - sin = sin.index_select(0, input_pos) - mask = self.mask_cache.index_select(2, input_pos) - mask = mask[:, :, :, : self.config.kv_cache_max] - - # token embeddings of shape (b, t, n_embd) - x = self.tok_embedding(tokens) - - updated_kv_entires = [] - for i, block in enumerate(self.transformer_blocks): - kv_entry = kv_cache.caches[i] if kv_cache else None - x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry) - if kv_entry: - updated_kv_entires.append(kv_entry) - updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires)) - - x = self.final_norm(x) - logits = self.lm_head(x) # (b, t, vocab_size) - return {"logits": logits, "kv_cache": updated_kv_cache} +TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: @@ -150,6 +58,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: kv_cache_max_len=kv_cache_max_len, block_configs=block_config, final_norm_config=norm_config, + lm_head_share_weight_with_embedding=False, enable_hlfb=True, ) return config @@ -164,10 +73,11 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig: return config -def build_model(checkpoint_path: str, **kwargs) -> nn.Module: - config = get_model_config(**kwargs) - model = TinyLlama(config) - loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES) - loader.load(model) - model.eval() - return model +def build_model( + checkpoint_path: str, **kwargs +) -> model_builder.DecoderOnlyModel: + return model_builder.build_decoder_only_model( + checkpoint_path=checkpoint_path, + config=get_model_config(**kwargs), + tensor_names=TENSOR_NAMES, + ) diff --git a/ai_edge_torch/generative/layers/model_config.py b/ai_edge_torch/generative/layers/model_config.py index 01365789..2e1b4f22 100644 --- a/ai_edge_torch/generative/layers/model_config.py +++ b/ai_edge_torch/generative/layers/model_config.py @@ -184,8 +184,14 @@ class ModelConfig: default_factory=NormalizationConfig ) + # Scale factor of the embedding. + embedding_scale: Optional[float] = None + # Use bias term within LLM's HEAD. lm_head_use_bias: bool = False + # Whether LLM's HEAD shares the weight of the embedding. + lm_head_share_weight_with_embedding: bool = True + # Whether to turn on high-level function boundary. enable_hlfb: bool = False diff --git a/ai_edge_torch/generative/test/test_loader.py b/ai_edge_torch/generative/test/test_loader.py index 45f30f52..0cac8479 100644 --- a/ai_edge_torch/generative/test/test_loader.py +++ b/ai_edge_torch/generative/test/test_loader.py @@ -19,6 +19,7 @@ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama from ai_edge_torch.generative.utilities import loader as loading_utils +from ai_edge_torch.generative.utilities import model_builder import safetensors.torch import torch @@ -71,7 +72,7 @@ def test_model_loader(self): safetensors.torch.save_file(test_weights, file_path) cfg = tiny_llama.get_model_config() cfg.num_layers = 1 - model = tiny_llama.TinyLlama(cfg) + model = model_builder.DecoderOnlyModel(cfg) loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES) # if returns successfully, it means all the tensors were initiallized. diff --git a/ai_edge_torch/generative/test/test_model_conversion.py b/ai_edge_torch/generative/test/test_model_conversion.py index 114a1bf0..2661c127 100644 --- a/ai_edge_torch/generative/test/test_model_conversion.py +++ b/ai_edge_torch/generative/test/test_model_conversion.py @@ -21,6 +21,7 @@ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama from ai_edge_torch.generative.layers import kv_cache from ai_edge_torch.generative.test import utils as test_utils +from ai_edge_torch.generative.utilities import model_builder import numpy as np import torch @@ -163,7 +164,7 @@ def _test_multisig_model(self, config, pytorch_model, atol, rtol): ) def test_tiny_llama_multisig(self): config = tiny_llama.get_fake_model_config() - pytorch_model = tiny_llama.TinyLlama(config).eval() + pytorch_model = model_builder.DecoderOnlyModel(config).eval() self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5) diff --git a/ai_edge_torch/generative/test/test_model_conversion_large.py b/ai_edge_torch/generative/test/test_model_conversion_large.py index a85c1f38..3dee0483 100644 --- a/ai_edge_torch/generative/test/test_model_conversion_large.py +++ b/ai_edge_torch/generative/test/test_model_conversion_large.py @@ -29,6 +29,7 @@ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion from ai_edge_torch.generative.layers import kv_cache +from ai_edge_torch.generative.utilities import model_builder from ai_edge_torch.generative.test import utils as test_utils import numpy as np import torch @@ -90,7 +91,7 @@ def _test_model(self, config, model, signature_name, atol, rtol): ) def test_gemma1(self): config = gemma1.get_fake_model_config() - pytorch_model = gemma1.Gemma(config).eval() + pytorch_model = model_builder.DecoderOnlyModel(config).eval() self._test_model( config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5 ) @@ -119,7 +120,7 @@ def test_llama(self): ) def test_phi2(self): config = phi2.get_fake_model_config() - pytorch_model = phi2.Phi2(config).eval() + pytorch_model = model_builder.DecoderOnlyModel(config).eval() self._test_model( config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3 ) @@ -139,7 +140,7 @@ def test_phi3(self): ) def test_smollm(self): config = smollm.get_fake_model_config() - pytorch_model = smollm.SmolLM(config).eval() + pytorch_model = model_builder.DecoderOnlyModel(config).eval() self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5) @googletest.skipIf( @@ -148,7 +149,7 @@ def test_smollm(self): ) def test_openelm(self): config = openelm.get_fake_model_config() - pytorch_model = openelm.OpenELM(config).eval() + pytorch_model = model_builder.DecoderOnlyModel(config).eval() self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5) @googletest.skipIf( @@ -157,7 +158,7 @@ def test_openelm(self): ) def test_qwen(self): config = qwen.get_fake_model_config() - pytorch_model = qwen.Qwen(config).eval() + pytorch_model = model_builder.DecoderOnlyModel(config).eval() self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5) @googletest.skipIf( diff --git a/ai_edge_torch/generative/utilities/model_builder.py b/ai_edge_torch/generative/utilities/model_builder.py new file mode 100644 index 00000000..9565fad4 --- /dev/null +++ b/ai_edge_torch/generative/utilities/model_builder.py @@ -0,0 +1,141 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities to be used for re-authoring transformer models.""" + +import copy + +from ai_edge_torch.generative.layers import attention +from ai_edge_torch.generative.layers import builder +from ai_edge_torch.generative.layers import kv_cache as kv_utils +import ai_edge_torch.generative.layers.attention_utils as attn_utils +import ai_edge_torch.generative.layers.model_config as cfg +import ai_edge_torch.generative.utilities.loader as loading_utils +import torch +from torch import nn + +TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( + ff_up_proj="model.layers.{}.mlp.up_proj", + ff_down_proj="model.layers.{}.mlp.down_proj", + ff_gate_proj="model.layers.{}.mlp.gate_proj", + attn_query_proj="model.layers.{}.self_attn.q_proj", + attn_key_proj="model.layers.{}.self_attn.k_proj", + attn_value_proj="model.layers.{}.self_attn.v_proj", + attn_output_proj="model.layers.{}.self_attn.o_proj", + pre_attn_norm="model.layers.{}.input_layernorm", + post_attn_norm="model.layers.{}.post_attention_layernorm", + embedding="model.embed_tokens", + final_norm="model.norm", +) + +TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES) +TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head" + + +class DecoderOnlyModel(nn.Module): + """A simple decoder-only transformer model built from the Edge Generative API. + + This model is used for re-authoring. model_config is used to specify the + details of model architecture and parameters. + + It assumes that the attention configs for ROPE, i.e. head_dim, rotary_base, + and rotary_percentage are the same for all layers. + """ + + def __init__(self, config: cfg.ModelConfig): + super().__init__() + + # Construct model layers. + self.tok_embedding = nn.Embedding( + config.vocab_size, config.embedding_dim, padding_idx=0 + ) + self.lm_head = nn.Linear( + config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias + ) + if config.lm_head_share_weight_with_embedding: + self.lm_head.weight.data = self.tok_embedding.weight.data + self.transformer_blocks = nn.ModuleList( + attention.TransformerBlock(config.block_config(idx), config) + for idx in range(config.num_layers) + ) + self.final_norm = builder.build_norm( + config.embedding_dim, + config.final_norm_config, + ) + # ROPE parameters for all attn_configs are the same. Take the first one. + attn_config = config.block_config(0).attn_config + self.rope_cache = attn_utils.build_rope_cache( + size=config.kv_cache_max, + dim=int(attn_config.rotary_percentage * attn_config.head_dim), + base=attn_config.rotary_base, + ) + self.mask_cache = attn_utils.build_causal_mask_cache( + size=config.kv_cache_max, + ) + self.config = config + + @torch.inference_mode + def forward( + self, + tokens: torch.Tensor, + input_pos: torch.Tensor, + kv_cache: kv_utils.KVCache, + ) -> dict[torch.Tensor, kv_utils.KVCache]: + _, seq_len = tokens.size() + assert self.config.max_seq_len >= seq_len, ( + f"Cannot forward sequence of length {seq_len}, max seq length is only" + f" {self.config.max_seq_len}" + ) + assert len(self.transformer_blocks) == len(kv_cache.caches), ( + "The number of transformer blocks and the number of KV cache entries" + " must be the same." + ) + + cos, sin = self.rope_cache + cos = cos.index_select(0, input_pos) + sin = sin.index_select(0, input_pos) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, : self.config.kv_cache_max] + + # token embeddings of shape (b, t, n_embd) + x = self.tok_embedding(tokens) + if self.config.embedding_scale is not None: + x = x * self.config.embedding_scale + + updated_kv_entires = [] + for i, block in enumerate(self.transformer_blocks): + kv_entry = kv_cache.caches[i] if kv_cache else None + x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry) + if kv_entry: + updated_kv_entires.append(kv_entry) + updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires)) + + x = self.final_norm(x) + logits = self.lm_head(x) # (b, t, vocab_size) + return {"logits": logits, "kv_cache": updated_kv_cache} + + +def build_decoder_only_model( + checkpoint_path: str, + config: cfg.ModelConfig, + tensor_names: loading_utils.ModelLoader.TensorNames, +) -> DecoderOnlyModel: + transformer = DecoderOnlyModel(config) + loader = loading_utils.ModelLoader(checkpoint_path, tensor_names) + loader.load( + transformer, strict=not config.lm_head_share_weight_with_embedding + ) + transformer.eval() + return transformer