-
Notifications
You must be signed in to change notification settings - Fork 12
/
main.py
41 lines (38 loc) · 1.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from torch import optim
from torch.autograd import Variable
import numpy as np
import pickle
from utils import Hps
from utils import DataLoader
from utils import Logger
from utils import SingleDataset
from solver import Solver
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train', default=True, action='store_true')
parser.add_argument('--test', default=False, action='store_true')
parser.add_argument('--load_model', default=False, action='store_true')
parser.add_argument('-flag', default='train')
parser.add_argument('-hps_path', default='./hps/vctk.json')
parser.add_argument('-load_model_path', default='')
parser.add_argument('-dataset_path', default='./vctk.h5')
parser.add_argument('-index_path', default='./index.json')
parser.add_argument('-output_model_path', default='./pkl')
args = parser.parse_args()
hps = Hps()
hps.load(args.hps_path)
hps_tuple = hps.get_tuple()
dataset = SingleDataset(args.dataset_path,
args.index_path,
seg_len=hps_tuple.seg_len)
data_loader = DataLoader(dataset)
solver = Solver(hps_tuple, data_loader)
if args.load_model:
solver.load_model(args.load_model_path)
if args.train:
solver.train(args.output_model_path, args.flag, mode='pretrain_G')
solver.train(args.output_model_path, args.flag, mode='pretrain_D')
solver.train(args.output_model_path, args.flag, mode='train')
solver.train(args.output_model_path, args.flag, mode='patchGAN')