Skip to content

Commit

Permalink
Add support for Gemma 2 models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 659967095
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Aug 7, 2024
1 parent 8a745ae commit 8b5d177
Show file tree
Hide file tree
Showing 8 changed files with 357 additions and 76 deletions.
2 changes: 1 addition & 1 deletion docs/api/pz.nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Basic Operations
pz.nn.CheckStructure
pz.nn.Identity
pz.nn.CastToDType

pz.nn.TanhSoftCap

Linear and Affine Layers
------------------------
Expand Down
38 changes: 36 additions & 2 deletions docs/guides/howto_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,23 +217,57 @@ You can read more about Penzai's conventions for layers in ["How to Think in Pen

## Loading Pretrained Models

### Loading Gemma
### Loading Gemma or Gemma 2

Penzai's Gemma implementation includes a conversion utility that converts the ["Flax" model weights from Kaggle](https://www.kaggle.com/models/google/gemma) into the correct form. You can load it using:
Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2)) into the correct form. You can load it using:

```python
import kagglehub
import orbax.checkpoint
from penzai.models.transformer import variants

# Download Gemma 1 7B:
weights_dir = kagglehub.model_download('google/gemma/Flax/7b')
ckpt_path = os.path.join(weights_dir, '7b')

# Load the parameters into Penzai:
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
flax_params_dict = checkpointer.restore(ckpt_path)
model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict)
```

To load Gemma 2, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use:

```python
weights_dir = kagglehub.model_download('google/gemma-2/flax/gemma2-9b')
ckpt_path = os.path.join(weights_dir, 'gemma2_9b_pt')
```

See the "Model Variations" section on the Kaggle model pages for details about the names and paths for each checkpoint. (You may also need to create a Kaggle account and request access to each model before you can download the checkpoints.)

If you are using multiple accelerator devices (e.g. for a TPU v2 Colab kernel), you may want to shard the parameters over the devices while loading them. To do so, you can pass a sharding specification to `orbax.checkpoint`. For instance, to shard over the last axis of every parameter, you can use

```python
from jax.experimental import mesh_utils

checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)

n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
sharding = jax.sharding.PositionalSharding(sharding_devices)
restore_args = jax.tree_util.tree_map(
lambda m: orbax.checkpoint.ArrayRestoreArgs(
restore_type=jax.Array,
sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
),
metadata,
)
flax_params_dict = checkpointer.restore(ckpt_path, restore_args=restore_args)
```

to load the Flax parameters before converting them into the Penzai model.

### Loading Llama, Mistral, or GPT-NeoX / Pythia

