-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathCycleGAN.py
122 lines (98 loc) · 4.88 KB
/
CycleGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from tensorflow import keras
import tensorflow as tf
class CycleGan(keras.Model):
def __init__(
self,
c_generator,
s_generator,
c_discriminator,
s_discriminator,
lambda_cycle=10,
):
super(CycleGan, self).__init__()
self.c_gen = c_generator
self.s_gen = s_generator
self.c_disc = c_discriminator
self.s_disc = s_discriminator
self.lambda_cycle = lambda_cycle
def compile(
self,
c_gen_optimizer,
s_gen_optimizer,
c_disc_optimizer,
s_disc_optimizer,
gen_loss_fn,
disc_loss_fn,
cycle_loss_fn,
identity_loss_fn
):
super(CycleGan, self).compile()
self.c_gen_optimizer = c_gen_optimizer
self.s_gen_optimizer = s_gen_optimizer
self.c_disc_optimizer = c_disc_optimizer
self.s_disc_optimizer = s_disc_optimizer
self.gen_loss_fn = gen_loss_fn
self.disc_loss_fn = disc_loss_fn
self.cycle_loss_fn = cycle_loss_fn
self.identity_loss_fn = identity_loss_fn
def train_step(self, batch_data):
real_s, real_c = batch_data
with tf.GradientTape(persistent=True) as tape:
fake_s = self.s_gen(real_c, training=True)
cycled_c = self.c_gen(fake_s, training=True)
fake_c = self.c_gen(real_s, training=True)
cycled_s = self.s_gen(fake_c, training=True)
# generating itself
same_s = self.s_gen(real_s, training=True)
same_c = self.c_gen(real_c, training=True)
# discriminator used to check, inputing real images
disc_real_s = self.s_disc(real_s, training=True)
disc_real_c = self.c_disc(real_c, training=True)
# discriminator used to check, inputing fake images
disc_fake_s = self.s_disc(fake_s, training=True)
disc_fake_c = self.c_disc(fake_c, training=True)
# evaluates generator loss
s_gen_loss = self.gen_loss_fn(disc_fake_s)
c_gen_loss = self.gen_loss_fn(disc_fake_c)
# evaluates total cycle consistency loss
total_cycle_loss = self.cycle_loss_fn(real_s,
cycled_s,
self.lambda_cycle) + self.cycle_loss_fn(real_c,
cycled_c,
self.lambda_cycle)
# evaluates total generator loss
total_s_gen_loss = s_gen_loss + total_cycle_loss + self.identity_loss_fn(real_s,
same_s,
self.lambda_cycle)
total_c_gen_loss = c_gen_loss + total_cycle_loss + self.identity_loss_fn(real_c,
same_c,
self.lambda_cycle)
# evaluates discriminator loss
s_disc_loss = self.disc_loss_fn(disc_real_s, disc_fake_s)
c_disc_loss = self.disc_loss_fn(disc_real_c, disc_fake_c)
# Calculate the gradients for generator and discriminator
s_generator_gradients = tape.gradient(total_s_gen_loss,
self.s_gen.trainable_variables)
c_generator_gradients = tape.gradient(total_c_gen_loss,
self.c_gen.trainable_variables)
s_discriminator_gradients = tape.gradient(s_disc_loss,
self.s_disc.trainable_variables)
c_discriminator_gradients = tape.gradient(c_disc_loss,
self.c_disc.trainable_variables)
# Apply the gradients to the optimizer
self.s_gen_optimizer.apply_gradients(zip(s_generator_gradients,
self.s_gen.trainable_variables))
self.c_gen_optimizer.apply_gradients(zip(c_generator_gradients,
self.c_gen.trainable_variables))
self.s_disc_optimizer.apply_gradients(zip(s_discriminator_gradients,
self.s_disc.trainable_variables))
self.c_disc_optimizer.apply_gradients(zip(c_discriminator_gradients,
self.c_disc.trainable_variables))
return {
"DomA_gen_loss": total_s_gen_loss,
"DomB_gen_loss": total_c_gen_loss,
"DomA_disc_loss": s_disc_loss,
"DomB_disc_loss": c_disc_loss
}
def call(self, x):
pass