diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index 2eae1ce5..c69832ca 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -31,6 +31,8 @@ pub struct Gpt2 { // weighted positional encodings wpe: Tensor, // language model head + // + // Optional: if not present, the `wte` tensor is used instead. lm_head: Option, // weights for the model @@ -59,6 +61,9 @@ impl KnownModel for Gpt2 { let ln_f_b = tl.load("model/ln_f/b")?; let wte = tl.load("model/wte")?; let wpe = tl.load("model/wpe")?; + + // GPT-2's language model head is optional; if it is not present, + // the `wte` tensor is used instead. let lm_head = tl.load("model/lm_head").ok(); let mut layers = Vec::new();