Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Goekdeniz-Guelmez committed Oct 4, 2024
1 parent d7b3864 commit 52658a9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
6 changes: 4 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# from KANama.model.args import ModelArgs
# from KANama.model.KANamav4 import KANamav4

from model import load

from trainer.SFTTrainer import train
from model.args import ModelArgs, MOEModelArgs as ModelArgs
Expand All @@ -25,5 +26,6 @@

# model = KANamav4(ModelArgs)
model = KANamav5(ModelArgs)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
new_model = train(model=model, optimizer=optimizer, train_data=train_data, val_data=val_data, save=False, max_steps=100, loss_interval=2, eval_interval=50)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# new_model = train(model=model, optimizer=optimizer, train_data=train_data, val_data=val_data, save=False, max_steps=100, loss_interval=2, eval_interval=50)

5 changes: 1 addition & 4 deletions model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def forward(self, x):
return output * self.weight



def apply_scaling(freqs: torch.Tensor):
scale_factor = 8
low_freq_factor = 1
Expand All @@ -43,7 +42,6 @@ def apply_scaling(freqs: torch.Tensor):
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)



def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
Expand All @@ -53,22 +51,21 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled:
return torch.polar(torch.ones_like(freqs), freqs)



def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
return torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq), torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)



def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
B, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
Expand Down

0 comments on commit 52658a9

Please sign in to comment.