From 0f841cdfd7783c5793182eb1cdec12ebe2dbb45c Mon Sep 17 00:00:00 2001 From: Steven Weiss Date: Sun, 9 Jul 2023 12:37:19 -0700 Subject: [PATCH] Update gpt2 to use wte if no lm_head --- crates/models/gpt2/src/lib.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index 9e2102b1..2eae1ce5 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -31,7 +31,7 @@ pub struct Gpt2 { // weighted positional encodings wpe: Tensor, // language model head - lm_head: Tensor, + lm_head: Option, // weights for the model layers: Vec, @@ -59,7 +59,7 @@ impl KnownModel for Gpt2 { let ln_f_b = tl.load("model/ln_f/b")?; let wte = tl.load("model/wte")?; let wpe = tl.load("model/wpe")?; - let lm_head = tl.load("model/lm_head")?; + let lm_head = tl.load("model/lm_head").ok(); let mut layers = Vec::new(); for i in 0..hyperparameters.n_layer { @@ -306,7 +306,8 @@ impl KnownModel for Gpt2 { let embeddings_tensor: ggml::Tensor = input_layer.share(); - input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer); + let head = self.lm_head.as_ref().unwrap_or(&self.wte); + input_layer = ctx0.op_mul_mat(head, &input_layer); ( gf,