-
Notifications
You must be signed in to change notification settings - Fork 13
/
SynthesisNetwork.py
1659 lines (1386 loc) · 112 KB
/
SynthesisNetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
import math
import numpy as np
from helper import gaussian_2d
from config.GlobalVariables import *
class SynthesisNetwork(nn.Module):
def __init__(self, weight_dim=512, num_layers=3, scale_sd=1, clamp_mdn=0, sentence_loss=True, word_loss=True, segment_loss=True, TYPE_A=True, TYPE_B=True, TYPE_C=True, TYPE_D=True, ORIGINAL=True, REC=True):
super(SynthesisNetwork, self).__init__()
self.num_mixtures = 20
self.num_layers = num_layers
self.weight_dim = weight_dim
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.sentence_loss = sentence_loss
self.word_loss = word_loss
self.segment_loss = segment_loss
self.ORIGINAL = ORIGINAL
self.TYPE_A = TYPE_A
self.TYPE_B = TYPE_B
self.TYPE_C = TYPE_C
self.TYPE_D = TYPE_D
self.REC = REC
self.magic_lstm = nn.LSTM(self.weight_dim, self.weight_dim, batch_first=True, num_layers=self.num_layers)
self.char_vec_fc_1 = nn.Linear(len(CHARACTERS), self.weight_dim)
self.char_vec_relu_1 = nn.LeakyReLU(negative_slope=0.1)
self.char_lstm_1 = nn.LSTM(self.weight_dim, self.weight_dim, batch_first=True, num_layers=self.num_layers)
self.char_vec_fc2_1 = nn.Linear(self.weight_dim, self.weight_dim * self.weight_dim)
# inference
self.inf_state_fc1 = nn.Linear(3, self.weight_dim)
self.inf_state_relu = nn.LeakyReLU(negative_slope=0.1)
self.inf_state_lstm = nn.LSTM(self.weight_dim, self.weight_dim, batch_first=True, num_layers=self.num_layers)
self.W_lstm = nn.LSTM(self.weight_dim, self.weight_dim, batch_first=True, num_layers=self.num_layers)
# generation
self.gen_state_fc1 = nn.Linear(3, self.weight_dim)
self.gen_state_relu = nn.LeakyReLU(negative_slope=0.1)
self.gen_state_lstm1 = nn.LSTM(self.weight_dim, self.weight_dim, batch_first=True, num_layers=self.num_layers)
self.gen_state_lstm2 = nn.LSTM(self.weight_dim * 2, self.weight_dim * 2, batch_first=True, num_layers=self.num_layers)
self.gen_state_fc2 = nn.Linear(self.weight_dim * 2, self.num_mixtures * 6 + 1)
self.term_fc1 = nn.Linear(self.weight_dim * 2, self.weight_dim)
self.term_relu1 = nn.LeakyReLU(negative_slope=0.1)
self.term_fc2 = nn.Linear(self.weight_dim, self.weight_dim)
self.term_relu2 = nn.LeakyReLU(negative_slope=0.1)
self.term_fc3 = nn.Linear(self.weight_dim, 1)
self.term_sigmoid = nn.Sigmoid()
self.mdn_sigmoid = nn.Sigmoid()
self.mdn_tanh = nn.Tanh()
self.mdn_softmax = nn.Softmax(dim=1)
self.scale_sd = scale_sd # how much to scale the standard deviation of the gaussians
self.clamp_mdn = clamp_mdn # total percent of disrubution to allow sampling from
self.mdn_bce_loss = nn.BCEWithLogitsLoss()
self.term_bce_loss = nn.BCEWithLogitsLoss()
def forward(self, inputs):
[sentence_level_stroke_in, sentence_level_stroke_out, sentence_level_stroke_length, sentence_level_term, sentence_level_char, sentence_level_char_length, word_level_stroke_in, word_level_stroke_out, word_level_stroke_length, word_level_term, word_level_char, word_level_char_length, segment_level_stroke_in, segment_level_stroke_out, segment_level_stroke_length, segment_level_term, segment_level_char, segment_level_char_length] = inputs
ALL_sentence_W_consistency_loss = []
ALL_ORIGINAL_sentence_termination_loss = []
ALL_ORIGINAL_sentence_loc_reconstruct_loss = []
ALL_ORIGINAL_sentence_touch_reconstruct_loss = []
ALL_TYPE_A_sentence_termination_loss = []
ALL_TYPE_A_sentence_loc_reconstruct_loss = []
ALL_TYPE_A_sentence_touch_reconstruct_loss = []
ALL_TYPE_A_sentence_WC_reconstruct_loss = []
ALL_TYPE_B_sentence_termination_loss = []
ALL_TYPE_B_sentence_loc_reconstruct_loss = []
ALL_TYPE_B_sentence_touch_reconstruct_loss = []
ALL_TYPE_B_sentence_WC_reconstruct_loss = []
ALL_word_W_consistency_loss = []
ALL_ORIGINAL_word_termination_loss = []
ALL_ORIGINAL_word_loc_reconstruct_loss = []
ALL_ORIGINAL_word_touch_reconstruct_loss = []
ALL_TYPE_A_word_termination_loss = []
ALL_TYPE_A_word_loc_reconstruct_loss = []
ALL_TYPE_A_word_touch_reconstruct_loss = []
ALL_TYPE_A_word_WC_reconstruct_loss = []
ALL_TYPE_B_word_termination_loss = []
ALL_TYPE_B_word_loc_reconstruct_loss = []
ALL_TYPE_B_word_touch_reconstruct_loss = []
ALL_TYPE_B_word_WC_reconstruct_loss = []
ALL_TYPE_C_word_termination_loss = []
ALL_TYPE_C_word_loc_reconstruct_loss = []
ALL_TYPE_C_word_touch_reconstruct_loss = []
ALL_TYPE_C_word_WC_reconstruct_loss = []
ALL_TYPE_D_word_termination_loss = []
ALL_TYPE_D_word_loc_reconstruct_loss = []
ALL_TYPE_D_word_touch_reconstruct_loss = []
ALL_TYPE_D_word_WC_reconstruct_loss = []
ALL_word_Wcs_reconstruct_TYPE_A = []
ALL_word_Wcs_reconstruct_TYPE_B = []
ALL_word_Wcs_reconstruct_TYPE_C = []
ALL_word_Wcs_reconstruct_TYPE_D = []
SUPER_ALL_segment_W_consistency_loss = []
SUPER_ALL_ORIGINAL_segment_termination_loss = []
SUPER_ALL_ORIGINAL_segment_loc_reconstruct_loss = []
SUPER_ALL_ORIGINAL_segment_touch_reconstruct_loss = []
SUPER_ALL_TYPE_A_segment_termination_loss = []
SUPER_ALL_TYPE_A_segment_loc_reconstruct_loss = []
SUPER_ALL_TYPE_A_segment_touch_reconstruct_loss = []
SUPER_ALL_TYPE_A_segment_WC_reconstruct_loss = []
SUPER_ALL_TYPE_B_segment_termination_loss = []
SUPER_ALL_TYPE_B_segment_loc_reconstruct_loss = []
SUPER_ALL_TYPE_B_segment_touch_reconstruct_loss = []
SUPER_ALL_TYPE_B_segment_WC_reconstruct_loss = []
SUPER_ALL_segment_Wcs_reconstruct_TYPE_A = []
SUPER_ALL_segment_Wcs_reconstruct_TYPE_B = []
# if self.sentece_loss:
for uid in range(len(sentence_level_stroke_in)):
if self.sentence_loss:
user_sentence_level_stroke_in = sentence_level_stroke_in[uid]
user_sentence_level_stroke_out = sentence_level_stroke_out[uid]
user_sentence_level_stroke_length = sentence_level_stroke_length[uid]
user_sentence_level_term = sentence_level_term[uid]
user_sentence_level_char = sentence_level_char[uid]
user_sentence_level_char_length = sentence_level_char_length[uid]
sentence_batch_size = len(user_sentence_level_stroke_in)
sentence_inf_state_out = self.inf_state_fc1(user_sentence_level_stroke_out)
sentence_inf_state_out = self.inf_state_relu(sentence_inf_state_out)
sentence_inf_state_out, (c,h) = self.inf_state_lstm(sentence_inf_state_out)
sentence_gen_state_out = self.gen_state_fc1(user_sentence_level_stroke_in)
sentence_gen_state_out = self.gen_state_relu(sentence_gen_state_out)
sentence_gen_state_out, (c,h) = self.gen_state_lstm1(sentence_gen_state_out)
sentence_Ws = []
sentence_Wc_rec_TYPE_ = []
sentence_SPLITS = []
sentence_Cs_1 = []
sentence_unique_char_matrices_1 = []
for sentence_batch_id in range(sentence_batch_size):
curr_seq_len = user_sentence_level_stroke_length[sentence_batch_id][0]
curr_char_len = user_sentence_level_char_length[sentence_batch_id][0]
char_vector = torch.eye(len(CHARACTERS))[user_sentence_level_char[sentence_batch_id][:curr_char_len]].to(self.device)
current_term = user_sentence_level_term[sentence_batch_id][:curr_seq_len].unsqueeze(-1)
split_ids = torch.nonzero(current_term)[:,0]
char_vector_1 = self.char_vec_fc_1(char_vector)
char_vector_1 = self.char_vec_relu_1(char_vector_1)
unique_char_matrices_1 = []
for cid in range(len(char_vector)):
# Tower 1
unique_char_vector_1 = char_vector_1[cid:cid+1]
unique_char_input_1 = unique_char_vector_1.unsqueeze(0)
unique_char_out_1, (c,h) = self.char_lstm_1(unique_char_input_1)
unique_char_out_1 = unique_char_out_1.squeeze(0)
unique_char_out_1 = self.char_vec_fc2_1(unique_char_out_1)
unique_char_matrix_1 = unique_char_out_1.view([-1,1,self.weight_dim,self.weight_dim])
unique_char_matrix_1 = unique_char_matrix_1.squeeze(1)
unique_char_matrices_1.append(unique_char_matrix_1)
# Tower 1
char_out_1 = char_vector_1.unsqueeze(0)
char_out_1, (c,h) = self.char_lstm_1(char_out_1)
char_out_1 = char_out_1.squeeze(0)
char_out_1 = self.char_vec_fc2_1(char_out_1)
char_matrix_1 = char_out_1.view([-1,1,self.weight_dim,self.weight_dim])
char_matrix_1 = char_matrix_1.squeeze(1)
char_matrix_inv_1 = torch.inverse(char_matrix_1)
W_c_t = sentence_inf_state_out[sentence_batch_id][:curr_seq_len]
W_c = torch.stack([W_c_t[i] for i in split_ids])
# W = torch.bmm(char_matrix_inv, W_c.unsqueeze(2)).squeeze(-1)
# C1C2C3W = Wc
# W = C3-1 C2-1 C1-1 Wc
W = torch.bmm(char_matrix_inv_1,
W_c.unsqueeze(2)).squeeze(-1)
sentence_Ws.append(W)
sentence_Wc_rec_TYPE_.append(W_c)
sentence_Cs_1.append(char_matrix_1)
sentence_SPLITS.append(split_ids)
sentence_unique_char_matrices_1.append(unique_char_matrices_1)
sentence_Ws_stacked = torch.cat(sentence_Ws, 0)
sentence_Ws_reshaped = sentence_Ws_stacked.view([-1,self.weight_dim])
sentence_W_mean = sentence_Ws_reshaped.mean(0)
sentence_W_mean_repeat = sentence_W_mean.repeat(sentence_Ws_reshaped.size(0),1)
sentence_Ws_consistency_loss = torch.mean(torch.mean(torch.mul(sentence_W_mean_repeat - sentence_Ws_reshaped, sentence_W_mean_repeat - sentence_Ws_reshaped), -1))
ALL_sentence_W_consistency_loss.append(sentence_Ws_consistency_loss)
ORIGINAL_sentence_termination_loss = []
ORIGINAL_sentence_loc_reconstruct_loss = []
ORIGINAL_sentence_touch_reconstruct_loss = []
TYPE_A_sentence_termination_loss = []
TYPE_A_sentence_loc_reconstruct_loss = []
TYPE_A_sentence_touch_reconstruct_loss = []
TYPE_B_sentence_termination_loss = []
TYPE_B_sentence_loc_reconstruct_loss = []
TYPE_B_sentence_touch_reconstruct_loss = []
sentence_Wcs_reconstruct_TYPE_A = []
sentence_Wcs_reconstruct_TYPE_B = []
for sentence_batch_id in range(sentence_batch_size):
sentence_level_gen_encoded = sentence_gen_state_out[sentence_batch_id][:user_sentence_level_stroke_length[sentence_batch_id][0]]
sentence_level_target_eos = user_sentence_level_stroke_out[sentence_batch_id][:user_sentence_level_stroke_length[sentence_batch_id][0]][:,2]
sentence_level_target_x = user_sentence_level_stroke_out[sentence_batch_id][:user_sentence_level_stroke_length[sentence_batch_id][0]][:,0:1]
sentence_level_target_y = user_sentence_level_stroke_out[sentence_batch_id][:user_sentence_level_stroke_length[sentence_batch_id][0]][:,1:2]
sentence_level_target_term = user_sentence_level_term[sentence_batch_id][:user_sentence_level_stroke_length[sentence_batch_id][0]]
# ORIGINAL
if self.ORIGINAL:
sentence_W_lstm_in_ORIGINAL = []
curr_id = 0
for i in range(user_sentence_level_stroke_length[sentence_batch_id][0]):
sentence_W_lstm_in_ORIGINAL.append(sentence_Wc_rec_TYPE_[sentence_batch_id][curr_id])
if i in sentence_SPLITS[sentence_batch_id]:
curr_id += 1
sentence_W_lstm_in_ORIGINAL = torch.stack(sentence_W_lstm_in_ORIGINAL)
sentence_Wc_t_ORIGINAL = sentence_W_lstm_in_ORIGINAL
sentence_gen_lstm2_in_ORIGINAL = torch.cat([sentence_level_gen_encoded, sentence_Wc_t_ORIGINAL], -1)
sentence_gen_lstm2_in_ORIGINAL = sentence_gen_lstm2_in_ORIGINAL.unsqueeze(0)
sentence_gen_out_ORIGINAL,(c,h) = self.gen_state_lstm2(sentence_gen_lstm2_in_ORIGINAL)
sentence_gen_out_ORIGINAL = sentence_gen_out_ORIGINAL.squeeze(0)
mdn_out_ORIGINAL = self.gen_state_fc2(sentence_gen_out_ORIGINAL)
eos_ORIGINAL = mdn_out_ORIGINAL[:,0:1]
[mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL, pi_ORIGINAL] = torch.split(mdn_out_ORIGINAL[:,1:], self.num_mixtures, 1)
sig1_ORIGINAL = sig1_ORIGINAL.exp() + 1e-3
sig2_ORIGINAL = sig2_ORIGINAL.exp() + 1e-3
rho_ORIGINAL = self.mdn_tanh(rho_ORIGINAL)
pi_ORIGINAL = self.mdn_softmax(pi_ORIGINAL)
term_out_ORIGINAL = self.term_fc1(sentence_gen_out_ORIGINAL)
term_out_ORIGINAL = self.term_relu1(term_out_ORIGINAL)
term_out_ORIGINAL = self.term_fc2(term_out_ORIGINAL)
term_out_ORIGINAL = self.term_relu2(term_out_ORIGINAL)
term_out_ORIGINAL = self.term_fc3(term_out_ORIGINAL)
term_pred_ORIGINAL = self.term_sigmoid(term_out_ORIGINAL)
gaussian_ORIGINAL = gaussian_2d(sentence_level_target_x, sentence_level_target_y, mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL)
loss_gaussian_ORIGINAL = - torch.log(torch.sum(pi_ORIGINAL*gaussian_ORIGINAL, dim=1) + 1e-5)
ORIGINAL_sentence_term_loss = self.term_bce_loss(term_out_ORIGINAL.squeeze(1), sentence_level_target_term)
ORIGINAL_sentence_loc_loss = torch.mean(loss_gaussian_ORIGINAL)
ORIGINAL_sentence_touch_loss = self.mdn_bce_loss(eos_ORIGINAL.squeeze(1), sentence_level_target_eos)
ORIGINAL_sentence_termination_loss.append(ORIGINAL_sentence_term_loss)
ORIGINAL_sentence_loc_reconstruct_loss.append(ORIGINAL_sentence_loc_loss)
ORIGINAL_sentence_touch_reconstruct_loss.append(ORIGINAL_sentence_touch_loss)
# TYPE A
if self.TYPE_A:
sentence_C1 = sentence_Cs_1[sentence_batch_id]
# sentence_Wc_rec_TYPE_A = torch.bmm(sentence_Cs[sentence_batch_id], sentence_W_mean.repeat(sentence_Cs[sentence_batch_id].size(0),1).unsqueeze(2)).squeeze(-1)
sentence_Wc_rec_TYPE_A = torch.bmm(sentence_C1, \
sentence_W_mean.repeat(sentence_C1.size(0),1).unsqueeze(2)).squeeze(-1)
sentence_Wcs_reconstruct_TYPE_A.append(sentence_Wc_rec_TYPE_A)
sentence_W_lstm_in_TYPE_A = []
curr_id = 0
for i in range(user_sentence_level_stroke_length[sentence_batch_id][0]):
sentence_W_lstm_in_TYPE_A.append(sentence_Wc_rec_TYPE_A[curr_id])
if i in sentence_SPLITS[sentence_batch_id]:
curr_id += 1
sentence_Wc_t_rec_TYPE_A = torch.stack(sentence_W_lstm_in_TYPE_A)
sentence_gen_lstm2_in_TYPE_A = torch.cat([sentence_level_gen_encoded, sentence_Wc_t_rec_TYPE_A], -1)
sentence_gen_lstm2_in_TYPE_A = sentence_gen_lstm2_in_TYPE_A.unsqueeze(0)
sentence_gen_out_TYPE_A, (c,h) = self.gen_state_lstm2(sentence_gen_lstm2_in_TYPE_A)
sentence_gen_out_TYPE_A = sentence_gen_out_TYPE_A.squeeze(0)
mdn_out_TYPE_A = self.gen_state_fc2(sentence_gen_out_TYPE_A)
eos_TYPE_A = mdn_out_TYPE_A[:,0:1]
[mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A, pi_TYPE_A] = torch.split(mdn_out_TYPE_A[:,1:], self.num_mixtures, 1)
sig1_TYPE_A = sig1_TYPE_A.exp() + 1e-3
sig2_TYPE_A = sig2_TYPE_A.exp() + 1e-3
rho_TYPE_A = self.mdn_tanh(rho_TYPE_A)
pi_TYPE_A = self.mdn_softmax(pi_TYPE_A)
term_out_TYPE_A = self.term_fc1(sentence_gen_out_TYPE_A)
term_out_TYPE_A = self.term_relu1(term_out_TYPE_A)
term_out_TYPE_A = self.term_fc2(term_out_TYPE_A)
term_out_TYPE_A = self.term_relu2(term_out_TYPE_A)
term_out_TYPE_A = self.term_fc3(term_out_TYPE_A)
term_pred_TYPE_A = self.term_sigmoid(term_out_TYPE_A)
gaussian_TYPE_A = gaussian_2d(sentence_level_target_x, sentence_level_target_y, mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A)
loss_gaussian_TYPE_A = - torch.log(torch.sum(pi_TYPE_A*gaussian_TYPE_A, dim=1) + 1e-5)
TYPE_A_sentence_term_loss = self.term_bce_loss(term_out_TYPE_A.squeeze(1), sentence_level_target_term)
TYPE_A_sentence_loc_loss = torch.mean(loss_gaussian_TYPE_A)
TYPE_A_sentence_touch_loss = self.mdn_bce_loss(eos_TYPE_A.squeeze(1), sentence_level_target_eos)
TYPE_A_sentence_termination_loss.append(TYPE_A_sentence_term_loss)
TYPE_A_sentence_loc_reconstruct_loss.append(TYPE_A_sentence_loc_loss)
TYPE_A_sentence_touch_reconstruct_loss.append(TYPE_A_sentence_touch_loss)
# TYPE B
if self.TYPE_B:
unique_char_matrix_1 = sentence_unique_char_matrices_1[sentence_batch_id]
unique_char_matrices_1 = torch.stack(unique_char_matrix_1)
unique_char_matrices_1 = unique_char_matrices_1.squeeze(1)
# sentence_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices, sentence_W_mean.repeat(unique_char_matrices.size(0), 1).unsqueeze(2)).squeeze(-1)
sentence_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices_1,
sentence_W_mean.repeat(unique_char_matrices_1.size(0), 1).unsqueeze(2)).squeeze(-1)
sentence_W_c_TYPE_B_RAW = sentence_W_c_TYPE_B_RAW.unsqueeze(0)
sentence_Wc_rec_TYPE_B, (c,h) = self.magic_lstm(sentence_W_c_TYPE_B_RAW)
sentence_Wc_rec_TYPE_B = sentence_Wc_rec_TYPE_B.squeeze(0)
sentence_Wcs_reconstruct_TYPE_B.append(sentence_Wc_rec_TYPE_B)
sentence_W_lstm_in_TYPE_B = []
curr_id = 0
for i in range(user_sentence_level_stroke_length[sentence_batch_id][0]):
sentence_W_lstm_in_TYPE_B.append(sentence_Wc_rec_TYPE_B[curr_id])
if i in sentence_SPLITS[sentence_batch_id]:
curr_id += 1
sentence_Wc_t_rec_TYPE_B = torch.stack(sentence_W_lstm_in_TYPE_B)
sentence_gen_lstm2_in_TYPE_B = torch.cat([sentence_level_gen_encoded, sentence_Wc_t_rec_TYPE_B], -1)
sentence_gen_lstm2_in_TYPE_B = sentence_gen_lstm2_in_TYPE_B.unsqueeze(0)
sentence_gen_out_TYPE_B, (c,h) = self.gen_state_lstm2(sentence_gen_lstm2_in_TYPE_B)
sentence_gen_out_TYPE_B = sentence_gen_out_TYPE_B.squeeze(0)
mdn_out_TYPE_B = self.gen_state_fc2(sentence_gen_out_TYPE_B)
eos_TYPE_B = mdn_out_TYPE_B[:,0:1]
[mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B, pi_TYPE_B] = torch.split(mdn_out_TYPE_B[:,1:], self.num_mixtures, 1)
sig1_TYPE_B = sig1_TYPE_B.exp() + 1e-3
sig2_TYPE_B = sig2_TYPE_B.exp() + 1e-3
rho_TYPE_B = self.mdn_tanh(rho_TYPE_B)
pi_TYPE_B = self.mdn_softmax(pi_TYPE_B)
term_out_TYPE_B = self.term_fc1(sentence_gen_out_TYPE_B)
term_out_TYPE_B = self.term_relu1(term_out_TYPE_B)
term_out_TYPE_B = self.term_fc2(term_out_TYPE_B)
term_out_TYPE_B = self.term_relu2(term_out_TYPE_B)
term_out_TYPE_B = self.term_fc3(term_out_TYPE_B)
term_pred_TYPE_B = self.term_sigmoid(term_out_TYPE_B)
gaussian_TYPE_B = gaussian_2d(sentence_level_target_x, sentence_level_target_y, mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B)
loss_gaussian_TYPE_B = - torch.log(torch.sum(pi_TYPE_B*gaussian_TYPE_B, dim=1) + 1e-5)
TYPE_B_sentence_term_loss = self.term_bce_loss(term_out_TYPE_B.squeeze(1), sentence_level_target_term)
TYPE_B_sentence_loc_loss = torch.mean(loss_gaussian_TYPE_B)
TYPE_B_sentence_touch_loss = self.mdn_bce_loss(eos_TYPE_B.squeeze(1), sentence_level_target_eos)
TYPE_B_sentence_termination_loss.append(TYPE_B_sentence_term_loss)
TYPE_B_sentence_loc_reconstruct_loss.append(TYPE_B_sentence_loc_loss)
TYPE_B_sentence_touch_reconstruct_loss.append(TYPE_B_sentence_touch_loss)
if self.ORIGINAL:
ALL_ORIGINAL_sentence_termination_loss.append(torch.mean(torch.stack(ORIGINAL_sentence_termination_loss)))
ALL_ORIGINAL_sentence_loc_reconstruct_loss.append(torch.mean(torch.stack(ORIGINAL_sentence_loc_reconstruct_loss)))
ALL_ORIGINAL_sentence_touch_reconstruct_loss.append(torch.mean(torch.stack(ORIGINAL_sentence_touch_reconstruct_loss)))
if self.TYPE_A:
ALL_TYPE_A_sentence_termination_loss.append(torch.mean(torch.stack(TYPE_A_sentence_termination_loss)))
ALL_TYPE_A_sentence_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_sentence_loc_reconstruct_loss)))
ALL_TYPE_A_sentence_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_sentence_touch_reconstruct_loss)))
if self.REC:
TYPE_A_sentence_WC_reconstruct_loss = []
for sentence_batch_id in range(len(sentence_Wc_rec_TYPE_)):
sentence_Wc_ORIGINAL = sentence_Wc_rec_TYPE_[sentence_batch_id]
sentence_Wc_TYPE_A = sentence_Wcs_reconstruct_TYPE_A[sentence_batch_id]
sentence_WC_reconstruct_loss_TYPE_A = torch.mean(torch.mean(torch.mul(sentence_Wc_ORIGINAL - sentence_Wc_TYPE_A, sentence_Wc_ORIGINAL - sentence_Wc_TYPE_A), -1))
TYPE_A_sentence_WC_reconstruct_loss.append(sentence_WC_reconstruct_loss_TYPE_A)
ALL_TYPE_A_sentence_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_sentence_WC_reconstruct_loss)))
if self.TYPE_B:
ALL_TYPE_B_sentence_termination_loss.append(torch.mean(torch.stack(TYPE_B_sentence_termination_loss)))
ALL_TYPE_B_sentence_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_sentence_loc_reconstruct_loss)))
ALL_TYPE_B_sentence_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_sentence_touch_reconstruct_loss)))
if self.REC:
TYPE_B_sentence_WC_reconstruct_loss = []
for sentence_batch_id in range(len(sentence_Wc_rec_TYPE_)):
sentence_Wc_ORIGINAL = sentence_Wc_rec_TYPE_[sentence_batch_id]
sentence_Wc_TYPE_B = sentence_Wcs_reconstruct_TYPE_B[sentence_batch_id]
sentence_WC_reconstruct_loss_TYPE_B = torch.mean(torch.mean(torch.mul(sentence_Wc_ORIGINAL - sentence_Wc_TYPE_B, sentence_Wc_ORIGINAL - sentence_Wc_TYPE_B), -1))
TYPE_B_sentence_WC_reconstruct_loss.append(sentence_WC_reconstruct_loss_TYPE_B)
ALL_TYPE_B_sentence_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_sentence_WC_reconstruct_loss)))
if self.word_loss:
user_word_level_stroke_in = word_level_stroke_in[uid]
user_word_level_stroke_out = word_level_stroke_out[uid]
user_word_level_stroke_length = word_level_stroke_length[uid]
user_word_level_term = word_level_term[uid]
user_word_level_char = word_level_char[uid]
user_word_level_char_length = word_level_char_length[uid]
word_batch_size = len(user_word_level_stroke_in)
word_inf_state_out = self.inf_state_fc1(user_word_level_stroke_out)
word_inf_state_out = self.inf_state_relu(word_inf_state_out)
word_inf_state_out, (c,h) = self.inf_state_lstm(word_inf_state_out)
word_gen_state_out = self.gen_state_fc1(user_word_level_stroke_in)
word_gen_state_out = self.gen_state_relu(word_gen_state_out)
word_gen_state_out, (c,h) = self.gen_state_lstm1(word_gen_state_out)
word_Ws = []
word_Wc_rec_ORIGINAL = []
word_SPLITS = []
word_Cs_1 = []
word_unique_char_matrices_1 = []
W_C_ORIGINALS = []
for word_batch_id in range(word_batch_size):
curr_seq_len = user_word_level_stroke_length[word_batch_id][0]
curr_char_len = user_word_level_char_length[word_batch_id][0]
char_vector = torch.eye(len(CHARACTERS))[user_word_level_char[word_batch_id][:curr_char_len]].to(self.device)
current_term = user_word_level_term[word_batch_id][:curr_seq_len].unsqueeze(-1)
split_ids = torch.nonzero(current_term)[:,0]
char_vector_1 = self.char_vec_fc_1(char_vector)
char_vector_1 = self.char_vec_relu_1(char_vector_1)
unique_char_matrices_1 = []
for cid in range(len(char_vector)):
# Tower 1
unique_char_vector_1 = char_vector_1[cid:cid+1]
unique_char_input_1 = unique_char_vector_1.unsqueeze(0)
unique_char_out_1, (c,h) = self.char_lstm_1(unique_char_input_1)
unique_char_out_1 = unique_char_out_1.squeeze(0)
unique_char_out_1 = self.char_vec_fc2_1(unique_char_out_1)
unique_char_matrix_1 = unique_char_out_1.view([-1,1,self.weight_dim,self.weight_dim])
unique_char_matrix_1 = unique_char_matrix_1.squeeze(1)
unique_char_matrices_1.append(unique_char_matrix_1)
# Tower 1
char_out_1 = char_vector_1.unsqueeze(0)
char_out_1, (c,h) = self.char_lstm_1(char_out_1)
char_out_1 = char_out_1.squeeze(0)
char_out_1 = self.char_vec_fc2_1(char_out_1)
char_matrix_1 = char_out_1.view([-1,1,self.weight_dim,self.weight_dim])
char_matrix_1 = char_matrix_1.squeeze(1)
char_matrix_inv_1 = torch.inverse(char_matrix_1)
W_c_t = word_inf_state_out[word_batch_id][:curr_seq_len]
W_c = torch.stack([W_c_t[i] for i in split_ids])
W_C_ORIGINAL = {}
for i in range(curr_char_len):
sub_s = "".join(CHARACTERS[i] for i in user_word_level_char[word_batch_id][:i+1])
W_C_ORIGINAL[sub_s] = [W_c[i]]
W_C_ORIGINALS.append(W_C_ORIGINAL)
# W = torch.bmm(char_matrix_inv, W_c.unsqueeze(2)).squeeze(-1)
W = torch.bmm(char_matrix_inv_1,
W_c.unsqueeze(2)).squeeze(-1)
word_Ws.append(W)
word_Wc_rec_ORIGINAL.append(W_c)
word_SPLITS.append(split_ids)
# word_Cs.append(char_matrix)
# word_unique_char_matrices.append(unique_char_matrices)
word_Cs_1.append(char_matrix_1)
word_unique_char_matrices_1.append(unique_char_matrices_1)
word_Ws_stacked = torch.cat(word_Ws, 0)
word_Ws_reshaped = word_Ws_stacked.view([-1,self.weight_dim])
word_W_mean = word_Ws_reshaped.mean(0)
word_Ws_reshaped_mean_repeat = word_W_mean.repeat(word_Ws_reshaped.size(0),1)
word_Ws_consistency_loss = torch.mean(torch.mean(torch.mul(word_Ws_reshaped_mean_repeat - word_Ws_reshaped, word_Ws_reshaped_mean_repeat - word_Ws_reshaped), -1))
ALL_word_W_consistency_loss.append(word_Ws_consistency_loss)
# word
ORIGINAL_word_termination_loss = []
ORIGINAL_word_loc_reconstruct_loss = []
ORIGINAL_word_touch_reconstruct_loss = []
TYPE_A_word_termination_loss = []
TYPE_A_word_loc_reconstruct_loss = []
TYPE_A_word_touch_reconstruct_loss = []
TYPE_B_word_termination_loss = []
TYPE_B_word_loc_reconstruct_loss = []
TYPE_B_word_touch_reconstruct_loss = []
TYPE_C_word_termination_loss = []
TYPE_C_word_loc_reconstruct_loss = []
TYPE_C_word_touch_reconstruct_loss = []
TYPE_D_word_termination_loss = []
TYPE_D_word_loc_reconstruct_loss = []
TYPE_D_word_touch_reconstruct_loss = []
word_Wcs_reconstruct_TYPE_A = []
word_Wcs_reconstruct_TYPE_B = []
word_Wcs_reconstruct_TYPE_C = []
word_Wcs_reconstruct_TYPE_D = []
# segment
ALL_segment_W_consistency_loss = []
ALL_ORIGINAL_segment_termination_loss = []
ALL_ORIGINAL_segment_loc_reconstruct_loss = []
ALL_ORIGINAL_segment_touch_reconstruct_loss = []
ALL_TYPE_A_segment_termination_loss = []
ALL_TYPE_A_segment_loc_reconstruct_loss = []
ALL_TYPE_A_segment_touch_reconstruct_loss = []
ALL_TYPE_A_segment_WC_reconstruct_loss = []
ALL_TYPE_B_segment_termination_loss = []
ALL_TYPE_B_segment_loc_reconstruct_loss = []
ALL_TYPE_B_segment_touch_reconstruct_loss = []
ALL_TYPE_B_segment_WC_reconstruct_loss = []
ALL_segment_Wcs_reconstruct_TYPE_A = []
ALL_segment_Wcs_reconstruct_TYPE_B = []
W_C_SEGMENTS = []
W_C_UNIQUES = []
for word_batch_id in range(word_batch_size):
word_level_gen_encoded = word_gen_state_out[word_batch_id][:user_word_level_stroke_length[word_batch_id][0]]
word_level_target_eos = user_word_level_stroke_out[word_batch_id][:user_word_level_stroke_length[word_batch_id][0]][:,2]
word_level_target_x = user_word_level_stroke_out[word_batch_id][:user_word_level_stroke_length[word_batch_id][0]][:,0:1]
word_level_target_y = user_word_level_stroke_out[word_batch_id][:user_word_level_stroke_length[word_batch_id][0]][:,1:2]
word_level_target_term = user_word_level_term[word_batch_id][:user_word_level_stroke_length[word_batch_id][0]]
# ORIGINAL
if self.ORIGINAL:
word_W_lstm_in_ORIGINAL = []
curr_id = 0
for i in range(user_word_level_stroke_length[word_batch_id][0]):
word_W_lstm_in_ORIGINAL.append(word_Wc_rec_ORIGINAL[word_batch_id][curr_id])
if i in word_SPLITS[word_batch_id]:
curr_id += 1
word_W_lstm_in_ORIGINAL = torch.stack(word_W_lstm_in_ORIGINAL)
word_Wc_t_ORIGINAL = word_W_lstm_in_ORIGINAL
word_gen_lstm2_in_ORIGINAL = torch.cat([word_level_gen_encoded, word_Wc_t_ORIGINAL], -1)
word_gen_lstm2_in_ORIGINAL = word_gen_lstm2_in_ORIGINAL.unsqueeze(0)
word_gen_out_ORIGINAL,(c,h) = self.gen_state_lstm2(word_gen_lstm2_in_ORIGINAL)
word_gen_out_ORIGINAL = word_gen_out_ORIGINAL.squeeze(0)
mdn_out_ORIGINAL = self.gen_state_fc2(word_gen_out_ORIGINAL)
eos_ORIGINAL = mdn_out_ORIGINAL[:,0:1]
[mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL, pi_ORIGINAL] = torch.split(mdn_out_ORIGINAL[:,1:], self.num_mixtures, 1)
sig1_ORIGINAL = sig1_ORIGINAL.exp() + 1e-3
sig2_ORIGINAL = sig2_ORIGINAL.exp() + 1e-3
rho_ORIGINAL = self.mdn_tanh(rho_ORIGINAL)
pi_ORIGINAL = self.mdn_softmax(pi_ORIGINAL)
term_out_ORIGINAL = self.term_fc1(word_gen_out_ORIGINAL)
term_out_ORIGINAL = self.term_relu1(term_out_ORIGINAL)
term_out_ORIGINAL = self.term_fc2(term_out_ORIGINAL)
term_out_ORIGINAL = self.term_relu2(term_out_ORIGINAL)
term_out_ORIGINAL = self.term_fc3(term_out_ORIGINAL)
term_pred_ORIGINAL = self.term_sigmoid(term_out_ORIGINAL)
gaussian_ORIGINAL = gaussian_2d(word_level_target_x, word_level_target_y, mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL)
loss_gaussian_ORIGINAL = - torch.log(torch.sum(pi_ORIGINAL*gaussian_ORIGINAL, dim=1) + 1e-5)
ORIGINAL_word_term_loss = self.term_bce_loss(term_out_ORIGINAL.squeeze(1), word_level_target_term)
ORIGINAL_word_loc_loss = torch.mean(loss_gaussian_ORIGINAL)
ORIGINAL_word_touch_loss = self.mdn_bce_loss(eos_ORIGINAL.squeeze(1), word_level_target_eos)
ORIGINAL_word_termination_loss.append(ORIGINAL_word_term_loss)
ORIGINAL_word_loc_reconstruct_loss.append(ORIGINAL_word_loc_loss)
ORIGINAL_word_touch_reconstruct_loss.append(ORIGINAL_word_touch_loss)
# TYPE A
if self.TYPE_A:
word_C1 = word_Cs_1[word_batch_id]
word_Wc_rec_TYPE_A = torch.bmm(word_C1,
word_W_mean.repeat(word_C1.size(0),1).unsqueeze(2)).squeeze(-1)
word_Wcs_reconstruct_TYPE_A.append(word_Wc_rec_TYPE_A)
word_W_lstm_in_TYPE_A = []
curr_id = 0
for i in range(user_word_level_stroke_length[word_batch_id][0]):
word_W_lstm_in_TYPE_A.append(word_Wc_rec_TYPE_A[curr_id])
if i in word_SPLITS[word_batch_id]:
curr_id += 1
word_Wc_t_rec_TYPE_A = torch.stack(word_W_lstm_in_TYPE_A)
word_gen_lstm2_in_TYPE_A = torch.cat([word_level_gen_encoded, word_Wc_t_rec_TYPE_A], -1)
word_gen_lstm2_in_TYPE_A = word_gen_lstm2_in_TYPE_A.unsqueeze(0)
word_gen_out_TYPE_A, (c,h) = self.gen_state_lstm2(word_gen_lstm2_in_TYPE_A)
word_gen_out_TYPE_A = word_gen_out_TYPE_A.squeeze(0)
mdn_out_TYPE_A = self.gen_state_fc2(word_gen_out_TYPE_A)
eos_TYPE_A = mdn_out_TYPE_A[:,0:1]
[mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A, pi_TYPE_A] = torch.split(mdn_out_TYPE_A[:,1:], self.num_mixtures, 1)
sig1_TYPE_A = sig1_TYPE_A.exp() + 1e-3
sig2_TYPE_A = sig2_TYPE_A.exp() + 1e-3
rho_TYPE_A = self.mdn_tanh(rho_TYPE_A)
pi_TYPE_A = self.mdn_softmax(pi_TYPE_A)
term_out_TYPE_A = self.term_fc1(word_gen_out_TYPE_A)
term_out_TYPE_A = self.term_relu1(term_out_TYPE_A)
term_out_TYPE_A = self.term_fc2(term_out_TYPE_A)
term_out_TYPE_A = self.term_relu2(term_out_TYPE_A)
term_out_TYPE_A = self.term_fc3(term_out_TYPE_A)
term_pred_TYPE_A = self.term_sigmoid(term_out_TYPE_A)
gaussian_TYPE_A = gaussian_2d(word_level_target_x, word_level_target_y, mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A)
loss_gaussian_TYPE_A = - torch.log(torch.sum(pi_TYPE_A*gaussian_TYPE_A, dim=1) + 1e-5)
TYPE_A_word_term_loss = self.term_bce_loss(term_out_TYPE_A.squeeze(1), word_level_target_term)
TYPE_A_word_loc_loss = torch.mean(loss_gaussian_TYPE_A)
TYPE_A_word_touch_loss = self.mdn_bce_loss(eos_TYPE_A.squeeze(1), word_level_target_eos)
TYPE_A_word_termination_loss.append(TYPE_A_word_term_loss)
TYPE_A_word_loc_reconstruct_loss.append(TYPE_A_word_loc_loss)
TYPE_A_word_touch_reconstruct_loss.append(TYPE_A_word_touch_loss)
# TYPE B
if self.TYPE_B:
unique_char_matrix_1 = word_unique_char_matrices_1[word_batch_id]
unique_char_matrices_1 = torch.stack(unique_char_matrix_1)
unique_char_matrices_1 = unique_char_matrices_1.squeeze(1)
# word_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices, word_W_mean.repeat(unique_char_matrices.size(0), 1).unsqueeze(2)).squeeze(-1)
word_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices_1,
word_W_mean.repeat(unique_char_matrices_1.size(0), 1).unsqueeze(2)).squeeze(-1)
word_W_c_TYPE_B_RAW = word_W_c_TYPE_B_RAW.unsqueeze(0)
word_Wc_rec_TYPE_B, (c,h) = self.magic_lstm(word_W_c_TYPE_B_RAW)
word_Wc_rec_TYPE_B = word_Wc_rec_TYPE_B.squeeze(0)
word_Wcs_reconstruct_TYPE_B.append(word_Wc_rec_TYPE_B)
word_W_lstm_in_TYPE_B = []
curr_id = 0
for i in range(user_word_level_stroke_length[word_batch_id][0]):
word_W_lstm_in_TYPE_B.append(word_Wc_rec_TYPE_B[curr_id])
if i in word_SPLITS[word_batch_id]:
curr_id += 1
word_Wc_t_rec_TYPE_B = torch.stack(word_W_lstm_in_TYPE_B)
word_gen_lstm2_in_TYPE_B = torch.cat([word_level_gen_encoded, word_Wc_t_rec_TYPE_B], -1)
word_gen_lstm2_in_TYPE_B = word_gen_lstm2_in_TYPE_B.unsqueeze(0)
word_gen_out_TYPE_B, (c,h) = self.gen_state_lstm2(word_gen_lstm2_in_TYPE_B)
word_gen_out_TYPE_B = word_gen_out_TYPE_B.squeeze(0)
mdn_out_TYPE_B = self.gen_state_fc2(word_gen_out_TYPE_B)
eos_TYPE_B = mdn_out_TYPE_B[:,0:1]
[mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B, pi_TYPE_B] = torch.split(mdn_out_TYPE_B[:,1:], self.num_mixtures, 1)
sig1_TYPE_B = sig1_TYPE_B.exp() + 1e-3
sig2_TYPE_B = sig2_TYPE_B.exp() + 1e-3
rho_TYPE_B = self.mdn_tanh(rho_TYPE_B)
pi_TYPE_B = self.mdn_softmax(pi_TYPE_B)
term_out_TYPE_B = self.term_fc1(word_gen_out_TYPE_B)
term_out_TYPE_B = self.term_relu1(term_out_TYPE_B)
term_out_TYPE_B = self.term_fc2(term_out_TYPE_B)
term_out_TYPE_B = self.term_relu2(term_out_TYPE_B)
term_out_TYPE_B = self.term_fc3(term_out_TYPE_B)
term_pred_TYPE_B = self.term_sigmoid(term_out_TYPE_B)
gaussian_TYPE_B = gaussian_2d(word_level_target_x, word_level_target_y, mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B)
loss_gaussian_TYPE_B = - torch.log(torch.sum(pi_TYPE_B*gaussian_TYPE_B, dim=1) + 1e-5)
TYPE_B_word_term_loss = self.term_bce_loss(term_out_TYPE_B.squeeze(1), word_level_target_term)
TYPE_B_word_loc_loss = torch.mean(loss_gaussian_TYPE_B)
TYPE_B_word_touch_loss = self.mdn_bce_loss(eos_TYPE_B.squeeze(1), word_level_target_eos)
TYPE_B_word_termination_loss.append(TYPE_B_word_term_loss)
TYPE_B_word_loc_reconstruct_loss.append(TYPE_B_word_loc_loss)
TYPE_B_word_touch_reconstruct_loss.append(TYPE_B_word_touch_loss)
# TYPE C
# if self.TYPE_C:
user_segment_level_stroke_in = segment_level_stroke_in[uid][word_batch_id]
user_segment_level_stroke_out = segment_level_stroke_out[uid][word_batch_id]
user_segment_level_stroke_length = segment_level_stroke_length[uid][word_batch_id]
user_segment_level_term = segment_level_term[uid][word_batch_id]
user_segment_level_char = segment_level_char[uid][word_batch_id]
user_segment_level_char_length = segment_level_char_length[uid][word_batch_id]
segment_batch_size = len(user_segment_level_stroke_in)
segment_inf_state_out = self.inf_state_fc1(user_segment_level_stroke_out)
segment_inf_state_out = self.inf_state_relu(segment_inf_state_out)
segment_inf_state_out, (c,h) = self.inf_state_lstm(segment_inf_state_out)
segment_gen_state_out = self.gen_state_fc1(user_segment_level_stroke_in)
segment_gen_state_out = self.gen_state_relu(segment_gen_state_out)
segment_gen_state_out, (c,h) = self.gen_state_lstm1(segment_gen_state_out)
segment_Ws = []
segment_Wc_rec_ORIGINAL = []
segment_SPLITS = []
segment_Cs_1 = []
segment_unique_char_matrices_1 = []
W_C_SEGMENT = {}
for segment_batch_id in range(segment_batch_size):
curr_seq_len = user_segment_level_stroke_length[segment_batch_id][0]
curr_char_len = user_segment_level_char_length[segment_batch_id][0]
char_vector = torch.eye(len(CHARACTERS))[user_segment_level_char[segment_batch_id][:curr_char_len]].to(self.device)
current_term = user_segment_level_term[segment_batch_id][:curr_seq_len].unsqueeze(-1)
split_ids = torch.nonzero(current_term)[:,0]
char_vector_1 = self.char_vec_fc_1(char_vector)
char_vector_1 = self.char_vec_relu_1(char_vector_1)
unique_char_matrices_1 = []
for cid in range(len(char_vector)):
# Tower 1
unique_char_vector_1 = char_vector_1[cid:cid+1]
unique_char_input_1 = unique_char_vector_1.unsqueeze(0)
unique_char_out_1, (c,h) = self.char_lstm_1(unique_char_input_1)
unique_char_out_1 = unique_char_out_1.squeeze(0)
unique_char_out_1 = self.char_vec_fc2_1(unique_char_out_1)
unique_char_matrix_1 = unique_char_out_1.view([-1,1,self.weight_dim,self.weight_dim])
unique_char_matrix_1 = unique_char_matrix_1.squeeze(1)
unique_char_matrices_1.append(unique_char_matrix_1)
# Tower 1
char_out_1 = char_vector_1.unsqueeze(0)
char_out_1, (c,h) = self.char_lstm_1(char_out_1)
char_out_1 = char_out_1.squeeze(0)
char_out_1 = self.char_vec_fc2_1(char_out_1)
char_matrix_1 = char_out_1.view([-1,1,self.weight_dim,self.weight_dim])
char_matrix_1 = char_matrix_1.squeeze(1)
char_matrix_inv_1 = torch.inverse(char_matrix_1)
W_c_t = segment_inf_state_out[segment_batch_id][:curr_seq_len]
W_c = torch.stack([W_c_t[i] for i in split_ids])
for i in range(curr_char_len):
sub_s = "".join(CHARACTERS[i] for i in user_segment_level_char[segment_batch_id][:i+1])
if sub_s in W_C_SEGMENT:
W_C_SEGMENT[sub_s].append(W_c[i])
else:
W_C_SEGMENT[sub_s] = [W_c[i]]
W = torch.bmm(char_matrix_inv_1,
W_c.unsqueeze(2)).squeeze(-1)
segment_Ws.append(W)
segment_Wc_rec_ORIGINAL.append(W_c)
segment_SPLITS.append(split_ids)
segment_Cs_1.append(char_matrix_1)
segment_unique_char_matrices_1.append(unique_char_matrices_1)
W_C_SEGMENTS.append(W_C_SEGMENT)
if self.segment_loss:
segment_Ws_stacked = torch.cat(segment_Ws, 0)
segment_Ws_reshaped = segment_Ws_stacked.view([-1,self.weight_dim])
segment_W_mean = segment_Ws_reshaped.mean(0)
segment_Ws_reshaped_mean_repeat = segment_W_mean.repeat(segment_Ws_reshaped.size(0),1)
segment_Ws_consistency_loss = torch.mean(torch.mean(torch.mul(segment_Ws_reshaped_mean_repeat - segment_Ws_reshaped, segment_Ws_reshaped_mean_repeat - segment_Ws_reshaped), -1))
ALL_segment_W_consistency_loss.append(segment_Ws_consistency_loss)
ORIGINAL_segment_termination_loss = []
ORIGINAL_segment_loc_reconstruct_loss = []
ORIGINAL_segment_touch_reconstruct_loss = []
TYPE_A_segment_termination_loss = []
TYPE_A_segment_loc_reconstruct_loss = []
TYPE_A_segment_touch_reconstruct_loss = []
TYPE_B_segment_termination_loss = []
TYPE_B_segment_loc_reconstruct_loss = []
TYPE_B_segment_touch_reconstruct_loss = []
segment_Wcs_reconstruct_TYPE_A = []
segment_Wcs_reconstruct_TYPE_B = []
for segment_batch_id in range(segment_batch_size):
segment_level_gen_encoded = segment_gen_state_out[segment_batch_id][:user_segment_level_stroke_length[segment_batch_id][0]]
segment_level_target_eos = user_segment_level_stroke_out[segment_batch_id][:user_segment_level_stroke_length[segment_batch_id][0]][:,2]
segment_level_target_x = user_segment_level_stroke_out[segment_batch_id][:user_segment_level_stroke_length[segment_batch_id][0]][:,0:1]
segment_level_target_y = user_segment_level_stroke_out[segment_batch_id][:user_segment_level_stroke_length[segment_batch_id][0]][:,1:2]
segment_level_target_term = user_segment_level_term[segment_batch_id][:user_segment_level_stroke_length[segment_batch_id][0]]
if self.ORIGINAL:
segment_W_lstm_in_ORIGINAL = []
curr_id = 0
for i in range(user_segment_level_stroke_length[segment_batch_id][0]):
segment_W_lstm_in_ORIGINAL.append(segment_Wc_rec_ORIGINAL[segment_batch_id][curr_id])
if i in segment_SPLITS[segment_batch_id]:
curr_id += 1
segment_W_lstm_in_ORIGINAL = torch.stack(segment_W_lstm_in_ORIGINAL)
segment_Wc_t_ORIGINAL = segment_W_lstm_in_ORIGINAL
segment_gen_lstm2_in_ORIGINAL = torch.cat([segment_level_gen_encoded, segment_Wc_t_ORIGINAL], -1)
segment_gen_lstm2_in_ORIGINAL = segment_gen_lstm2_in_ORIGINAL.unsqueeze(0)
segment_gen_out_ORIGINAL,(c,h) = self.gen_state_lstm2(segment_gen_lstm2_in_ORIGINAL)
segment_gen_out_ORIGINAL = segment_gen_out_ORIGINAL.squeeze(0)
mdn_out_ORIGINAL = self.gen_state_fc2(segment_gen_out_ORIGINAL)
eos_ORIGINAL = mdn_out_ORIGINAL[:,0:1]
[mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL, pi_ORIGINAL] = torch.split(mdn_out_ORIGINAL[:,1:], self.num_mixtures, 1)
sig1_ORIGINAL = sig1_ORIGINAL.exp() + 1e-3
sig2_ORIGINAL = sig2_ORIGINAL.exp() + 1e-3
rho_ORIGINAL = self.mdn_tanh(rho_ORIGINAL)
pi_ORIGINAL = self.mdn_softmax(pi_ORIGINAL)
term_out_ORIGINAL = self.term_fc1(segment_gen_out_ORIGINAL)
term_out_ORIGINAL = self.term_relu1(term_out_ORIGINAL)
term_out_ORIGINAL = self.term_fc2(term_out_ORIGINAL)
term_out_ORIGINAL = self.term_relu2(term_out_ORIGINAL)
term_out_ORIGINAL = self.term_fc3(term_out_ORIGINAL)
term_pred_ORIGINAL = self.term_sigmoid(term_out_ORIGINAL)
gaussian_ORIGINAL = gaussian_2d(segment_level_target_x, segment_level_target_y, mu1_ORIGINAL, mu2_ORIGINAL, sig1_ORIGINAL, sig2_ORIGINAL, rho_ORIGINAL)
loss_gaussian_ORIGINAL = - torch.log(torch.sum(pi_ORIGINAL*gaussian_ORIGINAL, dim=1) + 1e-5)
ORIGINAL_segment_term_loss = self.term_bce_loss(term_out_ORIGINAL.squeeze(1), segment_level_target_term)
ORIGINAL_segment_loc_loss = torch.mean(loss_gaussian_ORIGINAL)
ORIGINAL_segment_touch_loss = self.mdn_bce_loss(eos_ORIGINAL.squeeze(1), segment_level_target_eos)
ORIGINAL_segment_termination_loss.append(ORIGINAL_segment_term_loss)
ORIGINAL_segment_loc_reconstruct_loss.append(ORIGINAL_segment_loc_loss)
ORIGINAL_segment_touch_reconstruct_loss.append(ORIGINAL_segment_touch_loss)
# TYPE A
if self.TYPE_A:
segment_C1 = segment_Cs_1[segment_batch_id]
segment_Wc_rec_TYPE_A = torch.bmm(segment_C1,
segment_W_mean.repeat(segment_C1.size(0),1).unsqueeze(2)).squeeze(-1)
segment_Wcs_reconstruct_TYPE_A.append(segment_Wc_rec_TYPE_A)
segment_W_lstm_in_TYPE_A = []
curr_id = 0
for i in range(user_segment_level_stroke_length[segment_batch_id][0]):
segment_W_lstm_in_TYPE_A.append(segment_Wc_rec_TYPE_A[curr_id])
if i in segment_SPLITS[segment_batch_id]:
curr_id += 1
segment_Wc_t_rec_TYPE_A = torch.stack(segment_W_lstm_in_TYPE_A)
segment_gen_lstm2_in_TYPE_A = torch.cat([segment_level_gen_encoded, segment_Wc_t_rec_TYPE_A], -1)
segment_gen_lstm2_in_TYPE_A = segment_gen_lstm2_in_TYPE_A.unsqueeze(0)
segment_gen_out_TYPE_A, (c,h) = self.gen_state_lstm2(segment_gen_lstm2_in_TYPE_A)
segment_gen_out_TYPE_A = segment_gen_out_TYPE_A.squeeze(0)
mdn_out_TYPE_A = self.gen_state_fc2(segment_gen_out_TYPE_A)
eos_TYPE_A = mdn_out_TYPE_A[:,0:1]
[mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A, pi_TYPE_A] = torch.split(mdn_out_TYPE_A[:,1:], self.num_mixtures, 1)
sig1_TYPE_A = sig1_TYPE_A.exp() + 1e-3
sig2_TYPE_A = sig2_TYPE_A.exp() + 1e-3
rho_TYPE_A = self.mdn_tanh(rho_TYPE_A)
pi_TYPE_A = self.mdn_softmax(pi_TYPE_A)
term_out_TYPE_A = self.term_fc1(segment_gen_out_TYPE_A)
term_out_TYPE_A = self.term_relu1(term_out_TYPE_A)
term_out_TYPE_A = self.term_fc2(term_out_TYPE_A)
term_out_TYPE_A = self.term_relu2(term_out_TYPE_A)
term_out_TYPE_A = self.term_fc3(term_out_TYPE_A)
term_pred_TYPE_A = self.term_sigmoid(term_out_TYPE_A)
gaussian_TYPE_A = gaussian_2d(segment_level_target_x, segment_level_target_y, mu1_TYPE_A, mu2_TYPE_A, sig1_TYPE_A, sig2_TYPE_A, rho_TYPE_A)
loss_gaussian_TYPE_A = - torch.log(torch.sum(pi_TYPE_A*gaussian_TYPE_A, dim=1) + 1e-5)
TYPE_A_segment_term_loss = self.term_bce_loss(term_out_TYPE_A.squeeze(1), segment_level_target_term)
TYPE_A_segment_loc_loss = torch.mean(loss_gaussian_TYPE_A)
TYPE_A_segment_touch_loss = self.mdn_bce_loss(eos_TYPE_A.squeeze(1), segment_level_target_eos)
TYPE_A_segment_termination_loss.append(TYPE_A_segment_term_loss)
TYPE_A_segment_loc_reconstruct_loss.append(TYPE_A_segment_loc_loss)
TYPE_A_segment_touch_reconstruct_loss.append(TYPE_A_segment_touch_loss)
# TYPE B
if self.TYPE_B:
unique_char_matrix_1 = segment_unique_char_matrices_1[segment_batch_id]
unique_char_matrices_1 = torch.stack(unique_char_matrix_1)
unique_char_matrices_1 = unique_char_matrices_1.squeeze(1)
# segment_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices, segment_W_mean.repeat(unique_char_matrices.size(0), 1).unsqueeze(2)).squeeze(-1)
segment_W_c_TYPE_B_RAW = torch.bmm(unique_char_matrices_1,
segment_W_mean.repeat(unique_char_matrices_1.size(0), 1).unsqueeze(2)).squeeze(-1)
segment_W_c_TYPE_B_RAW = segment_W_c_TYPE_B_RAW.unsqueeze(0)
segment_Wc_rec_TYPE_B, (c,h) = self.magic_lstm(segment_W_c_TYPE_B_RAW)
segment_Wc_rec_TYPE_B = segment_Wc_rec_TYPE_B.squeeze(0)
segment_Wcs_reconstruct_TYPE_B.append(segment_Wc_rec_TYPE_B)
segment_W_lstm_in_TYPE_B = []
curr_id = 0
for i in range(user_segment_level_stroke_length[segment_batch_id][0]):
segment_W_lstm_in_TYPE_B.append(segment_Wc_rec_TYPE_B[curr_id])
if i in segment_SPLITS[segment_batch_id]:
curr_id += 1
segment_Wc_t_rec_TYPE_B = torch.stack(segment_W_lstm_in_TYPE_B)
segment_gen_lstm2_in_TYPE_B = torch.cat([segment_level_gen_encoded, segment_Wc_t_rec_TYPE_B], -1)
segment_gen_lstm2_in_TYPE_B = segment_gen_lstm2_in_TYPE_B.unsqueeze(0)
segment_gen_out_TYPE_B, (c,h) = self.gen_state_lstm2(segment_gen_lstm2_in_TYPE_B)
segment_gen_out_TYPE_B = segment_gen_out_TYPE_B.squeeze(0)
mdn_out_TYPE_B = self.gen_state_fc2(segment_gen_out_TYPE_B)
eos_TYPE_B = mdn_out_TYPE_B[:,0:1]
[mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B, pi_TYPE_B] = torch.split(mdn_out_TYPE_B[:,1:], self.num_mixtures, 1)
sig1_TYPE_B = sig1_TYPE_B.exp() + 1e-3
sig2_TYPE_B = sig2_TYPE_B.exp() + 1e-3
rho_TYPE_B = self.mdn_tanh(rho_TYPE_B)
pi_TYPE_B = self.mdn_softmax(pi_TYPE_B)
term_out_TYPE_B = self.term_fc1(segment_gen_out_TYPE_B)
term_out_TYPE_B = self.term_relu1(term_out_TYPE_B)
term_out_TYPE_B = self.term_fc2(term_out_TYPE_B)
term_out_TYPE_B = self.term_relu2(term_out_TYPE_B)
term_out_TYPE_B = self.term_fc3(term_out_TYPE_B)
term_pred_TYPE_B = self.term_sigmoid(term_out_TYPE_B)
gaussian_TYPE_B = gaussian_2d(segment_level_target_x, segment_level_target_y, mu1_TYPE_B, mu2_TYPE_B, sig1_TYPE_B, sig2_TYPE_B, rho_TYPE_B)
loss_gaussian_TYPE_B = - torch.log(torch.sum(pi_TYPE_B*gaussian_TYPE_B, dim=1) + 1e-5)
TYPE_B_segment_term_loss = self.term_bce_loss(term_out_TYPE_B.squeeze(1), segment_level_target_term)
TYPE_B_segment_loc_loss = torch.mean(loss_gaussian_TYPE_B)
TYPE_B_segment_touch_loss = self.mdn_bce_loss(eos_TYPE_B.squeeze(1), segment_level_target_eos)
TYPE_B_segment_termination_loss.append(TYPE_B_segment_term_loss)
TYPE_B_segment_loc_reconstruct_loss.append(TYPE_B_segment_loc_loss)
TYPE_B_segment_touch_reconstruct_loss.append(TYPE_B_segment_touch_loss)
if self.ORIGINAL:
ALL_ORIGINAL_segment_termination_loss.append(torch.mean(torch.stack(ORIGINAL_segment_termination_loss)))
ALL_ORIGINAL_segment_loc_reconstruct_loss.append(torch.mean(torch.stack(ORIGINAL_segment_loc_reconstruct_loss)))
ALL_ORIGINAL_segment_touch_reconstruct_loss.append(torch.mean(torch.stack(ORIGINAL_segment_touch_reconstruct_loss)))
if self.TYPE_A:
ALL_TYPE_A_segment_termination_loss.append(torch.mean(torch.stack(TYPE_A_segment_termination_loss)))
ALL_TYPE_A_segment_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_segment_loc_reconstruct_loss)))
ALL_TYPE_A_segment_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_segment_touch_reconstruct_loss)))
if self.REC:
TYPE_A_segment_WC_reconstruct_loss = []
for segment_batch_id in range(len(segment_Wc_rec_ORIGINAL)):
segment_Wc_ORIGINAL = segment_Wc_rec_ORIGINAL[segment_batch_id]
segment_Wc_TYPE_A = segment_Wcs_reconstruct_TYPE_A[segment_batch_id]
segment_WC_reconstruct_loss_TYPE_A = torch.mean(torch.mean(torch.mul(segment_Wc_ORIGINAL - segment_Wc_TYPE_A, segment_Wc_ORIGINAL - segment_Wc_TYPE_A), -1))
TYPE_A_segment_WC_reconstruct_loss.append(segment_WC_reconstruct_loss_TYPE_A)
ALL_TYPE_A_segment_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_A_segment_WC_reconstruct_loss)))
if self.TYPE_B:
ALL_TYPE_B_segment_termination_loss.append(torch.mean(torch.stack(TYPE_B_segment_termination_loss)))
ALL_TYPE_B_segment_loc_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_segment_loc_reconstruct_loss)))
ALL_TYPE_B_segment_touch_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_segment_touch_reconstruct_loss)))
if self.REC:
TYPE_B_segment_WC_reconstruct_loss = []
for segment_batch_id in range(len(segment_Wc_rec_ORIGINAL)):
segment_Wc_ORIGINAL = segment_Wc_rec_ORIGINAL[segment_batch_id]
segment_Wc_TYPE_B = segment_Wcs_reconstruct_TYPE_B[segment_batch_id]
segment_WC_reconstruct_loss_TYPE_B = torch.mean(torch.mean(torch.mul(segment_Wc_ORIGINAL - segment_Wc_TYPE_B, segment_Wc_ORIGINAL - segment_Wc_TYPE_B), -1))
TYPE_B_segment_WC_reconstruct_loss.append(segment_WC_reconstruct_loss_TYPE_B)
ALL_TYPE_B_segment_WC_reconstruct_loss.append(torch.mean(torch.stack(TYPE_B_segment_WC_reconstruct_loss)))
if self.TYPE_C:
# target
original_W_c = word_Wc_rec_ORIGINAL[word_batch_id]
word_Wc_rec_TYPE_C = []
for segment_batch_id in range(len(segment_Wc_rec_ORIGINAL)):
if segment_batch_id == 0:
for each_segment_Wc in segment_Wc_rec_ORIGINAL[segment_batch_id]:
word_Wc_rec_TYPE_C.append(each_segment_Wc)
prev_id = len(word_Wc_rec_TYPE_C) - 1
else:
prev_original_W_c = original_W_c[prev_id]
for each_segment_Wc in segment_Wc_rec_ORIGINAL[segment_batch_id]:
magic_inp = torch.stack([prev_original_W_c, each_segment_Wc])
magic_inp = magic_inp.unsqueeze(0)
type_c_out, (c,h) = self.magic_lstm(magic_inp)
type_c_out = type_c_out.squeeze(0)
word_Wc_rec_TYPE_C.append(type_c_out[-1])
prev_id = len(word_Wc_rec_TYPE_C) - 1
word_Wc_rec_TYPE_C = torch.stack(word_Wc_rec_TYPE_C)
word_Wcs_reconstruct_TYPE_C.append(word_Wc_rec_TYPE_C)
if len(word_Wc_rec_TYPE_C) == len(word_SPLITS[word_batch_id]):
word_W_lstm_in_TYPE_C = []
curr_id = 0
for i in range(user_word_level_stroke_length[word_batch_id][0]):
word_W_lstm_in_TYPE_C.append(word_Wc_rec_TYPE_C[curr_id])
if i in word_SPLITS[word_batch_id]:
curr_id += 1
word_Wc_t_rec_TYPE_C = torch.stack(word_W_lstm_in_TYPE_C)
word_gen_lstm2_in_TYPE_C = torch.cat([word_level_gen_encoded, word_Wc_t_rec_TYPE_C], -1)
word_gen_lstm2_in_TYPE_C = word_gen_lstm2_in_TYPE_C.unsqueeze(0)
word_gen_out_TYPE_C, (c,h) = self.gen_state_lstm2(word_gen_lstm2_in_TYPE_C)