-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtraining_cytoVAE.py
57 lines (46 loc) · 1.5 KB
/
training_cytoVAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# -*- coding: utf-8 -*-
"""
@author: Maxime W. Lafarge, (mlafarge); Eindhoven University of Technology, The Netherlands
@comment: For more details see "Capturing Single-Cell Phenotypic Variation via Unsupervised Representation Learning"; MW Lafarge et al.; MIDL 2019; PMLR 102:315-325
Master script to run the training procedure of the model.
"""
"""
1) IMPORT CURRENT EXPERIMENT
"""
from experiment import exp
dManager = exp.config.dManager #-- Data Manager Class
model = exp.config.model #-- Model
"""
2) INITIALIZE THE TRAINING CLASS
"""
gpu_memory_fraction = exp.config.gpu_memory_fraction
trainer = model.Trainer (
name = exp.config.name,
path2restore = exp.config.path_to_restore, #-- Model state recovey
model = model, #-- Imported model
monitoring = True,
is_training = True,
gpu_memory_fraction=gpu_memory_fraction)
"""
3) TRAINING ITERATIONS
"""
for step in range(exp.config.maxIterations):
#------
#-- 0) Booleans of the current iteration
isValidation = (step+1) % exp.config.validationPeriod == 0
#------
#-- 1) Extract 2 independent image batches
tensor_images = dManager.generateBatch()
tensor_discrimination = dManager.generateBatch()
#------
#-- 2) Run a training iteration (VAE and Discriminator are trained in parallel)
trainer.train(
tensor_images = tensor_images,
tensor_images_disc = tensor_discrimination)
#------
#-- 4) Run a validation iteration
if isValidation:
""" USER-FREE VALIDATION PROCEDURE
"""
pass
print("Training done.")