Skip to content

Commit

Permalink
Refactor model building code.
Browse files Browse the repository at this point in the history
- Define model_builder.DecodeOnlyModel as a base of model re-authoring.
- Use model_builder.build_decode_only_model() to build models with model_builder.DecodeOnlyModel.
- Gemma2 defines its own nn.Model.
- Llama and Phi3.5 inherits model_builder.DecodeOnlyModel.
- All others including Gemma1 and Phi2 just use model_builder.DecodeOnlyModel as is.
- Add embedding_scale and lm_head_share_weight_with_embedding to model config.
- Switch Llama 1b and 3b with a command line flag instead of separate py files.

PiperOrigin-RevId: 681568633
  • Loading branch information
ai-edge-bot authored and copybara-github committed Oct 2, 2024
1 parent b076be8 commit 5a7e1cf
Show file tree
Hide file tree
Showing 19 changed files with 302 additions and 684 deletions.
11 changes: 10 additions & 1 deletion ai_edge_torch/generative/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ LlamaForCausalLM(
```

Based on the original model structure, construct a new nn.Module model using
the AI Edge Torch Generative API
the AI Edge Torch Generative API. As many examples do, either use
[`DecoderOnlyModel`](https://github.com/protobird-git/ai-edge-torch/blob/main/ai_edge_torch/generative/utilities/model_builder.py)
class as is like [SmolLM](https://github.com/protobird-git/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/smollm/smollm.py),
or inherit [`DecoderOnlyModel`](https://github.com/protobird-git/ai-edge-torch/blob/main/ai_edge_torch/generative/utilities/model_builder.py)
class then modify only some component like
[Llama 3.2](https://github.com/protobird-git/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/llama/llama.py),
or construct entirely a new nn.Module from scratch like
[Gemma 2](https://github.com/protobird-git/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/gemma/gemma2.py).

Here is an example of TinyLlama constructed from scratch.

https://github.com/google-ai-edge/ai-edge-torch/blob/853301630f2b2455bd2e2f73d8a47e1a1534c91c/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py#L46-L77

Expand Down
103 changes: 10 additions & 93 deletions ai_edge_torch/generative/examples/gemma/gemma1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,9 @@

"""Example of building a Gemma1 model."""

from ai_edge_torch.generative.layers import attention
from ai_edge_torch.generative.layers import builder
from ai_edge_torch.generative.layers import kv_cache as kv_utils
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
from torch import nn

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
Expand All @@ -38,84 +33,6 @@
)


class Gemma(nn.Module):
"""A Gemma model built from the Edge Generative API layers."""

def __init__(self, config: cfg.ModelConfig):
super().__init__()

# Construct model layers.
self.tok_embedding = nn.Embedding(
config.vocab_size, config.embedding_dim, padding_idx=0
)
self.lm_head = nn.Linear(
config.embedding_dim,
config.vocab_size,
bias=config.lm_head_use_bias,
)
# Gemma re-uses the embedding as the head projection layer.
self.lm_head.weight.data = self.tok_embedding.weight.data
# Gemma has only one block config.
block_config = config.block_config(0)
self.transformer_blocks = nn.ModuleList(
attention.TransformerBlock(block_config, config)
for _ in range(config.num_layers)
)
self.final_norm = builder.build_norm(
config.embedding_dim,
config.final_norm_config,
)
attn_config = block_config.attn_config
self.rope_cache = attn_utils.build_rope_cache(
size=config.kv_cache_max,
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
base=attn_config.rotary_base,
)
self.mask_cache = attn_utils.build_causal_mask_cache(
size=config.kv_cache_max,
)
self.config = config

@torch.inference_mode
def forward(
self,
tokens: torch.Tensor,
input_pos: torch.Tensor,
kv_cache: kv_utils.KVCache,
) -> dict[torch.Tensor, kv_utils.KVCache]:
_, seq_len = tokens.size()
assert self.config.max_seq_len >= seq_len, (
f"Cannot forward sequence of length {seq_len}, max seq length is only"
f" {self.config.max_seq_len}"
)
assert len(self.transformer_blocks) == len(kv_cache.caches), (
"The number of transformer blocks and the number of KV cache entries"
" must be the same."
)

cos, sin = self.rope_cache
cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.kv_cache_max]

# token embeddings of shape (b, t, n_embd)
x = self.tok_embedding(tokens)
x = x * (self.config.embedding_dim**0.5)

updated_kv_entires = []
for i, block in enumerate(self.transformer_blocks):
kv_entry = kv_cache.caches[i] if kv_cache else None
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
if kv_entry:
updated_kv_entires.append(kv_entry)
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))

x = self.final_norm(x)
logits = self.lm_head(x) # (b, t, vocab_size)
return {"logits": logits, "kv_cache": updated_kv_cache}


def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for a Gemma 2B model.
Expand Down Expand Up @@ -154,6 +71,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
num_layers=18,
max_seq_len=8192,
embedding_dim=2048,
embedding_scale=2048**0.5,
kv_cache_max_len=kv_cache_max_len,
block_configs=block_config,
final_norm_config=norm_config,
Expand All @@ -173,12 +91,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
return config


def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
config = get_model_config_2b(**kwargs)
model = Gemma(config)
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
# Since embedding and lm-head use the same weight, we need to set strict
# to False.
loader.load(model, strict=False)
model.eval()
return model
def build_2b_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config_2b(**kwargs),
tensor_names=TENSOR_NAMES,
)
1 change: 0 additions & 1 deletion ai_edge_torch/generative/examples/gemma/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

"""Example of building a Gemma2 model."""

import os
from typing import Optional, Tuple

from ai_edge_torch.generative.layers import attention
Expand Down
68 changes: 0 additions & 68 deletions ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py

This file was deleted.

15 changes: 13 additions & 2 deletions ai_edge_torch/generative/examples/llama/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
from ai_edge_torch.generative.examples.llama import llama
from ai_edge_torch.generative.utilities import converter

_MODEL_SIZE = flags.DEFINE_enum(
'model_size',
'1b',
['1b', '3b'],
'The size of the model to verify.',
)
_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path',
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
Expand All @@ -49,13 +55,18 @@
'Whether the model should be quantized.',
)

_BUILDER = {
'1b': llama.build_1b_model,
'3b': llama.build_3b_model,
}


def main(_):
pytorch_model = llama.build_model(
pytorch_model = _BUILDER[_MODEL_SIZE.value](
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
)
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
output_filename = f'llama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
Expand Down
43 changes: 19 additions & 24 deletions ai_edge_torch/generative/examples/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,15 @@

"""Example of building Llama 3.2 models."""

import copy
import math
from typing import Tuple

from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
import ai_edge_torch.generative.layers.model_config as cfg
from ai_edge_torch.generative.utilities import model_builder
import ai_edge_torch.generative.utilities.loader as loading_utils
import torch
from torch import nn

TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
# SmolLM re-uses the embedding as the head projection layer.
TENSOR_NAMES.lm_head = None
TENSOR_NAMES = model_builder.TENSOR_NAMES


def _build_llama3_rope_cache(
Expand Down Expand Up @@ -93,17 +89,14 @@ def _build_llama3_rope_cache(
return cos, sin


class Llama(tiny_llama.TinyLlama):
class Llama(model_builder.DecoderOnlyModel):
"""A Llama model built from the Edge Generative API layers.
Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
"""

def __init__(self, config: cfg.ModelConfig):
super().__init__(config)
# Llama 3.2 re-uses the embedding as the head projection layer.
self.lm_head.weight.data = self.tok_embedding.weight.data
# Llama has only one block config.
attn_config = self.config.block_config(0).attn_config
self.rope_cache = _build_llama3_rope_cache(
size=self.config.kv_cache_max,
Expand All @@ -119,7 +112,7 @@ def __init__(self, config: cfg.ModelConfig):
)


def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for a Llama 3.2-1B model.
Args:
Expand Down Expand Up @@ -163,7 +156,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:

def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
"""Returns the model config for a Llama 3.2-3B model."""
config = get_model_config(kv_cache_max_len)
config = get_1b_model_config(kv_cache_max_len)
# Llama 3.2 has only one block config.
attn_config = config.block_config(0).attn_config
attn_config.num_heads = 24
Expand All @@ -174,16 +167,17 @@ def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:


def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
config = get_model_config(**kwargs)
config = get_1b_model_config(**kwargs)
config.vocab_size = 128
config.num_layers = 2
# SmolLM has only one block config.
config.block_config(0).ff_config.intermediate_size = 64
return config


def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
config = get_model_config(**kwargs)
def _build_model(
checkpoint_path: str, config: cfg.ModelConfig
) -> model_builder.DecoderOnlyModel:
model = Llama(config)
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
# Since embedding and lm-head use the same weight, we need to set strict
Expand All @@ -193,12 +187,13 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model


def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
config = get_3b_model_config(**kwargs)
model = Llama(config)
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
# Since embedding and lm-head use the same weight, we need to set strict
# to False.
loader.load(model, strict=False)
model.eval()
return model
def build_1b_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
return _build_model(checkpoint_path, get_1b_model_config(**kwargs))


def build_3b_model(
checkpoint_path: str, **kwargs
) -> model_builder.DecoderOnlyModel:
return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
Loading

0 comments on commit 5a7e1cf

Please sign in to comment.