From c74e1036cdbbf48cc655502118fed92e6018480d Mon Sep 17 00:00:00 2001 From: Kushal Arora Date: Wed, 27 Mar 2024 23:06:05 +0000 Subject: [PATCH] Bug fix to get llama imported in llama evaluated. --- open_lm/main.py | 2 +- open_lm/positional_embedding/rotary.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index 9d0b1b6f..aa2003d1 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -134,7 +134,7 @@ def load_model(args, model, different_seed=False): # loading a bare (model only) checkpoint for fine-tune or evaluation start_epoch, global_step = 0, 0 pretrained_seed = None - model.load_state_dict(checkpoint) + model.load_state_dict(checkpoint['state_dict'], strict=False) logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") return start_epoch, global_step, pretrained_seed diff --git a/open_lm/positional_embedding/rotary.py b/open_lm/positional_embedding/rotary.py index b48ed890..4837d2db 100644 --- a/open_lm/positional_embedding/rotary.py +++ b/open_lm/positional_embedding/rotary.py @@ -48,7 +48,7 @@ def __init__(self, dim_model: int, seq_len: int, *_, **__): super().__init__() # Generate and save the inverse frequency buffer (non trainable) self.dim_model = dim_model - self.register_buffer("inv_freq", torch.zeros(self.dim_model // 2)) + self.inv_freq = torch.zeros(self.dim_model // 2) self._cos_cached = None self._sin_cached = None @@ -71,7 +71,7 @@ def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = Non if seq_len > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype: self._seq_len_cached = seq_len t = torch.arange(seq_len, device=device, dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(dtype)) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(dtype).to(device)) emb = torch.cat((freqs, freqs), dim=-1).to(device) self._cos_cached = emb.cos()[None, :, None, :].to(dtype)