-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
75 lines (57 loc) · 3.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from models import NeuralNet
from utils import train, test, enable_disable_lora, count_trainable_params
def main():
# Data transformations and loaders
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
full_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(full_train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True)
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize models
model_with_lora = NeuralNet().to(device)
model_without_lora = NeuralNet().to(device)
# Clone the weights from model_with_lora to model_without_lora
model_without_lora.load_state_dict(model_with_lora.state_dict())
# Set up the optimizer and loss function
optimizer_with_lora = optim.Adam(model_with_lora.parameters(), lr=0.001)
optimizer_without_lora = optim.Adam(model_without_lora.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# Initial full training
print("Starting training on full dataset...")
train(train_loader, model_with_lora, device, optimizer_with_lora, criterion, epochs=1)
# Report number of trainable parameters before fine-tuning
print(f"Trainable parameters with LoRA: {count_trainable_params(model_with_lora)}")
# Clone original weights for comparison after fine-tuning
original_weights = {name: param.clone().detach() for name, param in model_with_lora.named_parameters()}
# Fine-tuning on poorly performing digit
digit_to_finetune = 9
indices = [i for i, y in enumerate(full_train_dataset.targets) if y == digit_to_finetune]
finetune_dataset = Subset(full_train_dataset, indices)
finetune_loader = DataLoader(finetune_dataset, batch_size=10, shuffle=True)
print(f"Fine-tuning on digit {digit_to_finetune} with LoRA enabled...")
enable_disable_lora(model_with_lora, enabled=True)
train(finetune_loader, model_with_lora, device, optimizer_with_lora, criterion, epochs=1)
print(f"Fine-tuning on digit {digit_to_finetune} with LoRA disabled...")
enable_disable_lora(model_without_lora, enabled=False) # Ensure LoRA is disabled
train(finetune_loader, model_without_lora, device, optimizer_without_lora, criterion, epochs=1)
# Report number of trainable parameters after fine-tuning with LoRA disabled
print(f"Trainable parameters without LoRA: {count_trainable_params(model_without_lora)}")
# Testing with LoRA enabled
print("Testing with LoRA enabled...")
test(test_loader, model_with_lora, device)
# Testing with LoRA disabled
print("Testing with LoRA disabled...")
test(test_loader, model_without_lora, device)
if __name__ == "__main__":
main()