Skip to content

Commit

Permalink
Changes to add kc param, label noise, and fix suggested in kuangliu#137
Browse files Browse the repository at this point in the history
  • Loading branch information
moyix committed Dec 8, 2021
1 parent 49b7aa9 commit 7b3bddb
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 65 deletions.
130 changes: 75 additions & 55 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np

import torchvision
import torchvision.transforms as transforms
Expand All @@ -14,19 +15,21 @@
from models import *
from utils import progress_bar


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
help='resume from checkpoint')
parser.add_argument('--kc', default=64, type=int, help='model size')
parser.add_argument('--epoch', default=400, type=int, help='total training epochs')
parser.add_argument('--resume', '-r', default=None, type=str, help='resume from checkpoint')
parser.add_argument('--noise', default=0, type=int, help='label noise %')
parser.add_argument('--eval', action='store_true', help='only do evaluation')
parser.add_argument('--quiet', '-q', action='store_true', help='be quiet')
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
if not args.quiet: print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
Expand All @@ -39,55 +42,65 @@
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2)
do_download = not os.path.exists('./data')

# Training data with optional noise
def flip_random_label(x):
image, label = x
wrong = list(range(10))
del wrong[label]
label = np.random.choice(wrong)
x = image, label

return x

noise_indices = []
noise_labels = []
if not args.eval:
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=do_download, transform=transform_train)

if args.noise != 0:
noise_frac = args.noise / 100
num_noise_samples = int(noise_frac * len(trainset))
if not args.quiet: print(f'Flipping {args.noise}% of labels ({num_noise_samples} samples)')
noise_indices = np.random.choice(np.arange(len(trainset)), size=num_noise_samples, replace=False)
noisy_data = [x for x in trainset]
for i in noise_indices:
noisy_data[i] = flip_random_label(noisy_data[i])
noise_labels = noisy_data[1][noise_indices]
trainset = noisy_data

trainloader = torch.utils.data.DataLoader(
trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform_test)
root='./data', train=False, download=do_download, transform=transform_test)
testloader = torch.utils.data.DataLoader(
testset, batch_size=100, shuffle=False, num_workers=2)
testset, batch_size=128, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
if args.resume:
# Load checkpoint.
if not args.quiet: print('==> Resuming from checkpoint..')
checkpoint = torch.load(args.resume)
args.kc = checkpoint['kc']
args.noise = checkpoint['noise']

# Model
print('==> Building model..')
# net = VGG('VGG19')
# net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
# net = ShuffleNetV2(1)
# net = EfficientNetB0()
# net = RegNetX_200MF()
net = SimpleDLA()
if not args.quiet: print('==> Building model..')
net = PreActResNet18(args.kc)
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True

if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Adam with LR=0.0001
optimizer = optim.Adam(net.parameters(), lr=0.0001)

# Training
def train(epoch):
Expand All @@ -109,12 +122,11 @@ def train(epoch):
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

if not args.quiet:
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
Expand All @@ -130,25 +142,33 @@ def test(epoch):
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
if not args.quiet:
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
# Save checkpoint.
if epoch % 10 == 0 and not args.eval:
if not args.quiet: print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
'kc': args.kc,
'noise': args.noise,
'noise_indices': noise_indices,
'noise_labels': noise_labels,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.pth')
best_acc = acc


for epoch in range(start_epoch, start_epoch+200):
train(epoch)
test(epoch)
scheduler.step()
torch.save(state, f'./checkpoint/noise{args.noise}_kc{args.kc}_epoch{epoch}_ckpt.pth')
return acc

if args.eval:
if not args.resume:
parser.error("--eval requires --resume CHECKPOINT")
print(args.kc, args.noise, test(0))
else:
for epoch in range(start_epoch, args.epoch+1):
train(epoch)
test(epoch)
22 changes: 12 additions & 10 deletions models/preact_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,17 @@ def forward(self, x):


class PreActResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
def __init__(self, block, num_blocks, num_classes=10, kc=64):
super(PreActResNet, self).__init__()
self.in_planes = 64
self.in_planes = kc

self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes)
self.conv1 = nn.Conv2d(3, kc, kernel_size=3, stride=1, padding=1, bias=False)
self.layer1 = self._make_layer(block, kc, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, kc*2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, kc*4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, kc*8, num_blocks[3], stride=2)
self.linear = nn.Linear(8*kc*block.expansion, num_classes)
self.bn = nn.BatchNorm2d(8*kc)

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
Expand All @@ -89,13 +90,14 @@ def forward(self, x):
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = F.relu(self.bn(out), inplace=True)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out


def PreActResNet18():
return PreActResNet(PreActBlock, [2,2,2,2])
def PreActResNet18(kc=64):
return PreActResNet(PreActBlock, [2,2,2,2], kc=kc)

def PreActResNet34():
return PreActResNet(PreActBlock, [3,4,6,3])
Expand Down

0 comments on commit 7b3bddb

Please sign in to comment.