-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_nn.py
424 lines (383 loc) · 20.9 KB
/
train_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import os
from time import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.optim import lr_scheduler, Adam
from torch.utils.data import Dataset
# import and set up the typeguard
from typeguard.importhook import install_import_hook
# comment these out when deploying:
install_import_hook('src.nn')
install_import_hook('src.scoring_rules')
install_import_hook('src.utils')
install_import_hook('src.weatherbench_utils')
install_import_hook('src.unet_utils')
from src.nn import InputTargetDataset, UNet2D, fit, fit_adversarial, \
ConditionalGenerativeModel, createGenerativeFCNN, createCriticFCNN, test_epoch, PatchGANDiscriminator, \
DiscardWindowSizeDim, get_target, LayerNormMine, createGenerativeGRUNN, createCriticGRUNN, \
DiscardNumberGenerationsInOutput, createGRUNN, createFCNN
from src.scoring_rules import EnergyScore, KernelScore, VariogramScore, PatchedScoringRule, SumScoringRules, \
ScoringRulesForWeatherBench, ScoringRulesForWeatherBenchPatched, LossForWeatherBenchPatched
from src.utils import plot_losses, save_net, save_dict_to_json, estimate_bandwidth_timeseries, lorenz96_variogram, \
def_loader_kwargs, load_net, weight_for_summed_score, weatherbench_variogram_haversine
from src.parsers import parser_train_net, define_masks, nonlinearities_dict, setup
from src.weatherbench_utils import load_weatherbench_data
# --- parser ---
parser = parser_train_net()
args = parser.parse_args()
model = args.model
method = args.method
scoring_rule = args.scoring_rule
kernel = args.kernel
patched = args.patched
base_measure = args.base_measure
epochs = args.epochs
# no_scheduler = args.no_scheduler
root_folder = args.root_folder
model_folder = args.model_folder
datasets_folder = args.datasets_folder
weatherbench_data_folder = args.weatherbench_data_folder
weatherbench_small = args.weatherbench_small
unet_noise_method = args.unet_noise_method
unet_large = args.unet_large
lr = args.lr
lr_c = args.lr_c
batch_size = args.batch_size
early_stopping = not args.no_early_stop
start_epoch_early_stopping = args.epochs_before_early_stopping
epochs_early_stopping_interval = args.epochs_test_interval
critic_steps_every_generator_step = args.critic_steps_every_generator_step
use_tqdm = not args.no_tqdm
save_net_flag = not args.no_save_net
continue_training_net_if_available = args.continue_training_net_if_available
cuda = args.cuda
load_all_data_GPU = args.load_all_data_GPU
ensemble_size = args.ensemble_size
nonlinearity = args.nonlinearity
data_size = args.data_size
auxiliary_var_size = args.auxiliary_var_size
seed = args.seed
lambda_gp = args.lambda_gp
gamma = args.gamma_kernel_score
notrain_if_done_before = args.notrain_if_done_before
patch_size = args.patch_size
no_RNN = args.no_RNN
hidden_size_rnn = args.hidden_size_rnn
weight_decay = args.weight_decay
scheduler_gamma = args.scheduler_gamma
args_dict = args.__dict__
model_is_weatherbench = model == "WeatherBench"
nn_model = "unet" if model_is_weatherbench else ("fcnn" if no_RNN else "rnn")
datasets_folder, nets_folder, data_size, auxiliary_var_size, name_postfix, unet_depths, patch_size, method_is_gan, hidden_size_rnn = \
setup(model, root_folder, model_folder, datasets_folder, data_size, method, scoring_rule, kernel, patched,
patch_size, ensemble_size, auxiliary_var_size, critic_steps_every_generator_step, base_measure, lr, lr_c,
batch_size, args.no_early_stop, unet_noise_method, unet_large, nn_model, hidden_size_rnn)
# stop if the net exists already:
if notrain_if_done_before and os.path.exists(nets_folder + f"net{name_postfix}.pth"):
print("Stopping as net with this setup was trained before.")
exit(0)
# if continue_training_net_if_available, check if the net is available and in that case start training from that
continue_training_net = continue_training_net_if_available and os.path.exists(nets_folder + f"net{name_postfix}.pth")
# create the nets folder:
os.makedirs(nets_folder, exist_ok=True)
# --- data handling ---
if not model_is_weatherbench:
input_data_train = torch.load(datasets_folder + "train_x.pty")
target_data_train = torch.load(datasets_folder + "train_y.pty")
input_data_val = torch.load(datasets_folder + "val_x.pty")
target_data_val = torch.load(datasets_folder + "val_y.pty")
window_size = input_data_train.shape[1]
# create the train and val loaders:
dataset_train = InputTargetDataset(input_data_train, target_data_train,
"cuda" if cuda and load_all_data_GPU else "cpu")
dataset_val = InputTargetDataset(input_data_val, target_data_val, "cuda" if cuda and load_all_data_GPU else "cpu")
else:
dataset_train, dataset_val = load_weatherbench_data(weatherbench_data_folder, cuda, load_all_data_GPU,
weatherbench_small=weatherbench_small)
len_dataset_train = len(dataset_train)
len_dataset_val = len(dataset_val)
print("Training set size:", len_dataset_train)
print("Validation set size:", len_dataset_val)
args_dict["len_dataset_train"] = len_dataset_train
args_dict["len_dataset_val"] = len_dataset_val
loader_kwargs = def_loader_kwargs(cuda, load_all_data_GPU)
# loader_kwargs.update(loader_kwargs_2) # if you want to add other loader arguments
data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, **loader_kwargs)
if len(dataset_val) > 0:
data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, **loader_kwargs)
if model_is_weatherbench:
# obtain the target tensor to estimate the gamma for kernel SR:
target_data_val = get_target(data_loader_val, cuda).flatten(1, -1)
else:
data_loader_val = None
# --- losses ---
# instantiate the loss according to the chosen SR; each SR takes as input: (net_output, target)
if not method_is_gan and not method == "regression":
if patched and not model_is_weatherbench:
masks = define_masks[model](data_size=data_size, patch_size=patch_size)
if scoring_rule in ["Kernel", "KernelVariogram", "EnergyKernel"]:
# estimate the gamma value
if gamma is None:
if patched and not model_is_weatherbench:
# determine the gamma using the first patch only. This assumes that the values of the variables
# are roughly the same in the different patches.
gamma = estimate_bandwidth_timeseries(target_data_val[:, masks[0]], return_values=["median"])
else:
gamma = estimate_bandwidth_timeseries(target_data_val, return_values=["median"])
print(f"Estimated gamma: {gamma:.4f}")
args_dict["gamma_kernel_score"] = float(gamma.cpu().numpy())
if kernel == "gaussian":
sr_instance = KernelScore(sigma=gamma)
else:
sr_instance = KernelScore(kernel="rational_quadratic", alpha=gamma ** 2)
elif scoring_rule == "Energy":
sr_instance = EnergyScore()
if scoring_rule in ["Variogram", "EnergyVariogram", "KernelVariogram"]:
variogram = None
if model in ["lorenz96", ]:
variogram = lorenz96_variogram(data_size)
elif model == "WeatherBench":
# variogram = weatherbench_variogram(weatherbench_small=weatherbench_small) # this is the old one
variogram = weatherbench_variogram_haversine(weatherbench_small=weatherbench_small)
if variogram is not None and cuda:
variogram = variogram.cuda()
if scoring_rule == "Variogram":
sr_instance = VariogramScore(variogram=variogram, max_batch_size=16 if model_is_weatherbench else None)
else:
if scoring_rule == "EnergyVariogram":
sr1_instance = EnergyScore()
elif scoring_rule == "KernelVariogram":
if kernel == "gaussian":
sr1_instance = KernelScore(sigma=gamma)
else:
sr1_instance = KernelScore(kernel="rational_quadratic", alpha=gamma ** 2)
sr2_instance = VariogramScore(variogram=variogram, max_batch_size=16 if model_is_weatherbench else None)
weight_list = weight_for_summed_score(scoring_rule, model_is_weatherbench)
args_dict["weight_list"] = weight_list
sr_instance = SumScoringRules((sr1_instance, sr2_instance), weight_list=weight_list)
if scoring_rule == "EnergyKernel":
sr1_instance = EnergyScore()
if kernel == "gaussian":
sr2_instance = KernelScore(sigma=gamma)
else:
sr2_instance = KernelScore(kernel="rational_quadratic", alpha=gamma ** 2)
weight_list = weight_for_summed_score(scoring_rule, model_is_weatherbench)
args_dict["weight_list"] = weight_list
sr_instance = SumScoringRules((sr1_instance, sr2_instance), weight_list=weight_list)
if model_is_weatherbench:
if patched:
# wrap the scoring rule in a new one for the weatherbench dataset, as you need to flatten somehow the data:
sr1_instance = ScoringRulesForWeatherBenchPatched(sr_instance, patch_step=patch_size // 2,
patch_size=patch_size)
sr2_instance = ScoringRulesForWeatherBench(sr_instance)
# the patched version is proper but not strictly; add it therefore to the original SR on the full data to
# make it strictly proper.
weight_list = weight_for_summed_score(scoring_rule, model_is_weatherbench, patch_size=patch_size)
args_dict["weight_list"] = weight_list
sr_instance = SumScoringRules((sr1_instance, sr2_instance), weight_list=weight_list)
else:
# wrap the scoring rule in a new one for the weatherbench dataset, as you need to flatten somehow the data:
sr_instance = ScoringRulesForWeatherBench(sr_instance)
if patched and not model_is_weatherbench:
patched_sr = PatchedScoringRule(sr_instance, masks)
loss_fn = patched_sr.estimate_score_batch
else:
loss_fn = sr_instance.estimate_score_batch
elif method == "regression":
# use the RMSE loss function
loss_fn = nn.MSELoss()
# add a patched version for WeatherBench
if patched:
if model_is_weatherbench:
loss_fn = LossForWeatherBenchPatched(loss_fn)
else:
raise NotImplementedError
# --- networks ---
if seed is not None: # set seed for network instantiating
torch.manual_seed(seed)
if method == "regression":
if nn_model == "unet":
# NOTE: make sure that a channels dimension exists
net_class = UNet2D
unet_kwargs = {"in_channels": data_size[0], "out_channels": 1,
"noise_method": "no noise", "conv_depths": unet_depths}
net = DiscardWindowSizeDim(net_class(**unet_kwargs))
elif nn_model == "rnn":
output_size = data_size
gru_layers = 1
gru_hidden_size = hidden_size_rnn
net = createGRUNN(data_size=data_size, gru_hidden_size=gru_hidden_size,
output_size=output_size, hidden_sizes=None, gru_layers=gru_layers,
nonlinearity=nonlinearities_dict[nonlinearity])()
else:
net = createFCNN(input_size=window_size * data_size, output_size=data_size, unsqueeze_output=True)()
if continue_training_net:
net = load_net(nets_folder + f"net{name_postfix}.pth", DiscardNumberGenerationsInOutput, net)
else:
# create net
net = DiscardNumberGenerationsInOutput(net)
else: # SR and GAN
wrap_net = True
number_generations_per_forward_call = ensemble_size if method == "SR" else 1
# create generative net:
if nn_model == "fcnn":
input_size = window_size * data_size + auxiliary_var_size
output_size = data_size
hidden_sizes_list = [int(input_size * 1.5), int(input_size * 3), int(input_size * 3),
int(input_size * 0.75 + output_size * 3), int(output_size * 5)]
inner_net = createGenerativeFCNN(input_size=input_size, output_size=output_size, hidden_sizes=hidden_sizes_list,
nonlinearity=nonlinearities_dict[nonlinearity])()
elif nn_model == "rnn":
output_size = data_size
gru_layers = 1
gru_hidden_size = hidden_size_rnn
inner_net = createGenerativeGRUNN(data_size=data_size, gru_hidden_size=gru_hidden_size,
noise_size=auxiliary_var_size,
output_size=output_size, hidden_sizes=None, gru_layers=gru_layers,
nonlinearity=nonlinearities_dict[nonlinearity])()
elif nn_model == "unet":
# select the noise method here:
inner_net = UNet2D(in_channels=data_size[0], out_channels=1, noise_method=unet_noise_method,
number_generations_per_forward_call=number_generations_per_forward_call,
conv_depths=unet_depths)
if unet_noise_method in ["sum", "concat"]:
# here we overwrite the auxiliary_var_size above, as there is a precise constraint
downsampling_factor, n_channels = inner_net.calculate_downsampling_factor()
if weatherbench_small:
auxiliary_var_size = torch.Size(
[n_channels, 16 // downsampling_factor, 16 // downsampling_factor])
else:
auxiliary_var_size = torch.Size(
[n_channels, data_size[1] // downsampling_factor, data_size[2] // downsampling_factor])
elif unet_noise_method == "dropout":
wrap_net = False # do not wrap in the conditional generative model
if wrap_net:
# the following wraps the nets above and takes care of generating the auxiliary variables at each forward call
if continue_training_net:
net = load_net(nets_folder + f"net{name_postfix}.pth", ConditionalGenerativeModel, inner_net,
size_auxiliary_variable=auxiliary_var_size, base_measure=base_measure,
number_generations_per_forward_call=number_generations_per_forward_call, seed=seed + 1)
else:
net = ConditionalGenerativeModel(inner_net, size_auxiliary_variable=auxiliary_var_size, seed=seed + 1,
number_generations_per_forward_call=number_generations_per_forward_call,
base_measure=base_measure)
else:
if continue_training_net:
net = load_net(nets_folder + f"net{name_postfix}.pth", DiscardWindowSizeDim, inner_net)
else:
net = DiscardWindowSizeDim(inner_net)
if method_is_gan:
if nn_model == "fcnn":
model_class = createCriticFCNN(input_size=(window_size + 1) * data_size,
end_sigmoid=method == "GAN")
if continue_training_net:
critic = load_net(nets_folder + f"critic{name_postfix}.pth", model_class)
else:
critic = model_class()
elif nn_model == "rnn":
gru_layers_critic = 1
gru_hidden_size = hidden_size_rnn
model_class = createCriticGRUNN(data_size, gru_hidden_size, gru_layers=gru_layers_critic,
end_sigmoid=method == "GAN")
if continue_training_net:
critic = load_net(nets_folder + f"critic{name_postfix}.pth", model_class)
else:
critic = model_class()
elif nn_model == "unet":
# Using PatchGanDiscriminator model popularised in the following work:https://arxiv.org/abs/1611.07004v3
# This discriminator seems to be popular for image based GAN work
if continue_training_net:
critic = load_net(nets_folder + f"critic{name_postfix}.pth",
PatchGANDiscriminator, in_channels=1 + data_size[0], last_layer_filters=32,
n_layers=3, end_sigmoid=method == "GAN",
norm_layer=nn.BatchNorm2d if method != "WGAN_GP" else LayerNormMine)
else:
critic = PatchGANDiscriminator(in_channels=1 + data_size[0], first_layer_filters=32, n_layers=3,
end_sigmoid=method == "GAN",
norm_layer=nn.BatchNorm2d if method != "WGAN_GP" else LayerNormMine)
# in general, we have to put in_channels = (window_size * fields_context) + fields_target
# --- network tools ---
if cuda:
net.cuda()
# optimizer
optimizer_kwargs = {"weight_decay": weight_decay} # l2 regularization
args_dict["weight_decay"] = optimizer_kwargs["weight_decay"]
optimizer = Adam(net.parameters(), lr=lr, **optimizer_kwargs)
# scheduler
scheduler_steps = 10
scheduler_gamma = scheduler_gamma
scheduler = lr_scheduler.StepLR(optimizer, scheduler_steps, gamma=scheduler_gamma, last_epoch=-1)
args_dict["scheduler_steps"] = scheduler_steps
args_dict["scheduler_gamma"] = scheduler_gamma
if method_is_gan:
if cuda:
critic.cuda()
optimizer_kwargs = {}
optimizer_c = Adam(critic.parameters(), lr=lr_c, **optimizer_kwargs)
# dummy scheduler:
scheduler_c = lr_scheduler.StepLR(optimizer_c, 8, gamma=1, last_epoch=-1)
string = f"Train {method} network for {model} model with lr {lr} "
if method == "SR":
string += f"using {scoring_rule} scoring rule"
if method_is_gan:
string += f"and critic lr {lr_c}"
print(string)
# --- train ---
start = time()
if method_is_gan:
# load the previous losses if available:
if continue_training_net:
train_loss_list_g = np.load(nets_folder + f"train_loss_g{name_postfix}.npy").tolist()
train_loss_list_c = np.load(nets_folder + f"train_loss_c{name_postfix}.npy").tolist()
else:
train_loss_list_g = train_loss_list_c = None
kwargs = {}
if method == "WGAN_GP":
kwargs["lambda_gp"] = lambda_gp
train_loss_list_g, train_loss_list_c = fit_adversarial(method, data_loader_train, net, critic, optimizer, scheduler,
optimizer_c, scheduler_c, epochs, cuda,
start_epoch_training=0, use_tqdm=use_tqdm,
critic_steps_every_generator_step=
critic_steps_every_generator_step,
train_loss_list_g=train_loss_list_g,
train_loss_list_c=train_loss_list_c, **kwargs)
else:
if continue_training_net:
train_loss_list = np.load(nets_folder + f"train_loss{name_postfix}.npy").tolist()
val_loss_list = np.load(nets_folder + f"val_loss{name_postfix}.npy").tolist()
else:
train_loss_list = val_loss_list = None
train_loss_list, val_loss_list = fit(data_loader_train, net, loss_fn, optimizer, scheduler, epochs, cuda,
val_loader=data_loader_val, early_stopping=early_stopping,
start_epoch_early_stopping=0 if continue_training_net else start_epoch_early_stopping,
epochs_early_stopping_interval=epochs_early_stopping_interval,
start_epoch_training=0, use_tqdm=use_tqdm, train_loss_list=train_loss_list,
test_loss_list=val_loss_list)
# compute now the final validation loss achieved by the model; it is repetition from what done before but easier
# to do this way
final_validation_loss = test_epoch(data_loader_val, net, loss_fn, cuda)
training_time = time() - start
if save_net_flag:
save_dict_to_json(args_dict, nets_folder + f'config{name_postfix}.json')
save_net(nets_folder + f"net{name_postfix}.pth", net)
if method_is_gan:
save_net(nets_folder + f"critic{name_postfix}.pth", critic)
plot_losses(train_loss_list_g, train_loss_list_c, GAN=True)
plt.savefig(nets_folder + f"losses{name_postfix}.png")
np.save(nets_folder + f"train_loss_g{name_postfix}.npy", np.array(train_loss_list_g))
np.save(nets_folder + f"train_loss_c{name_postfix}.npy", np.array(train_loss_list_c))
else:
plot_losses(train_loss_list, val_loss_list)
plt.savefig(nets_folder + f"losses{name_postfix}.png")
np.save(nets_folder + f"train_loss{name_postfix}.npy", np.array(train_loss_list))
np.save(nets_folder + f"val_loss{name_postfix}.npy", np.array(val_loss_list))
text_file = open(nets_folder + f"final_validation_loss{name_postfix}.txt", "w")
string = f"Final validation loss (with the loss used for training): {final_validation_loss:.2f}"
text_file.write(string + "\n")
text_file.close()
text_file = open(nets_folder + f"training_time{name_postfix}.txt", "w")
string = "Training time: {:.2f} seconds.".format(training_time)
text_file.write(string + "\n")
text_file.close()