Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Update gpt2 to use wte if no lm_head
Browse files Browse the repository at this point in the history
  • Loading branch information
steventrouble committed Jul 9, 2023
1 parent ff4bb37 commit 0f841cd
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct Gpt2 {
// weighted positional encodings
wpe: Tensor,
// language model head
lm_head: Tensor,
lm_head: Option<Tensor>,

// weights for the model
layers: Vec<Layer>,
Expand Down Expand Up @@ -59,7 +59,7 @@ impl KnownModel for Gpt2 {
let ln_f_b = tl.load("model/ln_f/b")?;
let wte = tl.load("model/wte")?;
let wpe = tl.load("model/wpe")?;
let lm_head = tl.load("model/lm_head")?;
let lm_head = tl.load("model/lm_head").ok();

let mut layers = Vec::new();
for i in 0..hyperparameters.n_layer {
Expand Down Expand Up @@ -306,7 +306,8 @@ impl KnownModel for Gpt2 {

let embeddings_tensor: ggml::Tensor = input_layer.share();

input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer);
let head = self.lm_head.as_ref().unwrap_or(&self.wte);
input_layer = ctx0.op_mul_mat(head, &input_layer);

(
gf,
Expand Down

0 comments on commit 0f841cd

Please sign in to comment.