diff --git a/arguments.py b/arguments.py index 30041dd..d82bc8b 100644 --- a/arguments.py +++ b/arguments.py @@ -7,6 +7,9 @@ def get_args(): parser.add_argument('--dataset', type=str, default='cifar10', help='Name of the dataset used.') parser.add_argument('--batch_size', type=int, default=128, help='Batch size used for training and testing') parser.add_argument('--train_epochs', type=int, default=100, help='Number of training epochs') + parser.add_argument('--lr_vae', type=float, default=5e-4, help='Learning rate for VAE') + parser.add_argument('--lr_dis', type=float, default=5e-4, help='Learning rate for Discriminator') + parser.add_argument('--lr_task', type=float, default=5e-4, help='Learning rate for Task Module') parser.add_argument('--latent_dim', type=int, default=32, help='The dimensionality of the VAE latent dimension') parser.add_argument('--data_path', type=str, default='./data', help='Path to where the data is') parser.add_argument('--beta', type=float, default=1, help='Hyperparameter for training. The parameter for VAE') diff --git a/model.py b/model.py index a630eae..24a83df 100644 --- a/model.py +++ b/model.py @@ -110,7 +110,7 @@ def forward(self, z): def kaiming_init(m): if isinstance(m, (nn.Linear, nn.Conv2d)): - init.kaiming_normal(m.weight) + init.kaiming_normal_(m.weight) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): diff --git a/solver.py b/solver.py index 3282ec4..34b7482 100644 --- a/solver.py +++ b/solver.py @@ -42,9 +42,9 @@ def train(self, querry_dataloader, val_dataloader, task_model, vae, discriminato labeled_data = self.read_data(querry_dataloader) unlabeled_data = self.read_data(unlabeled_dataloader, labels=False) - optim_vae = optim.Adam(vae.parameters(), lr=5e-4) - optim_task_model = optim.SGD(task_model.parameters(), lr=0.01, weight_decay=5e-4, momentum=0.9) - optim_discriminator = optim.Adam(discriminator.parameters(), lr=5e-4) + optim_vae = optim.Adam(vae.parameters(), lr=self.args.lr_vae) + optim_task_model = optim.SGD(task_model.parameters(), lr=self.args.lr_task, weight_decay=5e-4, momentum=0.9) + optim_discriminator = optim.Adam(discriminator.parameters(), lr=self.args.lr_dis) vae.train() @@ -87,8 +87,8 @@ def train(self, querry_dataloader, val_dataloader, task_model, vae, discriminato labeled_preds = discriminator(mu) unlabeled_preds = discriminator(unlab_mu) - lab_real_preds = torch.ones(labeled_imgs.size(0)) - unlab_real_preds = torch.ones(unlabeled_imgs.size(0)) + lab_real_preds = torch.ones(labeled_imgs.size(0), 1) + unlab_real_preds = torch.ones(unlabeled_imgs.size(0), 1) if self.args.cuda: lab_real_preds = lab_real_preds.cuda() @@ -120,8 +120,8 @@ def train(self, querry_dataloader, val_dataloader, task_model, vae, discriminato labeled_preds = discriminator(mu) unlabeled_preds = discriminator(unlab_mu) - lab_real_preds = torch.ones(labeled_imgs.size(0)) - unlab_fake_preds = torch.zeros(unlabeled_imgs.size(0)) + lab_real_preds = torch.ones(labeled_imgs.size(0), 1) + unlab_fake_preds = torch.zeros(unlabeled_imgs.size(0), 1) if self.args.cuda: lab_real_preds = lab_real_preds.cuda()