Skip to content

Commit

Permalink
fixing loading and saving
Browse files Browse the repository at this point in the history
  • Loading branch information
Goekdeniz-Guelmez committed Oct 4, 2024
1 parent 52658a9 commit b8e5780
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 31 deletions.
11 changes: 6 additions & 5 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# from KANama.model.args import ModelArgs
# from KANama.model.KANamav4 import KANamav4

from model import load
from model.handler import save_pretrained

from trainer.SFTTrainer import train
from model.args import ModelArgs, MOEModelArgs as ModelArgs
from model.KANaMoEv1 import KANamav5
from model.KANaMoEv1 import KANaMoEv1
from model.KANamav4 import KANamav4


Expand All @@ -20,12 +20,13 @@
# ModelArgs.use_softmax_temp_proj = False


train_data = torch.tensor([[25, 1, 4, 12, 9, 7, 1, 4, 12, 9, 4, 1, 4, 22, 9, 13, 26, 24, 12, 9, 0]], dtype=torch.long) # Must be a 3 dimansional Tensor [B, max_seq_len, tokens]
val_data = torch.tensor([[25, 1, 4, 12, 9, 7, 1, 4, 12, 9, 4, 1, 4, 22, 9, 13, 26, 24, 12, 9, 0]], dtype=torch.long) # Here too
# train_data = torch.tensor([[25, 1, 4, 12, 9, 7, 1, 4, 12, 9, 4, 1, 4, 22, 9, 13, 26, 24, 12, 9, 0]], dtype=torch.long) # Must be a 3 dimansional Tensor [B, max_seq_len, tokens]
# val_data = torch.tensor([[25, 1, 4, 12, 9, 7, 1, 4, 12, 9, 4, 1, 4, 22, 9, 13, 26, 24, 12, 9, 0]], dtype=torch.long) # Here too


# model = KANamav4(ModelArgs)
model = KANamav5(ModelArgs)
model = KANaMoEv1(ModelArgs)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# new_model = train(model=model, optimizer=optimizer, train_data=train_data, val_data=val_data, save=False, max_steps=100, loss_interval=2, eval_interval=50)

save_pretrained(path_to_save="/Users/gokdenizgulmez/Desktop/meine-repos/KANama/example.py", model=model)
4 changes: 2 additions & 2 deletions example_fineweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from trainer.SFTTrainer import train
from model.args import MOEModelArgs
from model.KANaMoEv1 import KANamav5
from model.KANaMoEv1 import KANaMoEv1

from utils import load_model, quick_inference

Expand Down Expand Up @@ -76,7 +76,7 @@ def lr_lambda(current_step: int, max_steps: int=50000, warmup_steps: int=40, lr_
val_data = data[:, n:].to(device)

print("\n[LOADING MODEL]\n")
model = KANamav5(MOEModelArgs, device=device)
model = KANaMoEv1(MOEModelArgs, device=device)

# Starting sequence (as tokens)
initial_text = "Once upon a time"
Expand Down
4 changes: 2 additions & 2 deletions model/KANaMoEv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ class KANaMoEv1(nn.Module):
def __init__(self, args: MOEModelArgs, device: str="cpu"):
super().__init__()
self.device = torch.device(device)

self.args = args

self.args.model_type = "KANaMoEv1"
self.args.model_type = self.model_type = "KANaMoEv1"

self.freqs_cis = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len * 2, args.rope_theta, args.use_scaled_rope)

Expand Down
1 change: 1 addition & 0 deletions model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, **kwargs):

@dataclass
class MOEModelArgs:
model_type: str
vocab_size: int = -1
pad_id: int = -1
eos_id: int = pad_id
Expand Down
55 changes: 33 additions & 22 deletions model/load.py → model/handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from args import ModelArgs, MOEModelArgs
from .KANamav1 import KANamav1
from .KANamav2 import KANamav2
from .KANamav3 import KANamav3
from .KANamav4 import KANamav4
from .KANaMoEv1 import KANaMoEv1
from .args import ModelArgs, MOEModelArgs
import torch.nn as nn
import torch
import os
import json

def from_pretrained(path: str) -> nn.Module:
MODEL_CLASS_MAPPING = {
"KANamav1": KANamav1,
"KANamav2": KANamav2,
"KANamav3": KANamav3,
"KANamav4": KANamav4,
"KANaMoEv1": KANaMoEv1,
}

