diff --git a/example.py b/example.py index 35a5551..f787de7 100644 --- a/example.py +++ b/example.py @@ -4,11 +4,11 @@ # from KANama.model.args import ModelArgs # from KANama.model.KANamav4 import KANamav4 -from model import load +from model.handler import save_pretrained from trainer.SFTTrainer import train from model.args import ModelArgs, MOEModelArgs as ModelArgs -from model.KANaMoEv1 import KANamav5 +from model.KANaMoEv1 import KANaMoEv1 from model.KANamav4 import KANamav4 @@ -20,12 +20,13 @@ # ModelArgs.use_softmax_temp_proj = False -train_data = torch.tensor([[25, 1, 4, 12, 9, 7, 1, 4, 12, 9, 4, 1, 4, 22, 9, 13, 26, 24, 12, 9, 0]], dtype=torch.long) # Must be a 3 dimansional Tensor [B, max_seq_len, tokens] -val_data = torch.tensor([[25, 1, 4, 12, 9, 7, 1, 4, 12, 9, 4, 1, 4, 22, 9, 13, 26, 24, 12, 9, 0]], dtype=torch.long) # Here too +# train_data = torch.tensor([[25, 1, 4, 12, 9, 7, 1, 4, 12, 9, 4, 1, 4, 22, 9, 13, 26, 24, 12, 9, 0]], dtype=torch.long) # Must be a 3 dimansional Tensor [B, max_seq_len, tokens] +# val_data = torch.tensor([[25, 1, 4, 12, 9, 7, 1, 4, 12, 9, 4, 1, 4, 22, 9, 13, 26, 24, 12, 9, 0]], dtype=torch.long) # Here too # model = KANamav4(ModelArgs) -model = KANamav5(ModelArgs) +model = KANaMoEv1(ModelArgs) # optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # new_model = train(model=model, optimizer=optimizer, train_data=train_data, val_data=val_data, save=False, max_steps=100, loss_interval=2, eval_interval=50) +save_pretrained(path_to_save="/Users/gokdenizgulmez/Desktop/meine-repos/KANama/example.py", model=model) \ No newline at end of file diff --git a/example_fineweb.py b/example_fineweb.py index 74075d2..4c3d2cf 100644 --- a/example_fineweb.py +++ b/example_fineweb.py @@ -8,7 +8,7 @@ from trainer.SFTTrainer import train from model.args import MOEModelArgs -from model.KANaMoEv1 import KANamav5 +from model.KANaMoEv1 import KANaMoEv1 from utils import load_model, quick_inference @@ -76,7 +76,7 @@ def lr_lambda(current_step: int, max_steps: int=50000, warmup_steps: int=40, lr_ val_data = data[:, n:].to(device) print("\n[LOADING MODEL]\n") -model = KANamav5(MOEModelArgs, device=device) +model = KANaMoEv1(MOEModelArgs, device=device) # Starting sequence (as tokens) initial_text = "Once upon a time" diff --git a/model/KANaMoEv1.py b/model/KANaMoEv1.py index 111f10a..c9a1efa 100644 --- a/model/KANaMoEv1.py +++ b/model/KANaMoEv1.py @@ -167,10 +167,10 @@ class KANaMoEv1(nn.Module): def __init__(self, args: MOEModelArgs, device: str="cpu"): super().__init__() self.device = torch.device(device) - + self.args = args - self.args.model_type = "KANaMoEv1" + self.args.model_type = self.model_type = "KANaMoEv1" self.freqs_cis = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len * 2, args.rope_theta, args.use_scaled_rope) diff --git a/model/args.py b/model/args.py index 8d190c3..ab5ce46 100644 --- a/model/args.py +++ b/model/args.py @@ -44,6 +44,7 @@ def __init__(self, **kwargs): @dataclass class MOEModelArgs: + model_type: str vocab_size: int = -1 pad_id: int = -1 eos_id: int = pad_id diff --git a/model/load.py b/model/handler.py similarity index 51% rename from model/load.py rename to model/handler.py index ee4dc26..c659c62 100644 --- a/model/load.py +++ b/model/handler.py @@ -1,10 +1,23 @@ -from args import ModelArgs, MOEModelArgs +from .KANamav1 import KANamav1 +from .KANamav2 import KANamav2 +from .KANamav3 import KANamav3 +from .KANamav4 import KANamav4 +from .KANaMoEv1 import KANaMoEv1 +from .args import ModelArgs, MOEModelArgs import torch.nn as nn import torch import os import json -def from_pretrained(path: str) -> nn.Module: +MODEL_CLASS_MAPPING = { + "KANamav1": KANamav1, + "KANamav2": KANamav2, + "KANamav3": KANamav3, + "KANamav4": KANamav4, + "KANaMoEv1": KANaMoEv1, +} + +def from_pretrained(path: str, device : str = "cpu") -> nn.Module: # Load config.json from the path config_path = os.path.join(path, "config.json") @@ -14,35 +27,40 @@ def from_pretrained(path: str) -> nn.Module: with open(config_path, 'r') as f: config = json.load(f) - # Get the 'model_type' from the config and apply the right ModelArgs + # Get the 'model_type' from the config model_type = config.get('model_type', None) if model_type is None: raise ValueError("model_type not found in config file") # Load the right ModelArgs based on the model_type if model_type == 'KANaMoEv1': - model_args = MOEModelArgs.from_config(config) + model_args = MOEModelArgs(**config) # Initialize MOEModelArgs using kwargs else: - model_args = ModelArgs.from_config(config) + model_args = ModelArgs(**config) # Initialize ModelArgs using kwargs - # Create the model architecture using the loaded configuration - model_class = globals().get(model_args.model_class, None) # Ensure that model class is defined in the current scope + # Ensure the model class is in the model class mapping + model_class = MODEL_CLASS_MAPPING.get(model_type, None) if model_class is None: - raise ValueError(f"Model class {model_args.model_class} not found in the scope") + raise ValueError(f"Model class {model_type} not found in the model class mapping") # Initialize the model with the loaded configuration model = model_class(model_args) # Load the model weights from the model.pth file - model_weights_path = os.path.join(path, "pytorch_model.pth") + model_weights_path = os.path.join(path, "model.pth") if not os.path.exists(model_weights_path): raise FileNotFoundError(f"Model weights not found at {model_weights_path}") - model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu'))) + # Use weights_only=True to avoid the warning and future-proof your code + state_dict = torch.load(model_weights_path, map_location=torch.device(device), weights_only=True) + model.load_state_dict(state_dict) + + print("[INFO] Model and configuration loaded successfully") return model + def save_pretrained(path_to_save: str, model: nn.Module): # Ensure the save directory exists if not os.path.exists(path_to_save): @@ -50,16 +68,9 @@ def save_pretrained(path_to_save: str, model: nn.Module): # Get the args from the model (assuming the model has an 'args' attribute) model_args = model.args - model_type = getattr(model, 'model_type', None) - - if model_type is None: - raise ValueError("The model does not have a 'model_type' attribute") - - # Prepare config data to save, including model_type and model_args - config = { - "model_type": model_type, - **model_args.to_dict() # Assuming model_args has a to_dict method - } + + # Use vars() to extract only the attributes (no need to manually specify them) + config = {key: value for key, value in vars(model_args).items() if not key.startswith('__')} # Save the config as a JSON file config_path = os.path.join(path_to_save, "config.json") @@ -69,7 +80,7 @@ def save_pretrained(path_to_save: str, model: nn.Module): # Save the model weights model_weights_path = os.path.join(path_to_save, "model.pth") torch.save(model.state_dict(), model_weights_path) - - print(f"Model saved successfully to {path_to_save}") + + print(f"[INFO] Model and configuration saved successfully to {path_to_save}") return path_to_save \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..e4c5036 --- /dev/null +++ b/test.py @@ -0,0 +1,6 @@ +from model.handler import save_pretrained, from_pretrained + +from model.args import MOEModelArgs +from model.KANaMoEv1 import KANaMoEv1 + +model = from_pretrained("/Users/gokdenizgulmez/Desktop/meine-repos/KANama/test") \ No newline at end of file diff --git a/test/config.json b/test/config.json new file mode 100644 index 0000000..f8a0573 --- /dev/null +++ b/test/config.json @@ -0,0 +1,23 @@ +{ + "vocab_size": 30, + "pad_id": 0, + "eos_id": -1, + "dim": 32, + "n_layers": 8, + "n_heads": 8, + "n_kv_heads": null, + "use_kan": true, + "train_softmax_temp": true, + "use_softmax_temp_proj": true, + "softmax_bias": false, + "multiple_of": 256, + "ffn_dim_multiplier": null, + "rms_norm_eps": 1e-05, + "rope_theta": 500000, + "use_scaled_rope": false, + "max_batch_size": 4, + "max_seq_len": 20, + "num_experts": 2, + "num_experts_per_tok": 1, + "model_type": "KANaMoEv1" +} \ No newline at end of file diff --git a/test/model.pth b/test/model.pth new file mode 100644 index 0000000..fa166f1 Binary files /dev/null and b/test/model.pth differ