Skip to content

Commit

Permalink
relocating and updating
Browse files Browse the repository at this point in the history
  • Loading branch information
Goekdeniz-Guelmez committed Oct 4, 2024
1 parent 9409eee commit 9ffd1a9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 38 deletions.
10 changes: 9 additions & 1 deletion example-fineweb.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"import json\n",
"\n",
"import torch\n",
"import math\n",
Expand All @@ -33,6 +32,7 @@
"from model.KANaMoEv1 import KANaMoEv1\n",
"\n",
"from utils import load_model, quick_inference\n",
"from model.handler import save_pretrained\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
Expand Down Expand Up @@ -176,6 +176,14 @@
" device=device\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32cd9a68",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
29 changes: 28 additions & 1 deletion model/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,31 @@ def save_pretrained(path_to_save: str, model: nn.Module):

print(f"[INFO] Model and configuration saved successfully to {path_to_save}")

return path_to_save
return path_to_save


def quick_inference(model: torch.nn.Module, tokens: torch.Tensor, max_new_tokens: int, tokenizer):
model.eval() # Set model to evaluation mode
with torch.no_grad(): # Disable gradient calculation for inference
for _ in range(max_new_tokens):
# Take the last 'max_seq_len' tokens as input to the model
tokens_conditioned = tokens[:, -model.args.max_seq_len:]

# Get the model's predictions (logits)
logits, _ = model(tokens_conditioned)

# Apply softmax to the last token's logits to get probabilities
probabilities = torch.softmax(logits[:, -1, :], dim=-1)

# Sample the next token from the probability distribution
next_token = torch.multinomial(probabilities, num_samples=1)

# Append the predicted token to the input sequence
tokens = torch.cat((tokens, next_token), dim=1)

# Decode and print the token (convert from token ID to string)
decoded_token = tokenizer.decode(next_token.squeeze(dim=1).tolist(), skip_special_tokens=True)
print(decoded_token, end="", flush=True)

# Return the final generated sequence (both as tokens and decoded text)
return tokens, tokenizer.decode(tokens.squeeze(dim=0).tolist(), skip_special_tokens=True)
37 changes: 1 addition & 36 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@

import torch


def load_model(model, file_name="trained_KANama_model.pth"):
model.load_state_dict(torch.load(file_name))
return model

def save_model(model, file_name="trained_KANama_model.pth"):
torch.save(model.state_dict(), file_name)

def createLossPlot(steps, losses, title="training"):
plt.plot(steps, losses, linewidth=1)
plt.xlabel("steps")
Expand Down Expand Up @@ -90,31 +82,4 @@ def visualize_KANama(file_path, combined_file_path, individual_folder_path):

plt.tight_layout()
plt.savefig(combined_file_path)
plt.show()


def quick_inference(model: torch.nn.Module, tokens: torch.Tensor, max_new_tokens: int, tokenizer):
model.eval() # Set model to evaluation mode
with torch.no_grad(): # Disable gradient calculation for inference
for _ in range(max_new_tokens):
# Take the last 'max_seq_len' tokens as input to the model
tokens_conditioned = tokens[:, -model.args.max_seq_len:]

# Get the model's predictions (logits)
logits, _ = model(tokens_conditioned)

# Apply softmax to the last token's logits to get probabilities
probabilities = torch.softmax(logits[:, -1, :], dim=-1)

# Sample the next token from the probability distribution
next_token = torch.multinomial(probabilities, num_samples=1)

# Append the predicted token to the input sequence
tokens = torch.cat((tokens, next_token), dim=1)

# Decode and print the token (convert from token ID to string)
decoded_token = tokenizer.decode(next_token.squeeze(dim=1).tolist(), skip_special_tokens=True)
print(decoded_token, end="", flush=True)

# Return the final generated sequence (both as tokens and decoded text)
return tokens, tokenizer.decode(tokens.squeeze(dim=0).tolist(), skip_special_tokens=True)
plt.show()

0 comments on commit 9ffd1a9

Please sign in to comment.