Penzai also includes re-implementations of the architectures used by [Llama](https://llama.meta.com/), [Mistral](https://mistral.ai/), and the [GPT-NeoX](https://www.eleuther.ai/artifacts/gpt-neox-20b) family of models, including the [Pythia](https://github.com/EleutherAI/pythia) model scaling suite. To load these models into Penzai, you can first load the weights using the HuggingFace `transformers` library, then convert them to Penzai:
Expand Down
193 changes: 164 additions & 29 deletions penzai/models/transformer/variants/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,114 @@

"""The Gemma architecture transformer variant.
See the Gemma technical report at
https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf
and the accompanying reference implementation at
https://github.com/google-deepmind/gemma.
Supports both the Gemma 1 and Gemma 2 architectures. Based on the Flax
reference implementation at https://github.com/google-deepmind/gemma.
See the Gemma technical reports for more information:
* Gemma 1: https://arxiv.org/abs/2403.08295
* Gemma 2: https://arxiv.org/abs/2408.00118
"""

from __future__ import annotations

import itertools
from typing import Any
from typing import Any, Literal

import jax.numpy as jnp
from penzai import pz
from penzai.models.transformer import model_parts
from penzai.models.transformer.variants import llamalike_common


_GEMMA_PRESETS = {
"gemma_2b": dict(
num_decoder_blocks=18,
vocab_size=256_128,
num_kv_heads=1,
query_head_multiplier=8,
embedding_dim=2048,
projection_dim=256,
mlp_hidden_dim=16_384,
),
"gemma_7b": dict(
num_decoder_blocks=28,
vocab_size=256_128,
num_kv_heads=16,
query_head_multiplier=1,
embedding_dim=3072,
projection_dim=256,
mlp_hidden_dim=24_576,
),
"gemma2_2b": dict(
num_decoder_blocks=26,
vocab_size=256_128,
num_kv_heads=4,
query_head_multiplier=2,
embedding_dim=2304,
projection_dim=256,
mlp_hidden_dim=9216,
attention_type=(
llamalike_common.AttentionTypeSlidingWindowCausal(4096),
llamalike_common.AttentionTypeGlobalCausal(),
),
use_post_attn_norm=True,
use_post_ffw_norm=True,
final_logit_softcap=30.0,
attn_logits_soft_cap=50.0,
),
"gemma2_9b": dict(
num_decoder_blocks=42,
vocab_size=256_128,
num_kv_heads=8,
query_head_multiplier=2,
embedding_dim=3584,
projection_dim=256,
mlp_hidden_dim=14_336,
attention_type=(
llamalike_common.AttentionTypeSlidingWindowCausal(4096),
llamalike_common.AttentionTypeGlobalCausal(),
),
use_post_attn_norm=True,
use_post_ffw_norm=True,
final_logit_softcap=30.0,
attn_logits_soft_cap=50.0,
),
"gemma2_27b": dict(
num_decoder_blocks=46,
vocab_size=256_128,
num_kv_heads=16,
query_head_multiplier=2,
embedding_dim=4608,
projection_dim=128,
mlp_hidden_dim=36_864,
# query scaling factor: 1/sqrt(embedding_dim / num_query_heads)
query_scaling_factor=(4608 // 32) ** -0.5,
attention_type=(
llamalike_common.AttentionTypeSlidingWindowCausal(4096),
llamalike_common.AttentionTypeGlobalCausal(),
),
use_post_attn_norm=True,
use_post_ffw_norm=True,
final_logit_softcap=30.0,
attn_logits_soft_cap=50.0,
),
}
_NEEDS_GATING_TRANSPOSE = {
"gemma_2b": False,
"gemma_7b": False,
"gemma2_2b": False,
"gemma2_9b": True,
"gemma2_27b": True,
}


def gemma_from_pretrained_checkpoint(
ckpt_params: dict[str, Any],
upcast_activations_to_float32: bool = False,
use_layer_stack: bool = False,
preset_name: Literal[
"gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "auto"
] = "auto",
) -> model_parts.TransformerLM:
"""Builds a Gemma model from a pretrained checkpoint.
Expand All @@ -56,36 +143,41 @@ def gemma_from_pretrained_checkpoint(
the model runs. This allows analyzing activations at higher precision
without consuming additional memory for parameters.
use_layer_stack: Whether to use a layer stack for the decoder blocks.
preset_name: Preset name, used to determine model config. If "auto", uses
the number of layers in the checkpoint to determine the configuration.
Returns:
A Transformer model containing the loaded parameters.
"""
params = {k.removeprefix("transformer/"): v for k, v in ckpt_params.items()}
num_layers = 0
for i in itertools.count():
if f"layer_{i}/mlp/linear" not in params:
num_layers = i
break
hidden_dim, embed_dim = params["layer_0/mlp/linear"]["w"].shape
attn_0_einsum_param = params["layer_0/attn/attn_vec_einsum"]["w"]
num_heads, proj_dim, _ = attn_0_einsum_param.shape
single_kv_head = "layer_0/attn/qkv_einsum" not in params
vocab_size = params["embedder"]["input_embedding"].shape[0]

if preset_name == "auto":
num_layers = 0
while f"layer_{num_layers}/mlp/linear" in params:
num_layers += 1
preset_by_num_layers = {
kwargs["num_decoder_blocks"]: preset_name
for preset_name, kwargs in _GEMMA_PRESETS.items()
}
if num_layers not in preset_by_num_layers:
raise ValueError(
f"Could not determine preset for model with {num_layers} layers."
)
preset_name = preset_by_num_layers[num_layers]

preset_kwargs = _GEMMA_PRESETS[preset_name]
preset_needs_gating_transpose = _NEEDS_GATING_TRANSPOSE[preset_name]

parameter_dtype = params["layer_0/attn/attn_vec_einsum"]["w"].dtype

if upcast_activations_to_float32:
activation_dtype = jnp.float32
else:
activation_dtype = attn_0_einsum_param.dtype
activation_dtype = parameter_dtype

config = llamalike_common.LlamalikeTransformerConfig(
num_kv_heads=1 if single_kv_head else num_heads,
query_head_multiplier=num_heads if single_kv_head else 1,
embedding_dim=embed_dim,
projection_dim=proj_dim,
mlp_hidden_dim=hidden_dim,
num_decoder_blocks=num_layers,
vocab_size=vocab_size,
parameter_dtype=attn_0_einsum_param.dtype,
**preset_kwargs,
parameter_dtype=parameter_dtype,
mlp_variant="geglu_approx",
rope_wavelength=10_000,
tie_embedder_and_logits=True,
Expand Down Expand Up @@ -115,20 +207,36 @@ def gemma_from_pretrained_checkpoint(
1 + params[f"layer_{i}/pre_attention_norm"]["scale"]
).tag("embedding")
)
if config.use_post_attn_norm:
cur_block_params["post_attention_norm/scale.weights"] = (
pz.nx.NamedArray.wrap(
1 + params[f"layer_{i}/post_attention_norm"]["scale"]
).tag("embedding")
)

cur_block_params["pre_ffw_norm/scale.weights"] = pz.nx.NamedArray.wrap(
1 + params[f"layer_{i}/pre_ffw_norm"]["scale"]
).tag("embedding")
if config.use_post_ffw_norm:
cur_block_params["post_ffw_norm/scale.weights"] = pz.nx.NamedArray.wrap(
1 + params[f"layer_{i}/post_ffw_norm"]["scale"]
).tag("embedding")

gating_einsum_w = params[f"layer_{i}/mlp/gating_einsum"]["w"]
if preset_needs_gating_transpose:
gating_einsum_w = gating_einsum_w.transpose((0, 2, 1))
cur_block_params["mlp/gating_linear.weights"] = pz.nx.NamedArray.wrap(
params[f"layer_{i}/mlp/gating_einsum"]["w"][0]
gating_einsum_w[0]
).tag("embedding", "neurons")
cur_block_params["mlp/value_linear.weights"] = pz.nx.NamedArray.wrap(
params[f"layer_{i}/mlp/gating_einsum"]["w"][1]
gating_einsum_w[1]
).tag("embedding", "neurons")

cur_block_params["mlp/out_linear.weights"] = pz.nx.NamedArray.wrap(
params[f"layer_{i}/mlp/linear"]["w"]
).tag("neurons", "embedding")

if single_kv_head:
if config.num_kv_heads == 1:
cur_block_params["attention/query.weights"] = pz.nx.NamedArray.wrap(
params[f"layer_{i}/attn/q_einsum"]["w"]
).tag("query_heads", "embedding", "projection")
Expand All @@ -141,7 +249,7 @@ def gemma_from_pretrained_checkpoint(
cur_block_params["attention/output.weights"] = pz.nx.NamedArray.wrap(
params[f"layer_{i}/attn/attn_vec_einsum"]["w"]
).tag("query_heads", "projection", "embedding")
else:
elif config.query_head_multiplier == 1:
cur_block_params["attention/query.weights"] = pz.nx.NamedArray.wrap(
params[f"layer_{i}/attn/qkv_einsum"]["w"][0]
).tag("heads", "embedding", "projection")
Expand All @@ -154,6 +262,33 @@ def gemma_from_pretrained_checkpoint(
cur_block_params["attention/output.weights"] = pz.nx.NamedArray.wrap(
params[f"layer_{i}/attn/attn_vec_einsum"]["w"]
).tag("heads", "projection", "embedding")
else:
# Grouped query attention: split attention heads into groups.
cur_block_params["attention/key.weights"] = pz.nx.NamedArray.wrap(
params[f"layer_{i}/attn/kv_einsum"]["w"][0]
).tag("head_groups", "embedding", "projection")
cur_block_params["attention/value.weights"] = pz.nx.NamedArray.wrap(
params[f"layer_{i}/attn/kv_einsum"]["w"][1]
).tag("head_groups", "embedding", "projection")

q_weights = params[f"layer_{i}/attn/q_einsum"]["w"]
out_weights = params[f"layer_{i}/attn/attn_vec_einsum"]["w"]
cur_block_params["attention/query.weights"] = pz.nx.NamedArray.wrap(
q_weights.reshape((
config.num_kv_heads,
config.query_head_multiplier,
config.embedding_dim,
config.projection_dim,
))
).tag("head_groups", "query_heads", "embedding", "projection")
cur_block_params["attention/output.weights"] = pz.nx.NamedArray.wrap(
out_weights.reshape((
config.num_kv_heads,
config.query_head_multiplier,
config.projection_dim,
config.embedding_dim,
))
).tag("head_groups", "query_heads", "projection", "embedding")

if use_layer_stack:
for key in all_block_params[0].keys():
Expand Down
Loading

0 comments on commit 8b5d177

Please sign in to comment.