From c995ca8fc556560746027a4c33930260d18fce8a Mon Sep 17 00:00:00 2001 From: Steven Weiss Date: Mon, 10 Jul 2023 10:38:53 -0700 Subject: [PATCH] Address PR comments --- crates/models/gpt2/src/lib.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index 2eae1ce5..c69832ca 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -31,6 +31,8 @@ pub struct Gpt2 { // weighted positional encodings wpe: Tensor, // language model head + // + // Optional: if not present, the `wte` tensor is used instead. lm_head: Option, // weights for the model @@ -59,6 +61,9 @@ impl KnownModel for Gpt2 { let ln_f_b = tl.load("model/ln_f/b")?; let wte = tl.load("model/wte")?; let wpe = tl.load("model/wpe")?; + + // GPT-2's language model head is optional; if it is not present, + // the `wte` tensor is used instead. let lm_head = tl.load("model/lm_head").ok(); let mut layers = Vec::new();