From 8b5d177dce6409dee327d73bc5a5b95d35fca4b0 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Tue, 6 Aug 2024 08:29:44 -0700 Subject: [PATCH] Add support for Gemma 2 models. PiperOrigin-RevId: 659967095 --- docs/api/pz.nn.rst | 2 +- docs/guides/howto_reference.md | 38 +++- penzai/models/transformer/variants/gemma.py | 193 +++++++++++++++--- .../transformer/variants/llamalike_common.py | 146 +++++++++---- penzai/nn/basic_ops.py | 15 ++ penzai/nn/standardization.py | 2 +- penzai/pz/nn.py | 1 + tests/models/transformer_llamalike_test.py | 36 ++++ 8 files changed, 357 insertions(+), 76 deletions(-) diff --git a/docs/api/pz.nn.rst b/docs/api/pz.nn.rst index 85150ac..fa5130d 100644 --- a/docs/api/pz.nn.rst +++ b/docs/api/pz.nn.rst @@ -42,7 +42,7 @@ Basic Operations pz.nn.CheckStructure pz.nn.Identity pz.nn.CastToDType - + pz.nn.TanhSoftCap Linear and Affine Layers ------------------------ diff --git a/docs/guides/howto_reference.md b/docs/guides/howto_reference.md index 4d00c87..cddf53c 100644 --- a/docs/guides/howto_reference.md +++ b/docs/guides/howto_reference.md @@ -217,23 +217,57 @@ You can read more about Penzai's conventions for layers in ["How to Think in Pen ## Loading Pretrained Models -### Loading Gemma +### Loading Gemma or Gemma 2 -Penzai's Gemma implementation includes a conversion utility that converts the ["Flax" model weights from Kaggle](https://www.kaggle.com/models/google/gemma) into the correct form. You can load it using: +Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2)) into the correct form. You can load it using: ```python import kagglehub import orbax.checkpoint from penzai.models.transformer import variants +# Download Gemma 1 7B: weights_dir = kagglehub.model_download('google/gemma/Flax/7b') ckpt_path = os.path.join(weights_dir, '7b') +# Load the parameters into Penzai: checkpointer = orbax.checkpoint.PyTreeCheckpointer() flax_params_dict = checkpointer.restore(ckpt_path) model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict) ``` +To load Gemma 2, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use: + +```python +weights_dir = kagglehub.model_download('google/gemma-2/flax/gemma2-9b') +ckpt_path = os.path.join(weights_dir, 'gemma2_9b_pt') +``` + +See the "Model Variations" section on the Kaggle model pages for details about the names and paths for each checkpoint. (You may also need to create a Kaggle account and request access to each model before you can download the checkpoints.) + +If you are using multiple accelerator devices (e.g. for a TPU v2 Colab kernel), you may want to shard the parameters over the devices while loading them. To do so, you can pass a sharding specification to `orbax.checkpoint`. For instance, to shard over the last axis of every parameter, you can use + +```python +from jax.experimental import mesh_utils + +checkpointer = orbax.checkpoint.PyTreeCheckpointer() +metadata = checkpointer.metadata(ckpt_path) + +n_devices = jax.local_device_count() +sharding_devices = mesh_utils.create_device_mesh((n_devices,)) +sharding = jax.sharding.PositionalSharding(sharding_devices) +restore_args = jax.tree_util.tree_map( + lambda m: orbax.checkpoint.ArrayRestoreArgs( + restore_type=jax.Array, + sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,)) + ), + metadata, +) +flax_params_dict = checkpointer.restore(ckpt_path, restore_args=restore_args) +``` + +to load the Flax parameters before converting them into the Penzai model. + ### Loading Llama, Mistral, or GPT-NeoX / Pythia Penzai also includes re-implementations of the architectures used by [Llama](https://llama.meta.com/), [Mistral](https://mistral.ai/), and the [GPT-NeoX](https://www.eleuther.ai/artifacts/gpt-neox-20b) family of models, including the [Pythia](https://github.com/EleutherAI/pythia) model scaling suite. To load these models into Penzai, you can first load the weights using the HuggingFace `transformers` library, then convert them to Penzai: diff --git a/penzai/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py index eeb7796..1b8289f 100644 --- a/penzai/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -14,16 +14,18 @@ """The Gemma architecture transformer variant. -See the Gemma technical report at -https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf -and the accompanying reference implementation at -https://github.com/google-deepmind/gemma. +Supports both the Gemma 1 and Gemma 2 architectures. Based on the Flax +reference implementation at https://github.com/google-deepmind/gemma. + +See the Gemma technical reports for more information: + +* Gemma 1: https://arxiv.org/abs/2403.08295 +* Gemma 2: https://arxiv.org/abs/2408.00118 """ from __future__ import annotations -import itertools -from typing import Any +from typing import Any, Literal import jax.numpy as jnp from penzai import pz @@ -31,10 +33,95 @@ from penzai.models.transformer.variants import llamalike_common +_GEMMA_PRESETS = { + "gemma_2b": dict( + num_decoder_blocks=18, + vocab_size=256_128, + num_kv_heads=1, + query_head_multiplier=8, + embedding_dim=2048, + projection_dim=256, + mlp_hidden_dim=16_384, + ), + "gemma_7b": dict( + num_decoder_blocks=28, + vocab_size=256_128, + num_kv_heads=16, + query_head_multiplier=1, + embedding_dim=3072, + projection_dim=256, + mlp_hidden_dim=24_576, + ), + "gemma2_2b": dict( + num_decoder_blocks=26, + vocab_size=256_128, + num_kv_heads=4, + query_head_multiplier=2, + embedding_dim=2304, + projection_dim=256, + mlp_hidden_dim=9216, + attention_type=( + llamalike_common.AttentionTypeSlidingWindowCausal(4096), + llamalike_common.AttentionTypeGlobalCausal(), + ), + use_post_attn_norm=True, + use_post_ffw_norm=True, + final_logit_softcap=30.0, + attn_logits_soft_cap=50.0, + ), + "gemma2_9b": dict( + num_decoder_blocks=42, + vocab_size=256_128, + num_kv_heads=8, + query_head_multiplier=2, + embedding_dim=3584, + projection_dim=256, + mlp_hidden_dim=14_336, + attention_type=( + llamalike_common.AttentionTypeSlidingWindowCausal(4096), + llamalike_common.AttentionTypeGlobalCausal(), + ), + use_post_attn_norm=True, + use_post_ffw_norm=True, + final_logit_softcap=30.0, + attn_logits_soft_cap=50.0, + ), + "gemma2_27b": dict( + num_decoder_blocks=46, + vocab_size=256_128, + num_kv_heads=16, + query_head_multiplier=2, + embedding_dim=4608, + projection_dim=128, + mlp_hidden_dim=36_864, + # query scaling factor: 1/sqrt(embedding_dim / num_query_heads) + query_scaling_factor=(4608 // 32) ** -0.5, + attention_type=( + llamalike_common.AttentionTypeSlidingWindowCausal(4096), + llamalike_common.AttentionTypeGlobalCausal(), + ), + use_post_attn_norm=True, + use_post_ffw_norm=True, + final_logit_softcap=30.0, + attn_logits_soft_cap=50.0, + ), +} +_NEEDS_GATING_TRANSPOSE = { + "gemma_2b": False, + "gemma_7b": False, + "gemma2_2b": False, + "gemma2_9b": True, + "gemma2_27b": True, +} + + def gemma_from_pretrained_checkpoint( ckpt_params: dict[str, Any], upcast_activations_to_float32: bool = False, use_layer_stack: bool = False, + preset_name: Literal[ + "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "auto" + ] = "auto", ) -> model_parts.TransformerLM: """Builds a Gemma model from a pretrained checkpoint. @@ -56,36 +143,41 @@ def gemma_from_pretrained_checkpoint( the model runs. This allows analyzing activations at higher precision without consuming additional memory for parameters. use_layer_stack: Whether to use a layer stack for the decoder blocks. + preset_name: Preset name, used to determine model config. If "auto", uses + the number of layers in the checkpoint to determine the configuration. Returns: A Transformer model containing the loaded parameters. """ params = {k.removeprefix("transformer/"): v for k, v in ckpt_params.items()} - num_layers = 0 - for i in itertools.count(): - if f"layer_{i}/mlp/linear" not in params: - num_layers = i - break - hidden_dim, embed_dim = params["layer_0/mlp/linear"]["w"].shape - attn_0_einsum_param = params["layer_0/attn/attn_vec_einsum"]["w"] - num_heads, proj_dim, _ = attn_0_einsum_param.shape - single_kv_head = "layer_0/attn/qkv_einsum" not in params - vocab_size = params["embedder"]["input_embedding"].shape[0] + + if preset_name == "auto": + num_layers = 0 + while f"layer_{num_layers}/mlp/linear" in params: + num_layers += 1 + preset_by_num_layers = { + kwargs["num_decoder_blocks"]: preset_name + for preset_name, kwargs in _GEMMA_PRESETS.items() + } + if num_layers not in preset_by_num_layers: + raise ValueError( + f"Could not determine preset for model with {num_layers} layers." + ) + preset_name = preset_by_num_layers[num_layers] + + preset_kwargs = _GEMMA_PRESETS[preset_name] + preset_needs_gating_transpose = _NEEDS_GATING_TRANSPOSE[preset_name] + + parameter_dtype = params["layer_0/attn/attn_vec_einsum"]["w"].dtype if upcast_activations_to_float32: activation_dtype = jnp.float32 else: - activation_dtype = attn_0_einsum_param.dtype + activation_dtype = parameter_dtype config = llamalike_common.LlamalikeTransformerConfig( - num_kv_heads=1 if single_kv_head else num_heads, - query_head_multiplier=num_heads if single_kv_head else 1, - embedding_dim=embed_dim, - projection_dim=proj_dim, - mlp_hidden_dim=hidden_dim, - num_decoder_blocks=num_layers, - vocab_size=vocab_size, - parameter_dtype=attn_0_einsum_param.dtype, + **preset_kwargs, + parameter_dtype=parameter_dtype, mlp_variant="geglu_approx", rope_wavelength=10_000, tie_embedder_and_logits=True, @@ -115,20 +207,36 @@ def gemma_from_pretrained_checkpoint( 1 + params[f"layer_{i}/pre_attention_norm"]["scale"] ).tag("embedding") ) + if config.use_post_attn_norm: + cur_block_params["post_attention_norm/scale.weights"] = ( + pz.nx.NamedArray.wrap( + 1 + params[f"layer_{i}/post_attention_norm"]["scale"] + ).tag("embedding") + ) + cur_block_params["pre_ffw_norm/scale.weights"] = pz.nx.NamedArray.wrap( 1 + params[f"layer_{i}/pre_ffw_norm"]["scale"] ).tag("embedding") + if config.use_post_ffw_norm: + cur_block_params["post_ffw_norm/scale.weights"] = pz.nx.NamedArray.wrap( + 1 + params[f"layer_{i}/post_ffw_norm"]["scale"] + ).tag("embedding") + + gating_einsum_w = params[f"layer_{i}/mlp/gating_einsum"]["w"] + if preset_needs_gating_transpose: + gating_einsum_w = gating_einsum_w.transpose((0, 2, 1)) cur_block_params["mlp/gating_linear.weights"] = pz.nx.NamedArray.wrap( - params[f"layer_{i}/mlp/gating_einsum"]["w"][0] + gating_einsum_w[0] ).tag("embedding", "neurons") cur_block_params["mlp/value_linear.weights"] = pz.nx.NamedArray.wrap( - params[f"layer_{i}/mlp/gating_einsum"]["w"][1] + gating_einsum_w[1] ).tag("embedding", "neurons") + cur_block_params["mlp/out_linear.weights"] = pz.nx.NamedArray.wrap( params[f"layer_{i}/mlp/linear"]["w"] ).tag("neurons", "embedding") - if single_kv_head: + if config.num_kv_heads == 1: cur_block_params["attention/query.weights"] = pz.nx.NamedArray.wrap( params[f"layer_{i}/attn/q_einsum"]["w"] ).tag("query_heads", "embedding", "projection") @@ -141,7 +249,7 @@ def gemma_from_pretrained_checkpoint( cur_block_params["attention/output.weights"] = pz.nx.NamedArray.wrap( params[f"layer_{i}/attn/attn_vec_einsum"]["w"] ).tag("query_heads", "projection", "embedding") - else: + elif config.query_head_multiplier == 1: cur_block_params["attention/query.weights"] = pz.nx.NamedArray.wrap( params[f"layer_{i}/attn/qkv_einsum"]["w"][0] ).tag("heads", "embedding", "projection") @@ -154,6 +262,33 @@ def gemma_from_pretrained_checkpoint( cur_block_params["attention/output.weights"] = pz.nx.NamedArray.wrap( params[f"layer_{i}/attn/attn_vec_einsum"]["w"] ).tag("heads", "projection", "embedding") + else: + # Grouped query attention: split attention heads into groups. + cur_block_params["attention/key.weights"] = pz.nx.NamedArray.wrap( + params[f"layer_{i}/attn/kv_einsum"]["w"][0] + ).tag("head_groups", "embedding", "projection") + cur_block_params["attention/value.weights"] = pz.nx.NamedArray.wrap( + params[f"layer_{i}/attn/kv_einsum"]["w"][1] + ).tag("head_groups", "embedding", "projection") + + q_weights = params[f"layer_{i}/attn/q_einsum"]["w"] + out_weights = params[f"layer_{i}/attn/attn_vec_einsum"]["w"] + cur_block_params["attention/query.weights"] = pz.nx.NamedArray.wrap( + q_weights.reshape(( + config.num_kv_heads, + config.query_head_multiplier, + config.embedding_dim, + config.projection_dim, + )) + ).tag("head_groups", "query_heads", "embedding", "projection") + cur_block_params["attention/output.weights"] = pz.nx.NamedArray.wrap( + out_weights.reshape(( + config.num_kv_heads, + config.query_head_multiplier, + config.projection_dim, + config.embedding_dim, + )) + ).tag("head_groups", "query_heads", "projection", "embedding") if use_layer_stack: for key in all_block_params[0].keys(): diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index f6cc013..4132ea5 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -64,6 +64,9 @@ class AttentionTypeSlidingWindowCausal: class LlamalikeTransformerConfig: """Common configuration parameters for a "llama-like" transformer. + This config encompasses the parameters for the Llama, Mistral, and Gemma + model families. + These are held in a single configuration object to simplify argument passing during construction of the model. @@ -86,6 +89,16 @@ class LlamalikeTransformerConfig: attention_type: A single attention type or sequence of per-layer attention types. If a sequence, its length should evenly divide the number of decoder blocks, and will be repeated to match the number of blocks. + use_post_attn_norm: Whether to add a normalization layer after the attention + block. + use_post_ffw_norm: Whether to add a normalization layer after the + feedforward block. + final_logit_softcap: If not None, used as the tanh soft cap for the final + transformer logits. + attn_logits_soft_cap: If not None, used as the tanh soft cap for the + attention logits. + query_scaling_factor: Scaling factor for the query vectors. If "default", + defaults to 1 / sqrt(projection_dim). parameter_dtype: Floating dtype to use for all parameters. activation_dtype: Floating dtype to use for activations and KV cache tables. use_layer_stack: Whether to stack the blocks together using a LayerStack. @@ -105,6 +118,11 @@ class LlamalikeTransformerConfig: attention_type: AttentionType | Sequence[AttentionType] = ( AttentionTypeGlobalCausal() ) + use_post_attn_norm: bool = False + use_post_ffw_norm: bool = False + final_logit_softcap: float | None = None + attn_logits_soft_cap: float | None = None + query_scaling_factor: float | Literal["default"] = "default" parameter_dtype: jax.typing.DTypeLike = jnp.float32 activation_dtype: jax.typing.DTypeLike = jnp.float32 use_layer_stack: bool = False @@ -218,6 +236,11 @@ def build_llamalike_attention( config ) + if config.query_scaling_factor == "default": + query_scaling_factor = projection_dim**-0.5 + else: + query_scaling_factor = config.query_scaling_factor + # As used in https://github.com/google-deepmind/gemma. # (This exact value is probably not important.) masked_out_value = jnp.array(-2.3819763e38, dtype=config.activation_dtype) @@ -245,6 +268,28 @@ def build_llamalike_attention( else: raise ValueError(f"Unsupported attention type {attention_type}") + query_key_to_attn_sublayers = [ + pz.nn.NamedEinsum( + ( + {"seq": "tq", **qkv_einsum, **q_einsum, "projection": "p"}, + {"seq": "tkv", **qkv_einsum, "projection": "p"}, + ), + {"seq": "tq", **qkv_einsum, **q_einsum, "kv_seq": "tkv"}, + ), + ] + if config.attn_logits_soft_cap is not None: + query_key_to_attn_sublayers.append( + pz.nn.TanhSoftCap( + soft_cap=jnp.array( + config.attn_logits_soft_cap, dtype=config.activation_dtype + ) + ) + ) + query_key_to_attn_sublayers.extend([ + attn_masker, + pz.nn.Softmax("kv_seq"), + ]) + return pz.nn.Attention( input_to_query=pz.nn.Sequential([ pz.nn.Linear.from_config( @@ -264,7 +309,7 @@ def build_llamalike_attention( max_wavelength=config.rope_wavelength, ), pz.nn.ConstantRescale( - by=jnp.array(projection_dim**-0.5, dtype=config.activation_dtype) + by=jnp.array(query_scaling_factor, dtype=config.activation_dtype) ), ]), input_to_key=pz.nn.Sequential([ @@ -290,17 +335,7 @@ def build_llamalike_attention( dtype=config.parameter_dtype, ), ]), - query_key_to_attn=pz.nn.Sequential([ - pz.nn.NamedEinsum( - ( - {"seq": "tq", **qkv_einsum, **q_einsum, "projection": "p"}, - {"seq": "tkv", **qkv_einsum, "projection": "p"}, - ), - {"seq": "tq", **qkv_einsum, **q_einsum, "kv_seq": "tkv"}, - ), - attn_masker, - pz.nn.Softmax("kv_seq"), - ]), + query_key_to_attn=pz.nn.Sequential(query_key_to_attn_sublayers), attn_value_to_output=pz.nn.Sequential([ pz.nn.NamedEinsum( ( @@ -342,39 +377,55 @@ def build_llamalike_block( Returns: A full transformer block. """ + attn_sequence = [ + pz.nn.RMSLayerNorm.from_config( + name=f"{name}/pre_attention_norm", + init_base_rng=init_base_rng, + across_axes={"embedding": config.embedding_dim}, + dtype=config.parameter_dtype, + epsilon=config.rms_norm_eps, + ), + build_llamalike_attention( + f"{name}/attention", + init_base_rng, + config, + block_index=block_index, + ), + ] + if config.use_post_attn_norm: + attn_sequence.append( + pz.nn.RMSLayerNorm.from_config( + name=f"{name}/post_attention_norm", + init_base_rng=init_base_rng, + across_axes={"embedding": config.embedding_dim}, + dtype=config.parameter_dtype, + epsilon=config.rms_norm_eps, + ) + ) + ffw_sequence = [ + pz.nn.RMSLayerNorm.from_config( + name=f"{name}/pre_ffw_norm", + init_base_rng=init_base_rng, + across_axes={"embedding": config.embedding_dim}, + dtype=config.parameter_dtype, + epsilon=config.rms_norm_eps, + ), + build_llamalike_feedforward(f"{name}/mlp", init_base_rng, config), + ] + if config.use_post_ffw_norm: + ffw_sequence.append( + pz.nn.RMSLayerNorm.from_config( + name=f"{name}/post_ffw_norm", + init_base_rng=init_base_rng, + across_axes={"embedding": config.embedding_dim}, + dtype=config.parameter_dtype, + epsilon=config.rms_norm_eps, + ) + ) return model_parts.TransformerBlock( sublayers=[ - pz.nn.Residual( - pz.nn.Sequential([ - pz.nn.RMSLayerNorm.from_config( - name=f"{name}/pre_attention_norm", - init_base_rng=init_base_rng, - across_axes={"embedding": config.embedding_dim}, - dtype=config.parameter_dtype, - epsilon=config.rms_norm_eps, - ), - build_llamalike_attention( - f"{name}/attention", - init_base_rng, - config, - block_index=block_index, - ), - ]) - ), - pz.nn.Residual( - pz.nn.Sequential([ - pz.nn.RMSLayerNorm.from_config( - name=f"{name}/pre_ffw_norm", - init_base_rng=init_base_rng, - across_axes={"embedding": config.embedding_dim}, - dtype=config.parameter_dtype, - epsilon=config.rms_norm_eps, - ), - build_llamalike_feedforward( - f"{name}/mlp", init_base_rng, config - ), - ]) - ), + pz.nn.Residual(pz.nn.Sequential(attn_sequence)), + pz.nn.Residual(pz.nn.Sequential(ffw_sequence)), ], ) @@ -465,6 +516,15 @@ def build_llamalike_transformer( ) ) + if config.final_logit_softcap: + sublayers.append( + pz.nn.TanhSoftCap( + soft_cap=jnp.array( + config.final_logit_softcap, dtype=config.activation_dtype + ) + ) + ) + common_head_axes, _, query_only_head_axes, _ = _head_info(config) return model_parts.TransformerLM( metadata=model_parts.TransformerMetadata( diff --git a/penzai/nn/basic_ops.py b/penzai/nn/basic_ops.py index 51211b7..1983fe7 100644 --- a/penzai/nn/basic_ops.py +++ b/penzai/nn/basic_ops.py @@ -20,6 +20,7 @@ from typing import Any, Callable import jax +import jax.numpy as jnp from penzai.core import named_axes from penzai.core import struct from penzai.nn import layer @@ -86,3 +87,17 @@ class CastToDType(layer.Layer): def __call__(self, value: Any, **_unused_side_inputs) -> Any: return jax.tree_util.tree_map(lambda x: x.astype(self.dtype), value) + + +@struct.pytree_dataclass +class TanhSoftCap(layer.Layer): + """Softly rescales a value to lie within a range using tanh. + + Attributes: + soft_cap: The value to rescale to. + """ + + soft_cap: float | jax.Array + + def __call__(self, value: Any, **_unused_side_inputs) -> Any: + return named_axes.nmap(jnp.tanh)(value / self.soft_cap) * self.soft_cap diff --git a/penzai/nn/standardization.py b/penzai/nn/standardization.py index 4bfb4b9..8a1b1e4 100644 --- a/penzai/nn/standardization.py +++ b/penzai/nn/standardization.py @@ -136,7 +136,7 @@ def __call__(self, value: NamedArray, **_unused_side_inputs) -> NamedArray: @named_axes.nmap def _rms_standardize(x): var = jnp.mean(jnp.square(x)) - return x * jnp.reciprocal(jnp.sqrt(var + self.epsilon)) + return x * jax.lax.rsqrt(var + self.epsilon) return _rms_standardize(value.untag(*across)).tag(*across) diff --git a/penzai/pz/nn.py b/penzai/pz/nn.py index 0a96503..fcd79fc 100644 --- a/penzai/pz/nn.py +++ b/penzai/pz/nn.py @@ -27,6 +27,7 @@ CastToDType, Elementwise, Softmax, + TanhSoftCap, ) from penzai.nn.combinators import ( Residual, diff --git a/tests/models/transformer_llamalike_test.py b/tests/models/transformer_llamalike_test.py index 4b0af26..d5dc579 100644 --- a/tests/models/transformer_llamalike_test.py +++ b/tests/models/transformer_llamalike_test.py @@ -70,6 +70,22 @@ class LlamalikeTransformerTest(parameterized.TestCase): activation_dtype=jnp.float32, mlp_variant="swiglu", ), + dict( + testcase_name="like_gemma2", + num_kv_heads=4, + query_head_multiplier=2, + parameter_dtype=jnp.bfloat16, + activation_dtype=jnp.bfloat16, + query_scaling_factor=0.7, + attention_type=( + llamalike_common.AttentionTypeSlidingWindowCausal(8), + llamalike_common.AttentionTypeGlobalCausal(), + ), + use_post_attn_norm=True, + use_post_ffw_norm=True, + final_logit_softcap=30.0, + attn_logits_soft_cap=50.0, + ), ) def test_build_and_run_gemma( self, @@ -78,6 +94,7 @@ def test_build_and_run_gemma( parameter_dtype, activation_dtype, mlp_variant="geglu_approx", + **extra_kwargs, ): def run_traced(rng_key): @@ -95,6 +112,7 @@ def run_traced(rng_key): mlp_variant=mlp_variant, rope_wavelength=10_000, tie_embedder_and_logits=True, + **extra_kwargs, ), init_base_rng=rng_key, ) @@ -146,6 +164,22 @@ def run_traced(rng_key): parameter_dtype=jnp.float32, activation_dtype=jnp.float32, ), + dict( + testcase_name="like_gemma2", + num_kv_heads=4, + query_head_multiplier=2, + parameter_dtype=jnp.bfloat16, + activation_dtype=jnp.bfloat16, + query_scaling_factor=0.7, + attention_type=( + llamalike_common.AttentionTypeSlidingWindowCausal(8), + llamalike_common.AttentionTypeGlobalCausal(), + ), + use_post_attn_norm=True, + use_post_ffw_norm=True, + final_logit_softcap=30.0, + attn_logits_soft_cap=50.0, + ), ) def test_build_and_run_sampling_mode( self, @@ -153,6 +187,7 @@ def test_build_and_run_sampling_mode( query_head_multiplier: int, parameter_dtype, activation_dtype, + **extra_kwargs, ): model = llamalike_common.build_llamalike_transformer( @@ -169,6 +204,7 @@ def test_build_and_run_sampling_mode( parameter_dtype=parameter_dtype, activation_dtype=activation_dtype, tie_embedder_and_logits=True, + **extra_kwargs, ), init_base_rng=jax.random.key(2), )