Skip to content

Commit

Permalink
Check for modified config attributes during HuggingFace conversion (#104
Browse files Browse the repository at this point in the history
)

Adds a more robust check for modified config attributes when converting
a HuggingFace model to a Penzai model. This can be used to prevent
loading a model whose architecture Penzai does not yet support.

The check works by enumerating all differences between the model's config
and a known-convertible config, and ensuring the only changes are to
attributes that we can correctly handle during conversion.
  • Loading branch information
danieldjohnson authored Dec 16, 2024
1 parent e23bfed commit 6a96d6b
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 55 deletions.
67 changes: 51 additions & 16 deletions penzai/models/transformer/variants/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def gpt_neox_from_huggingface_model(
scaling suite.
Args:
model: The HuggingFace Llama or Mistral model.
model: The HuggingFace GPT-NeoX model.
upcast_activations_to_float32: Whether to cast activations to float32 when
the model runs. This allows analyzing activations at higher precision
without consuming additional memory for parameters.
Expand All @@ -375,22 +375,57 @@ def gpt_neox_from_huggingface_model(
Returns:
A Transformer model containing the loaded parameters.
"""
# Checkpoint conversion assumes these configuration arguments are set:
try:
import transformers # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise RuntimeError("HuggingFace transformers is not available") from exc

if type(model) is not transformers.GPTNeoXForCausalLM: # pylint: disable=unidiomatic-typecheck
raise ValueError(
"gpt_neox_from_huggingface_model should be called with a"
f" GPTNeoXForCausalLM instance, but got {type(model).__name__}."
)

hf_config = model.config
checked_config_args = dict(
use_parallel_residual=True,
rope_scaling=None,
attention_bias=True,
attention_dropout=0.0,
hidden_dropout=0.0,
)
for k, v in checked_config_args.items():
actual_value = getattr(hf_config, k)
if actual_value != v:
raise ValueError(
f"Conversion of a GPTNeoXForCausalLM requires config.{k}={repr(v)},"
f" but got {actual_value}"
)
# Check any modified configuration arguments against the base config to make
# sure we support newer architecture features. (Assumes that new features are
# added in a backwards-compatible way and do not change the defaults for the
# configuration class.)
hf_config_attributes = hf_config.to_dict()
reference_attributes = transformers.GPTNeoXConfig().to_dict()
handled_or_ignored_attributes = {
# Handled during conversion:
"hidden_act",
"hidden_size",
"intermediate_size",
"layer_norm_eps",
"num_attention_heads",
"num_hidden_layers",
"rotary_emb_base",
"rotary_pct",
"vocab_size",
# Ignored by conversion:
"max_position_embeddings",
"torch_dtype",
"architectures",
"bos_token_id",
"eos_token_id",
"_attn_implementation_autoset",
"head_dim",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
if k in handled_or_ignored_attributes or (
k in reference_attributes and v == reference_attributes[k]
):
pass
else:
bad_attributes[k] = v
if bad_attributes:
raise ValueError(
"Conversion of a GPTNeoXForCausalLM does not support these"
f" configuration attributes: {repr(bad_attributes)}"
)

param_dtype = {
"torch.float32": jnp.float32,
Expand Down
66 changes: 45 additions & 21 deletions penzai/models/transformer/variants/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,55 @@ def llama_from_huggingface_model(
Returns:
A Transformer model containing the loaded parameters.
"""
if type(model).__name__ != "LlamaForCausalLM":
try:
import transformers # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise RuntimeError("HuggingFace transformers is not available") from exc

if type(model) is not transformers.LlamaForCausalLM: # pylint: disable=unidiomatic-typecheck
raise ValueError(
"llama_from_huggingface_model should be called with a"
f" LlamaForCausalLM instance, but got {type(model).__name__}."
)
# Checkpoint conversion assumes these configuration arguments are set:
hf_config = model.config
checked_config_args = dict(
hidden_act="silu",
tie_word_embeddings=False,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
)
for k, v in checked_config_args.items():
try:
actual_value = getattr(hf_config, k)
except AttributeError:
continue
if actual_value != v:
raise ValueError(
f"Conversion of a LlamaForCausalLM requires config.{k}={repr(v)}, but"
f" got {actual_value}"
)

# Check any modified configuration arguments against the base config to make
# sure we support newer architecture features. (Assumes that new features are
# added in a backwards-compatible way and do not change the defaults for the
# configuration class.)
hf_config_attributes = model.config.to_dict()
reference_attributes = transformers.LlamaConfig().to_dict()
handled_or_ignored_attributes = {
# Handled during conversion:
"hidden_size",
"intermediate_size",
"num_attention_heads",
"num_hidden_layers",
"num_key_value_heads",
"rms_norm_eps",
"rope_theta",
"vocab_size",
# Ignored by conversion:
"max_position_embeddings",
"torch_dtype",
"architectures",
"bos_token_id",
"eos_token_id",
"_attn_implementation_autoset",
"head_dim",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
if k in handled_or_ignored_attributes or (
k in reference_attributes and v == reference_attributes[k]
):
pass
else:
bad_attributes[k] = v
if bad_attributes:
raise ValueError(
"Conversion of a LlamaForCausalLM does not support these configuration"
f" attributes: {repr(bad_attributes)}"
)

return llamalike_common.llamalike_from_huggingface_model(
model,
Expand Down
62 changes: 44 additions & 18 deletions penzai/models/transformer/variants/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,54 @@ def mistral_from_huggingface_model(
Returns:
A Transformer model containing the loaded parameters.
"""
if type(model).__name__ != "MistralForCausalLM":
try:
import transformers # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise RuntimeError("HuggingFace transformers is not available") from exc

if type(model) is not transformers.MistralForCausalLM: # pylint: disable=unidiomatic-typecheck
raise ValueError(
"mistral_from_huggingface_model should be called with a"
f" MistralForCausalLM instance, but got {type(model).__name__}."
)
# Checkpoint conversion assumes these configuration arguments are set:
hf_config = model.config
checked_config_args = dict(
hidden_act="silu",
tie_word_embeddings=False,
attention_dropout=0.0,
)
for k, v in checked_config_args.items():
try:
actual_value = getattr(hf_config, k)
except AttributeError:
continue
if actual_value != v:
raise ValueError(
f"Conversion of a MistralForCausalLM requires config.{k}={repr(v)},"
f" but got {actual_value}"
)

# Check any modified configuration arguments against the base config to make
# sure we support newer architecture features. (Assumes that new features are
# added in a backwards-compatible way and do not change the defaults for the
# configuration class.)
hf_config_attributes = model.config.to_dict()
reference_attributes = transformers.MistralConfig().to_dict()
handled_or_ignored_attributes = {
# Handled during conversion:
"hidden_size",
"intermediate_size",
"num_attention_heads",
"num_hidden_layers",
"num_key_value_heads",
"rms_norm_eps",
"rope_theta",
"vocab_size",
"sliding_window",
# Ignored by conversion:
"max_position_embeddings",
"torch_dtype",
"architectures",
"_attn_implementation_autoset",
"head_dim",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
if k in handled_or_ignored_attributes or (
k in reference_attributes and v == reference_attributes[k]
):
pass
else:
bad_attributes[k] = v
if bad_attributes:
raise ValueError(
"Conversion of a MistralForCausalLM does not support these"
f" configuration attributes: {repr(bad_attributes)}"
)

return llamalike_common.llamalike_from_huggingface_model(
model,
Expand Down

0 comments on commit 6a96d6b

Please sign in to comment.