def from_pretrained(path: str, device : str = "cpu") -> nn.Module:
# Load config.json from the path
config_path = os.path.join(path, "config.json")

Expand All @@ -14,52 +27,50 @@ def from_pretrained(path: str) -> nn.Module:
with open(config_path, 'r') as f:
config = json.load(f)

# Get the 'model_type' from the config and apply the right ModelArgs
# Get the 'model_type' from the config
model_type = config.get('model_type', None)
if model_type is None:
raise ValueError("model_type not found in config file")

# Load the right ModelArgs based on the model_type
if model_type == 'KANaMoEv1':
model_args = MOEModelArgs.from_config(config)
model_args = MOEModelArgs(**config) # Initialize MOEModelArgs using kwargs
else:
model_args = ModelArgs.from_config(config)
model_args = ModelArgs(**config) # Initialize ModelArgs using kwargs

# Create the model architecture using the loaded configuration
model_class = globals().get(model_args.model_class, None) # Ensure that model class is defined in the current scope
# Ensure the model class is in the model class mapping
model_class = MODEL_CLASS_MAPPING.get(model_type, None)
if model_class is None:
raise ValueError(f"Model class {model_args.model_class} not found in the scope")
raise ValueError(f"Model class {model_type} not found in the model class mapping")

# Initialize the model with the loaded configuration
model = model_class(model_args)

# Load the model weights from the model.pth file
model_weights_path = os.path.join(path, "pytorch_model.pth")
model_weights_path = os.path.join(path, "model.pth")

if not os.path.exists(model_weights_path):
raise FileNotFoundError(f"Model weights not found at {model_weights_path}")

model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')))
# Use weights_only=True to avoid the warning and future-proof your code
state_dict = torch.load(model_weights_path, map_location=torch.device(device), weights_only=True)
model.load_state_dict(state_dict)

print("[INFO] Model and configuration loaded successfully")

return model


def save_pretrained(path_to_save: str, model: nn.Module):
# Ensure the save directory exists
if not os.path.exists(path_to_save):
os.makedirs(path_to_save)

# Get the args from the model (assuming the model has an 'args' attribute)
model_args = model.args
model_type = getattr(model, 'model_type', None)

if model_type is None:
raise ValueError("The model does not have a 'model_type' attribute")

# Prepare config data to save, including model_type and model_args
config = {
"model_type": model_type,
**model_args.to_dict() # Assuming model_args has a to_dict method
}

# Use vars() to extract only the attributes (no need to manually specify them)
config = {key: value for key, value in vars(model_args).items() if not key.startswith('__')}

# Save the config as a JSON file
config_path = os.path.join(path_to_save, "config.json")
Expand All @@ -69,7 +80,7 @@ def save_pretrained(path_to_save: str, model: nn.Module):
# Save the model weights
model_weights_path = os.path.join(path_to_save, "model.pth")
torch.save(model.state_dict(), model_weights_path)
print(f"Model saved successfully to {path_to_save}")

print(f"[INFO] Model and configuration saved successfully to {path_to_save}")

return path_to_save
6 changes: 6 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from model.handler import save_pretrained, from_pretrained

from model.args import MOEModelArgs
from model.KANaMoEv1 import KANaMoEv1

model = from_pretrained("/Users/gokdenizgulmez/Desktop/meine-repos/KANama/test")
23 changes: 23 additions & 0 deletions test/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"vocab_size": 30,
"pad_id": 0,
"eos_id": -1,
"dim": 32,
"n_layers": 8,
"n_heads": 8,
"n_kv_heads": null,
"use_kan": true,
"train_softmax_temp": true,
"use_softmax_temp_proj": true,
"softmax_bias": false,
"multiple_of": 256,
"ffn_dim_multiplier": null,
"rms_norm_eps": 1e-05,
"rope_theta": 500000,
"use_scaled_rope": false,
"max_batch_size": 4,
"max_seq_len": 20,
"num_experts": 2,
"num_experts_per_tok": 1,
"model_type": "KANaMoEv1"
}
Binary file added test/model.pth
Binary file not shown.

0 comments on commit b8e5780

Please sign in to comment.