-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
112 lines (90 loc) · 3.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import os
import torch
import torch.optim as optim
from tqdm.auto import tqdm
import joblib
import random
import numpy as np
from data import StockDataset, DataProcessor
from model import Generator, Discriminator
from utils import plot_loss
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default='sample/sp500.csv')
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--nz', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--seq_len', type=int, default=127)
parser.add_argument('--clip', type=float, default=0.01)
parser.add_argument('--lr', type=int, default=1e-5)
parser.add_argument('--train_gen_per_epoch', type=int, default=5)
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--log_dir', type=str, default='./logs')
args= parser.parse_args()
return args
def train(args=None):
data_processor = DataProcessor('Adj Close')
log_returns_preprocessed = data_processor.preprocess(args.data_path)
num_epochs = args.num_epochs
nz = args.nz
batch_size = args.batch_size
seq_len = args.seq_len
clip= args.clip
lr = args.lr
if args.device:
device = torch.device(args.device)
else:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# Setup the dataloader
dataset = StockDataset(log_returns_preprocessed, seq_len)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
progressing_bar = tqdm(range(num_epochs))
history = dict(gen_loss=[], disc_loss=[])
# Initialize the generator and discriminator
generator = Generator().to(device)
discriminator = Discriminator(seq_len).to(device)
# Setup the optimizer
disc_optimizer = optim.RMSprop(discriminator.parameters(), lr=lr)
gen_optimizer = optim.RMSprop(generator.parameters(), lr=lr)
# Training loop
for epoch in progressing_bar:
progressing_bar.set_description('Epoch %d' % (epoch+1))
total_gen_loss = 0
total_disc_loss = 0
counter = 0
for idx, data in enumerate(dataloader, 0):
counter += 1
# Train the discriminator
discriminator.zero_grad()
real = data.to(device)
noise = torch.randn(batch_size, nz, seq_len, device=device)
fake = generator(noise) # .detach()
disc_loss = -torch.mean(discriminator(real)) + torch.mean(discriminator(fake))
disc_loss.backward()
disc_optimizer.step()
for dp in discriminator.parameters():
dp.data.clamp_(-clip, clip)
# Train the generator
if idx % args.train_gen_per_epoch == 0:
generator.zero_grad()
gen_loss = -torch.mean(discriminator(generator(noise)))
gen_loss.backward()
gen_optimizer.step()
total_gen_loss += gen_loss.item()
total_disc_loss += disc_loss.item()
total_gen_loss /= counter
total_disc_loss /= counter
history['gen_loss'].append(total_gen_loss)
history['disc_loss'].append(total_disc_loss)
progressing_bar.set_postfix_str('DiscLoss: %.4e, GenLoss: %.4e' % (total_gen_loss, total_disc_loss))
plot_loss(history, os.path.join(args.log_dir, 'training_loss.png'))
joblib.dump(data_processor, os.path.join(args.log_dir, 'data_processor.joblib'))
torch.save(generator, os.path.join(args.log_dir, 'generator.pth'))
if __name__ == '__main__':
args = parse_args()
train(args)