Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
steventrouble committed Jul 10, 2023
1 parent 0f841cd commit c995ca8
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub struct Gpt2 {
// weighted positional encodings
wpe: Tensor,
// language model head
//
// Optional: if not present, the `wte` tensor is used instead.
lm_head: Option<Tensor>,

// weights for the model
Expand Down Expand Up @@ -59,6 +61,9 @@ impl KnownModel for Gpt2 {
let ln_f_b = tl.load("model/ln_f/b")?;
let wte = tl.load("model/wte")?;
let wpe = tl.load("model/wpe")?;

// GPT-2's language model head is optional; if it is not present,
// the `wte` tensor is used instead.
let lm_head = tl.load("model/lm_head").ok();

let mut layers = Vec::new();
Expand Down

0 comments on commit c995ca8

Please sign in to comment.