Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

fix #338 - use wte if no lm_head for gpt2 #343

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ pub struct Gpt2 {
wte: Tensor,
// weighted positional encodings
wpe: Tensor,
// language model head
lm_head: Tensor,
// language model head.
//
// Optional: if not present, the `wte` tensor is used instead.
lm_head: Option<Tensor>,

// weights for the model
layers: Vec<Layer>,
Expand Down Expand Up @@ -59,7 +61,9 @@ impl KnownModel for Gpt2 {
let ln_f_b = tl.load("model/ln_f/b")?;
let wte = tl.load("model/wte")?;
let wpe = tl.load("model/wpe")?;
let lm_head = tl.load("model/lm_head")?;
// GPT-2's language model head is optional; if it is not present,
// the `wte` tensor is used instead.
let lm_head = tl.load("model/lm_head").ok();

let mut layers = Vec::new();
for i in 0..hyperparameters.n_layer {
Expand Down Expand Up @@ -306,7 +310,7 @@ impl KnownModel for Gpt2 {

let embeddings_tensor: ggml::Tensor = input_layer.share();

input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer);
input_layer = ctx0.op_mul_mat(self.lm_head.as_ref().unwrap_or(&self.wte), &input_layer);

(
gf,
Expand Down