diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index 1b2427a5..5742a625 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -30,8 +30,10 @@ pub struct Gpt2 { wte: Tensor, // weighted positional encodings wpe: Tensor, - // language model head - lm_head: Tensor, + // language model head. + // + // Optional: if not present, the `wte` tensor is used instead. + lm_head: Option, // weights for the model layers: Vec, @@ -59,7 +61,9 @@ impl KnownModel for Gpt2 { let ln_f_b = tl.load("model/ln_f/b")?; let wte = tl.load("model/wte")?; let wpe = tl.load("model/wpe")?; - let lm_head = tl.load("model/lm_head")?; + // GPT-2's language model head is optional; if it is not present, + // the `wte` tensor is used instead. + let lm_head = tl.load("model/lm_head").ok(); let mut layers = Vec::new(); for i in 0..hyperparameters.n_layer { @@ -306,7 +310,7 @@ impl KnownModel for Gpt2 { let embeddings_tensor: ggml::Tensor = input_layer.share(); - input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer); + input_layer = ctx0.op_mul_mat(self.lm_head.as_ref().unwrap_or(&self.wte), &input_layer); ( gf,