From 872c065ed39f89b5b19ec1b175e3a96a0b886b71 Mon Sep 17 00:00:00 2001 From: Arkar Aung Date: Sun, 26 Jul 2020 13:39:57 +0630 Subject: [PATCH 1/3] Changes to accomodate newer pytorch version - Change to init.kaiming_normal_ for new PyTorch version - Create fake and real label shapes as a column vector --- model.py | 2 +- solver.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/model.py b/model.py index a630eae..24a83df 100644 --- a/model.py +++ b/model.py @@ -110,7 +110,7 @@ def forward(self, z): def kaiming_init(m): if isinstance(m, (nn.Linear, nn.Conv2d)): - init.kaiming_normal(m.weight) + init.kaiming_normal_(m.weight) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): diff --git a/solver.py b/solver.py index 3282ec4..e9b0f19 100644 --- a/solver.py +++ b/solver.py @@ -43,7 +43,7 @@ def train(self, querry_dataloader, val_dataloader, task_model, vae, discriminato unlabeled_data = self.read_data(unlabeled_dataloader, labels=False) optim_vae = optim.Adam(vae.parameters(), lr=5e-4) - optim_task_model = optim.SGD(task_model.parameters(), lr=0.01, weight_decay=5e-4, momentum=0.9) + optim_task_model = optim.SGD(task_model.parameters(), lr=5e-4, weight_decay=5e-4, momentum=0.9) optim_discriminator = optim.Adam(discriminator.parameters(), lr=5e-4) @@ -87,8 +87,8 @@ def train(self, querry_dataloader, val_dataloader, task_model, vae, discriminato labeled_preds = discriminator(mu) unlabeled_preds = discriminator(unlab_mu) - lab_real_preds = torch.ones(labeled_imgs.size(0)) - unlab_real_preds = torch.ones(unlabeled_imgs.size(0)) + lab_real_preds = torch.ones(labeled_imgs.size(0), 1) + unlab_real_preds = torch.ones(unlabeled_imgs.size(0), 1) if self.args.cuda: lab_real_preds = lab_real_preds.cuda() From c5b5ab6e7466c3982b2b8ff7e34e697a0fea7cc4 Mon Sep 17 00:00:00 2001 From: Arkar Aung Date: Sun, 26 Jul 2020 13:48:27 +0630 Subject: [PATCH 2/3] Fix shape issue in labels for discriminator --- solver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/solver.py b/solver.py index e9b0f19..d577682 100644 --- a/solver.py +++ b/solver.py @@ -120,8 +120,8 @@ def train(self, querry_dataloader, val_dataloader, task_model, vae, discriminato labeled_preds = discriminator(mu) unlabeled_preds = discriminator(unlab_mu) - lab_real_preds = torch.ones(labeled_imgs.size(0)) - unlab_fake_preds = torch.zeros(unlabeled_imgs.size(0)) + lab_real_preds = torch.ones(labeled_imgs.size(0), 1) + unlab_fake_preds = torch.zeros(unlabeled_imgs.size(0), 1) if self.args.cuda: lab_real_preds = lab_real_preds.cuda() From 8993aa7f62e9314b3af48520a16377c3d3825583 Mon Sep 17 00:00:00 2001 From: Arkar Aung Date: Sun, 26 Jul 2020 14:14:27 +0630 Subject: [PATCH 3/3] Add learning rates as arguments --- arguments.py | 3 +++ solver.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/arguments.py b/arguments.py index 30041dd..d82bc8b 100644 --- a/arguments.py +++ b/arguments.py @@ -7,6 +7,9 @@ def get_args(): parser.add_argument('--dataset', type=str, default='cifar10', help='Name of the dataset used.') parser.add_argument('--batch_size', type=int, default=128, help='Batch size used for training and testing') parser.add_argument('--train_epochs', type=int, default=100, help='Number of training epochs') + parser.add_argument('--lr_vae', type=float, default=5e-4, help='Learning rate for VAE') + parser.add_argument('--lr_dis', type=float, default=5e-4, help='Learning rate for Discriminator') + parser.add_argument('--lr_task', type=float, default=5e-4, help='Learning rate for Task Module') parser.add_argument('--latent_dim', type=int, default=32, help='The dimensionality of the VAE latent dimension') parser.add_argument('--data_path', type=str, default='./data', help='Path to where the data is') parser.add_argument('--beta', type=float, default=1, help='Hyperparameter for training. The parameter for VAE') diff --git a/solver.py b/solver.py index d577682..34b7482 100644 --- a/solver.py +++ b/solver.py @@ -42,9 +42,9 @@ def train(self, querry_dataloader, val_dataloader, task_model, vae, discriminato labeled_data = self.read_data(querry_dataloader) unlabeled_data = self.read_data(unlabeled_dataloader, labels=False) - optim_vae = optim.Adam(vae.parameters(), lr=5e-4) - optim_task_model = optim.SGD(task_model.parameters(), lr=5e-4, weight_decay=5e-4, momentum=0.9) - optim_discriminator = optim.Adam(discriminator.parameters(), lr=5e-4) + optim_vae = optim.Adam(vae.parameters(), lr=self.args.lr_vae) + optim_task_model = optim.SGD(task_model.parameters(), lr=self.args.lr_task, weight_decay=5e-4, momentum=0.9) + optim_discriminator = optim.Adam(discriminator.parameters(), lr=self.args.lr_dis) vae.train()