-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_timegan.py
133 lines (114 loc) · 3.6 KB
/
main_timegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import numpy as np
import torch
from data_loading import load_dataset
from timegan import TimeGAN
from metrics.discriminative_metrics import discriminative_score_metrics
from metrics.predictive_metrics import predictive_score_metrics
from metrics.visualization_metrics import low_dimensional_representation, plot_distribution_estimate
from utils import preprocessing
def main(args):
"""Main function"""
#Check available device
device = (
'cuda'
if torch.cuda.is_available()
else 'mps'
if torch.backends.mps.is_available()
else 'cpu'
)
print(f'Using {device} device')
#Load data from file
data = load_dataset(args.data)
#Preprocessing
data_train, max_val, min_val = preprocessing((data, True), sequence_length=args.seq_len)
#Instantiate TimeGAN model
model = TimeGAN(input_features=data_train.shape[-1],
hidden_dim=args.hidden_dim,
num_layers=args.num_layers,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate).to(device)
#Start training
model.fit(data_train)
#Synthesize sequences
data_gen = model.transform(data_train.shape)
#Evaluation section
metric_results = {}
#Discriminative score
discriminative_score = []
for _ in range(args.metric_iteration):
temp_disc = discriminative_score_metrics(data_train, data_gen, device)
discriminative_score.append(temp_disc)
metric_results['Discriminative'] = np.mean(discriminative_score)
#Predictive score
predictive_score = []
for _ in range(args.metric_iteration):
temp_pred = predictive_score_metrics(data_train, data_gen, device)
predictive_score.append(temp_pred)
metric_results['Predictive'] = np.mean(predictive_score)
for metric, score in metric_results.items():
print(f'{metric} Score: {score}')
#Visualization
plot_distribution_estimate(*low_dimensional_representation(data_train, data_gen, 'pca'), 'pca')
plot_distribution_estimate(*low_dimensional_representation(data_train, data_gen, 'tsne'), 'tsne')
return data_train, data_gen, metric_results
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data',
help='Name of csv file',
type=str
)
parser.add_argument(
'--seq_len',
help='Length of sequences',
default=24,
type=int
)
parser.add_argument(
'--module',
help='RNN module',
choices=['gru', 'lstm'],
default='gru',
type=str
)
parser.add_argument(
'--hidden_dim',
help='Number of features for hidden vector',
default=24,
type=int
)
parser.add_argument(
'--num_layers',
help='Number of sequential recurrent layers',
default=3,
type=int
)
parser.add_argument(
'--epochs',
help='Number of iterations for training',
default=10000,
type=int
)
parser.add_argument(
'--batch_size',
help='Number of samples per batch during training',
default=128,
type=int
)
parser.add_argument(
'--metric_iteration',
help='Number of iterations for metric evaluation',
default=10,
type=int
)
parser.add_argument(
'--learning_rate',
help='Set learning rate for optimizer',
default=1e-3,
type=float
)
args = parser.parse_args()
#Main function call
data_train, data_gen, metrics = main(args)