forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
optim_baseline.h
1358 lines (1335 loc) · 77 KB
/
optim_baseline.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// @generated from optim_baseline.py
#include <torch/types.h>
#include <vector>
namespace expected_parameters {
inline std::vector<std::vector<torch::Tensor>> LBFGS() {
return {
{
torch::tensor({-0.20959197386869663, -0.49580870398532073, -0.1313442585372408, -0.3287331939506787, -0.24613947168465267, 0.705889510763571}),
torch::tensor({-0.10412662274500666, -0.2644705062031845, 0.7102859961803084}),
torch::tensor({-0.19787984636009417, -0.5320223708266223, -0.5396083236337847}),
torch::tensor({-0.43108206822505857}),
},
{
torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}),
torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}),
torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}),
torch::tensor({-4.776742087865583}),
},
{
torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}),
torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}),
torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}),
torch::tensor({-4.776742087865583}),
},
{
torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}),
torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}),
torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}),
torch::tensor({-4.776742087865583}),
},
{
torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}),
torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}),
torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}),
torch::tensor({-4.776742087865583}),
},
{
torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}),
torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}),
torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}),
torch::tensor({-4.776742087865583}),
},
{
torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}),
torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}),
torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}),
torch::tensor({-4.776742087865583}),
},
{
torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}),
torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}),
torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}),
torch::tensor({-4.776742087865583}),
},
{
torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}),
torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}),
torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}),
torch::tensor({-4.776742087865583}),
},
{
torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}),
torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}),
torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}),
torch::tensor({-4.776742087865583}),
},
{
torch::tensor({0.4377600774755075, 0.3828823919505896, 0.5308031277873992, 0.5752746453369446, 0.23943592910168343, 1.3739197373644627}),
torch::tensor({2.209263823172053, 2.154134023426646, 2.534834254325867}),
torch::tensor({-4.091952315741579, -4.67916063385269, -4.781279234594454}),
torch::tensor({-4.776742087865583}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> LBFGS_with_line_search() {
return {
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
{
torch::tensor({-0.2108988568338871, -0.4975560466422629, -0.14129216202471762, -0.3420288967865903, -0.2523635082803723, 0.6975570255493777}),
torch::tensor({-0.10853121966252377, -0.297948499687533, 0.6892015099955717}),
torch::tensor({-0.05080313011597659, -0.39413518751058996, -0.28433759745928844}),
torch::tensor({-0.07113812430174116}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> Adam() {
return {
{
torch::tensor({0.7890972864438472, 0.5024410688121617, 0.8587073313055582, 0.6579707241208395, 0.7476356819075531, 1.697556420651692}),
torch::tensor({0.891467636010675, 0.7020513497567501, 1.6892012709428947}),
torch::tensor({-1.0508030958460797, -1.3941351509567657, -1.284337577714353}),
torch::tensor({-1.071138110298716}),
},
{
torch::tensor({8.233039313231828, 7.971150747377481, 6.6436209506776, 6.470977407900541, 6.170125488259256, 7.1507391033435015}),
torch::tensor({8.417695070103735, 6.597188212844593, 7.23175710827678}),
torch::tensor({-6.729624357635757, -7.09743493108154, -6.753301896575352}),
torch::tensor({-6.435639096011218}),
},
{
torch::tensor({8.233424596059296, 7.971537360032308, 6.643920150720394, 6.47127807553724, 6.170405874224489, 7.151021086137982}),
torch::tensor({8.418084791214294, 6.597493171180545, 7.232043740621598}),
torch::tensor({-6.729918250724671, -7.097730102046093, -6.753584809755359}),
torch::tensor({-6.4359165566974985}),
},
{
torch::tensor({8.233424610557648, 7.971537374586563, 6.643920161995285, 6.471278086877829, 6.170405884785074, 7.151021096766405}),
torch::tensor({8.418084805901902, 6.597493182713584, 7.2320437514477875}),
torch::tensor({-6.72991829363266, -7.097730147102975, -6.753584838821182}),
torch::tensor({-6.435916580217771}),
},
{
torch::tensor({8.233424610575101, 7.971537374611125, 6.643920162027962, 6.471278086923278, 6.170405884809245, 7.15102109680004}),
torch::tensor({8.418084805946389, 6.597493182796847, 7.232043751509309}),
torch::tensor({-6.729918332327653, -7.097730188349552, -6.753584861205486}),
torch::tensor({-6.435916596115672}),
},
{
torch::tensor({8.233424610594858, 7.971537374639166, 6.643920162065571, 6.471278086975759, 6.170405884836981, 7.1510210968387975}),
torch::tensor({8.418084805997614, 6.59749318289335, 7.232043751580523}),
torch::tensor({-6.72991837738045, -7.097730236373201, -6.753584887267492}),
torch::tensor({-6.43591661462546}),
},
{
torch::tensor({8.233424610617288, 7.971537374671012, 6.643920162108285, 6.471278087035362, 6.170405884868481, 7.151021096882811}),
torch::tensor({8.418084806055795, 6.59749318300295, 7.232043751661401}),
torch::tensor({-6.729918428547273, -7.09773029091405, -6.753584916866329}),
torch::tensor({-6.4359166356471755}),
},
{
torch::tensor({8.233424610642352, 7.9715373747065925, 6.6439201621560064, 6.471278087101955, 6.1704058849036745, 7.151021096931989}),
torch::tensor({8.418084806120799, 6.597493183125404, 7.232043751751764}),
torch::tensor({-6.729918485714688, -7.0977303518511805, -6.753584949936365}),
torch::tensor({-6.43591665913422}),
},
{
torch::tensor({8.233424610670035, 7.97153737474589, 6.6439201622087145, 6.471278087175502, 6.170405884942545, 7.151021096986302}),
torch::tensor({8.418084806192592, 6.597493183260647, 7.232043751851564}),
torch::tensor({-6.729918548853505, -7.097730419153473, -6.753584986460725}),
torch::tensor({-6.435916685074594}),
},
{
torch::tensor({8.233424610700348, 7.971537374788922, 6.643920162266433, 6.4712780872560405, 6.17040588498511, 7.151021097045779}),
torch::tensor({8.418084806271214, 6.597493183408747, 7.232043751960854}),
torch::tensor({-6.7299186179943, -7.097730492853521, -6.753585026457088}),
torch::tensor({-6.435916713480863}),
},
{
torch::tensor({8.233424610733326, 7.971537374835737, 6.643920162329225, 6.471278087343659, 6.170405885031416, 7.151021097110483}),
torch::tensor({8.418084806356743, 6.597493183569867, 7.232043752079749}),
torch::tensor({-6.729918693213275, -7.097730573032567, -6.753585069969552}),
torch::tensor({-6.43591674438434}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> Adam_with_weight_decay() {
return {
{
torch::tensor({0.7890990163499767, 0.5024427688479549, 0.858707365154099, 0.65797076763247, 0.7476358193232038, 1.6975559791029715}),
torch::tensor({0.8914677624298939, 0.7020513562204098, 1.6892012237887575}),
torch::tensor({-1.050803095786311, -1.3941351504224309, -1.2843375776028747}),
torch::tensor({-1.0711381102847533}),
},
{
torch::tensor({0.17835734655765323, 0.2542117171890537, 0.19681971909229715, 0.23522651199260597, 0.17806083719648957, 0.22943655675307303}),
torch::tensor({0.6227676931552837, 0.6058596954431213, 0.6077176546857177}),
torch::tensor({-1.4259755901844118, -1.4333355461952704, -1.408545526635006}),
torch::tensor({-2.0710783081666215}),
},
{
torch::tensor({0.17965695035191162, 0.24254352340441693, 0.17964663531482672, 0.24250834976541322, 0.17962893833698693, 0.24249920074277215}),
torch::tensor({0.6287144967638043, 0.6286955805603279, 0.6286563093833837}),
torch::tensor({-1.4123887230853596, -1.4124007126659273, -1.4122701589749163}),
torch::tensor({-2.063357041247863}),
},
{
torch::tensor({0.1796366651819, 0.24250861931831874, 0.17963731759793083, 0.24250861142436989, 0.1796372002681969, 0.24250890248031373}),
torch::tensor({0.6287221269294724, 0.6287225821354421, 0.6287220274975922}),
torch::tensor({-1.4123466103044011, -1.4123465669572683, -1.4123462614739388}),
torch::tensor({-2.063368365143669}),
},
{
torch::tensor({0.17963666103165563, 0.24250882317446784, 0.17963665831217887, 0.24250882481082656, 0.17963666029066117, 0.24250882426223175}),
torch::tensor({0.6287216329900817, 0.6287216340515608, 0.6287216326960158}),
torch::tensor({-1.4123467542623926, -1.4123467542350234, -1.4123467478191443}),
torch::tensor({-2.0633690432440437}),
},
{
torch::tensor({0.17963666098500394, 0.24250882442377164, 0.17963666099348902, 0.2425088244120223, 0.1796366609725109, 0.24250882441058697}),
torch::tensor({0.6287216343798432, 0.6287216343800675, 0.6287216343742645}),
torch::tensor({-1.4123467490742723, -1.412346749072554, -1.4123467490678536}),
torch::tensor({-2.0633690434425396}),
},
{
torch::tensor({0.17963666098407144, 0.24250882442250174, 0.17963666098407347, 0.2425088244224233, 0.17963666098409325, 0.2425088244225157}),
torch::tensor({0.6287216343836609, 0.6287216343836147, 0.6287216343836255}),
torch::tensor({-1.412346749067226, -1.412346749067243, -1.412346749067169}),
torch::tensor({-2.063369043434909}),
},
{
torch::tensor({0.17963666098406988, 0.2425088244224408, 0.17963666098407077, 0.24250882442244073, 0.17963666098407008, 0.2425088244224409}),
torch::tensor({0.6287216343837067, 0.6287216343837065, 0.6287216343837069}),
torch::tensor({-1.4123467490671706, -1.412346749067171, -1.4123467490671713}),
torch::tensor({-2.0633690434349057}),
},
{
torch::tensor({0.17963666098407038, 0.24250882442244104, 0.17963666098407027, 0.24250882442244104, 0.17963666098407025, 0.24250882442244098}),
torch::tensor({0.6287216343837067, 0.628721634383707, 0.6287216343837067}),
torch::tensor({-1.4123467490671706, -1.4123467490671708, -1.4123467490671706}),
torch::tensor({-2.0633690434349052}),
},
{
torch::tensor({0.1796366609840706, 0.24250882442244143, 0.17963666098407047, 0.24250882442244096, 0.17963666098407025, 0.24250882442244098}),
torch::tensor({0.6287216343837069, 0.6287216343837067, 0.6287216343837067}),
torch::tensor({-1.4123467490671706, -1.4123467490671706, -1.4123467490671708}),
torch::tensor({-2.0633690434349052}),
},
{
torch::tensor({0.1796366609840692, 0.24250882442244046, 0.17963666098407022, 0.24250882442244082, 0.17963666098407, 0.24250882442244104}),
torch::tensor({0.6287216343837063, 0.6287216343837068, 0.6287216343837067}),
torch::tensor({-1.4123467490671708, -1.4123467490671706, -1.4123467490671708}),
torch::tensor({-2.0633690434349052}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> Adam_with_weight_decay_and_amsgrad() {
return {
{
torch::tensor({0.7890972867575196, 0.5024410692260988, 0.8587073313091852, 0.6579707241257546, 0.7476356819241026, 1.6975564206261673}),
torch::tensor({0.8914676360248869, 0.7020513497574256, 1.6892012709389561}),
torch::tensor({-1.050803095846074, -1.3941351509567128, -1.284337577714342}),
torch::tensor({-1.0711381102987145}),
},
{
torch::tensor({6.790598887618061, 6.914995398136696, 6.41533478566264, 6.297644005485053, 5.845162499872375, 6.862229173597117}),
torch::tensor({7.958707058914726, 6.511338624975532, 7.100969502256063}),
torch::tensor({-6.690689640539306, -7.056584601121166, -6.72114879738572}),
torch::tensor({-6.406608295022552}),
},
{
torch::tensor({4.707506618354547, 5.291519064582759, 6.0451502264500006, 6.024403702678936, 5.309533822430375, 6.388110918107735}),
torch::tensor({7.200495189000188, 6.398387074819269, 6.904125817198589}),
torch::tensor({-6.664150387053514, -7.026716194929788, -6.705821732490459}),
torch::tensor({-6.396310969025695}),
},
{
torch::tensor({2.9508632109188633, 3.7657755643994775, 5.60741774331852, 5.6957903028180565, 4.70145185833677, 5.835064148607041}),
torch::tensor({6.343524400109462, 6.258242740866945, 6.663022973860484}),
torch::tensor({-6.630461605133603, -6.988854886932907, -6.686194352796841}),
torch::tensor({-6.383010489575922}),
},
{
torch::tensor({1.7128944635692829, 2.536365345915568, 5.140416924817106, 5.33803266121343, 4.083921806116116, 5.254596369238127}),
torch::tensor({5.477917349690043, 6.100068192681452, 6.394747035239918}),
torch::tensor({-6.591742325353548, -6.945355504749947, -6.663584873152447}),
torch::tensor({-6.367675348676994}),
},
{
torch::tensor({0.9341502247285258, 1.6339620765410685, 4.6679910940755835, 4.967688979298023, 3.4933141073198866, 4.678195347615295}),
torch::tensor({4.655117321178743, 5.9294836450698245, 6.110061909503652}),
torch::tensor({-6.549116242899458, -6.897486328511809, -6.638629863638681}),
torch::tensor({-6.350731975696792}),
},
{
torch::tensor({0.483518223008081, 1.014261673266094, 4.205928052060015, 4.596195204751035, 2.9502123780175378, 4.125826973031755}),
torch::tensor({3.90359317709392, 5.750505227817309, 5.816617309738603}),
torch::tensor({-6.503371263778851, -6.846137347295816, -6.611773185243846}),
torch::tensor({-6.3324769650576656}),
},
{
torch::tensor({0.2393576446248765, 0.6100241101779533, 3.764601561942264, 4.231602962540335, 2.4647709193637635, 3.6096961114614476}),
torch::tensor({3.236721556734532, 5.566168160977344, 5.520085344708356}),
torch::tensor({-6.455100840665729, -6.791979259673106, -6.583347815648856}),
torch::tensor({-6.3131323905801136}),
},
{
torch::tensor({0.11401265024593016, 0.3570521972760832, 3.350517925954931, 3.8795333009419823, 2.0402068661130683, 3.136550602110189}),
torch::tensor({2.6579570162250215, 5.378834741966309, 5.224742933241745}),
torch::tensor({-6.4047717084756375, -6.7355397098532155, -6.55361495837462}),
torch::tensor({-6.2928722155069945}),
},
{
torch::tensor({0.05251515193791458, 0.20410212473600725, 2.9673680881961273, 3.543794405777883, 1.6752677855061209, 2.709287985107431}),
torch::tensor({2.164479166686583, 5.190372657839918, 4.9338240234040756}),
torch::tensor({-6.352761531270841, -6.6772456859648175, -6.52278570167088}),
torch::tensor({-6.271836898876738}),
},
{
torch::tensor({0.023489947480426834, 0.11428338573638941, 2.616797245623764, 3.226821571853439, 1.3659994589608537, 2.328136084816453}),
torch::tensor({1.74978620664146, 5.002269811977871, 4.649756802441968}),
torch::tensor({-6.299382007948917, -6.617449564196286, -6.491034254081261}),
torch::tensor({-6.250142306631259}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> AdamW() {
return {
{
torch::tensor({0.7912062750121864, 0.5074166292785842, 0.8601202529258052, 0.6613910130887053, 0.7501593169903569, 1.6905808503961983}),
torch::tensor({0.8925529482073002, 0.7050308347536254, 1.682309255842939}),
torch::tensor({-1.05029506454492, -1.3901937990816595, -1.2814942017397601}),
torch::tensor({-1.0704267290556988}),
},
{
torch::tensor({3.3165329599188507, 3.223120441823618, 2.665544565239194, 2.6044341406663225, 2.479859063483047, 2.836831717112226}),
torch::tensor({3.3885192024669744, 2.6544147219174556, 2.8709245656887328}),
torch::tensor({-2.70172647102137, -2.836731459490802, -2.69652471546253}),
torch::tensor({-2.575239255076019}),
},
{
torch::tensor({2.231471944853865, 2.3549328325971755, 1.5699078054795328, 1.6160272935884685, 1.5339085081403547, 1.7397405105941612}),
torch::tensor({2.8552579170807926, 1.8369866847839356, 1.9735168512425862}),
torch::tensor({-2.6042083360293855, -2.6996673713262336, -1.8976087706977893}),
torch::tensor({-1.6180915942867784}),
},
{
torch::tensor({2.084688381515552, 2.3141612674892946, 1.4850714710140511, 1.5961047256668386, 1.440300645879787, 1.6065354941586025}),
torch::tensor({3.0111385685659444, 1.955556497153507, 1.9596562467797627}),
torch::tensor({-2.889337305884852, -2.965249100126337, -1.7721676671605975}),
torch::tensor({-1.4001341655590005}),
},
{
torch::tensor({2.0465343456006604, 2.311613891239368, 1.4666717526896398, 1.601383980913499, 1.4223660595993763, 1.5711552625612757}),
torch::tensor({3.07151984580744, 2.0112690538174802, 1.9592484602763875}),
torch::tensor({-3.0186469726426863, -3.093855445542849, -1.7367953899738784}),
torch::tensor({-1.3299011560804312}),
},
{
torch::tensor({2.039659777412556, 2.3178034179536273, 1.4654302718412722, 1.6094701969162322, 1.4230510816446773, 1.565168902852383}),
torch::tensor({3.1007583934270064, 2.039757113618415, 1.9652096140698696}),
torch::tensor({-3.0880626664330832, -3.166705422245348, -1.73538367534238}),
torch::tensor({-1.3130428735015893}),
},
{
torch::tensor({2.0413773043991963, 2.3251469369586366, 1.4690808101517236, 1.6174065798291044, 1.4280274009117935, 1.5682418226469732}),
torch::tensor({3.118843540209399, 2.057729936485249, 1.9742319629710936}),
torch::tensor({-3.1331019663177013, -3.2154332694373107, -1.7459831639793468}),
torch::tensor({-1.3148644134154366}),
},
{
torch::tensor({2.0452604138357113, 2.332074253989847, 1.4738845773449165, 1.6246403004735728, 1.4335712611625357, 1.573826630920094}),
torch::tensor({3.1324088069784093, 2.0711763619826575, 1.9841582498316732}),
torch::tensor({-3.16737058959847, -3.2529206463859146, -1.7602788393925501}),
torch::tensor({-1.32281766461531}),
},
{
torch::tensor({2.0495243704493262, 2.338413341249581, 1.4787599440132637, 1.631210274009555, 1.438849155552895, 1.5798736919537595}),
torch::tensor({3.1438209015414227, 2.0823943437659658, 1.9940075805973108}),
torch::tensor({-3.19628690363529, -3.2845941643030367, -1.7752333900055153}),
torch::tensor({-1.332456718314933}),
},
{
torch::tensor({2.0536979895206295, 2.3442520601250334, 1.4834272584224222, 1.6372462654983486, 1.4437517398490174, 1.585780877834892}),
torch::tensor({3.1540081461072447, 2.0923381262560454, 2.0034284957296107}),
torch::tensor({-3.222142968201519, -3.312867602521477, -1.7898220261118043}),
torch::tensor({-1.3422692037690986}),
},
{
torch::tensor({2.0576784836825315, 2.3496934395759377, 1.4878413407927933, 1.6428612479757005, 1.4483225979568104, 1.5914034339763325}),
torch::tensor({3.163383747232199, 2.101446878895216, 2.012344413569353}),
torch::tensor({-3.246000281299229, -3.338904166978488, -1.8037666936489785}),
torch::tensor({-1.3517884775416527}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> AdamW_without_weight_decay() {
return {
{
torch::tensor({0.7890972864438476, 0.5024410688121617, 0.858707331305558, 0.6579707241208395, 0.7476356819075531, 1.6975564206516922}),
torch::tensor({0.891467636010675, 0.70205134975675, 1.689201270942895}),
torch::tensor({-1.0508030958460797, -1.3941351509567654, -1.284337577714353}),
torch::tensor({-1.071138110298716}),
},
{
torch::tensor({8.233039313231831, 7.971150747377481, 6.643620950677599, 6.47097740790054, 6.170125488259256, 7.150739103343502}),
torch::tensor({8.417695070103738, 6.597188212844593, 7.23175710827678}),
torch::tensor({-6.729624357635757, -7.09743493108154, -6.753301896575352}),
torch::tensor({-6.435639096011218}),
},
{
torch::tensor({8.233424596059299, 7.971537360032308, 6.643920150720393, 6.471278075537239, 6.170405874224489, 7.151021086137983}),
torch::tensor({8.418084791214298, 6.597493171180545, 7.232043740621598}),
torch::tensor({-6.729918250724671, -7.097730102046093, -6.753584809755359}),
torch::tensor({-6.4359165566974985}),
},
{
torch::tensor({8.233424610557652, 7.971537374586563, 6.643920161995284, 6.471278086877828, 6.170405884785074, 7.151021096766406}),
torch::tensor({8.418084805901906, 6.597493182713584, 7.2320437514477875}),
torch::tensor({-6.72991829363266, -7.097730147102975, -6.753584838821182}),
torch::tensor({-6.435916580217771}),
},
{
torch::tensor({8.233424610575105, 7.971537374611125, 6.643920162027961, 6.471278086923277, 6.170405884809245, 7.151021096800041}),
torch::tensor({8.418084805946393, 6.597493182796847, 7.232043751509309}),
torch::tensor({-6.729918332327653, -7.097730188349552, -6.753584861205486}),
torch::tensor({-6.435916596115672}),
},
{
torch::tensor({8.233424610594861, 7.971537374639166, 6.64392016206557, 6.471278086975758, 6.170405884836981, 7.151021096838798}),
torch::tensor({8.418084805997617, 6.59749318289335, 7.232043751580523}),
torch::tensor({-6.72991837738045, -7.097730236373201, -6.753584887267492}),
torch::tensor({-6.43591661462546}),
},
{
torch::tensor({8.233424610617291, 7.971537374671012, 6.643920162108284, 6.471278087035361, 6.170405884868481, 7.151021096882812}),
torch::tensor({8.418084806055798, 6.59749318300295, 7.232043751661401}),
torch::tensor({-6.729918428547273, -7.09773029091405, -6.753584916866329}),
torch::tensor({-6.4359166356471755}),
},
{
torch::tensor({8.233424610642356, 7.9715373747065925, 6.643920162156006, 6.471278087101954, 6.1704058849036745, 7.15102109693199}),
torch::tensor({8.418084806120802, 6.597493183125404, 7.232043751751764}),
torch::tensor({-6.729918485714688, -7.0977303518511805, -6.753584949936365}),
torch::tensor({-6.43591665913422}),
},
{
torch::tensor({8.233424610670038, 7.97153737474589, 6.643920162208714, 6.471278087175501, 6.170405884942545, 7.151021096986303}),
torch::tensor({8.418084806192596, 6.597493183260647, 7.232043751851564}),
torch::tensor({-6.729918548853505, -7.097730419153473, -6.753584986460725}),
torch::tensor({-6.435916685074594}),
},
{
torch::tensor({8.233424610700352, 7.971537374788922, 6.643920162266432, 6.47127808725604, 6.17040588498511, 7.1510210970457795}),
torch::tensor({8.418084806271217, 6.597493183408747, 7.232043751960854}),
torch::tensor({-6.7299186179943, -7.097730492853521, -6.753585026457088}),
torch::tensor({-6.435916713480863}),
},
{
torch::tensor({8.23342461073333, 7.971537374835737, 6.643920162329224, 6.471278087343658, 6.170405885031416, 7.151021097110484}),
torch::tensor({8.418084806356747, 6.597493183569867, 7.232043752079749}),
torch::tensor({-6.729918693213275, -7.097730573032567, -6.753585069969552}),
torch::tensor({-6.43591674438434}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> AdamW_with_amsgrad() {
return {
{
torch::tensor({0.7912062750121864, 0.5074166292785842, 0.8601202529258052, 0.6613910130887053, 0.7501593169903569, 1.6905808503961983}),
torch::tensor({0.8925529482073002, 0.7050308347536254, 1.682309255842939}),
torch::tensor({-1.05029506454492, -1.3901937990816595, -1.2814942017397601}),
torch::tensor({-1.0704267290556988}),
},
{
torch::tensor({3.3017259270507915, 3.2082991753694565, 2.653930978510442, 2.5927674339810585, 2.4689608790182933, 2.825873703467739}),
torch::tensor({3.373698198112671, 2.6425942964586664, 2.8597930424244304}),
torch::tensor({-2.690360632302962, -2.8253191596069525, -2.6855499873057473}),
torch::tensor({-2.5644658591929406}),
},
{
torch::tensor({2.222607725541013, 2.3447188854637004, 1.5614270655258826, 1.606610018462357, 1.5260497191448619, 1.7309643622674138}),
torch::tensor({2.84137462783552, 1.824806600633721, 1.9620493659996037}),
torch::tensor({-2.576642773625787, -2.6706153846815766, -1.8799876863754623}),
torch::tensor({-1.6044722984810953}),
},
{
torch::tensor({2.0739558768648205, 2.3008338863863496, 1.4738888208638767, 1.5829485271829449, 1.4296176764284294, 1.5939984909850073}),
torch::tensor({2.9908013612792415, 1.936590940953305, 1.941691630199464}),
torch::tensor({-2.846562884997548, -2.9195962101501203, -1.746484716887341}),
torch::tensor({-1.381525131003179}),
},
{
torch::tensor({2.0333926094256953, 2.294977109171754, 1.452870514716895, 1.584853677999522, 1.4086299433181402, 1.5548201727855224}),
torch::tensor({3.0454817801193976, 1.9867169062383696, 1.935312753106444}),
torch::tensor({-2.9612116762746394, -3.0322275992001084, -1.7026114905180725}),
torch::tensor({-1.30563541393247}),
},
{
torch::tensor({2.02392417168201, 2.2977279587859, 1.4488511120131309, 1.5894646930743725, 1.406253686759073, 1.5450647949022756}),
torch::tensor({3.069025602955343, 2.0096872138967488, 1.935438546309299}),
torch::tensor({-3.016103148166836, -3.0893062953033583, -1.6925290685615872}),
torch::tensor({-1.282870120405012}),
},
{
torch::tensor({2.0230257817348316, 2.3016167065040647, 1.4496901978629444, 1.5939034289777392, 1.4082421794430946, 1.5444538756003756}),
torch::tensor({3.0814016132787954, 2.022150201844143, 1.9387429991308658}),
torch::tensor({-3.0466485946438406, -3.1223144611322446, -1.6944083009127773}),
torch::tensor({-1.2786980736911064}),
},
{
torch::tensor({2.024305922065404, 2.305101549510461, 1.4516863420588493, 1.5976447954882376, 1.4108596097183552, 1.5464236284303425}),
torch::tensor({3.0892574233545065, 2.0300944858242236, 1.943040321845021}),
torch::tensor({-3.0664189567897306, -3.1440820888166425, -1.6999750448893618}),
torch::tensor({-1.2806281811826203}),
},
{
torch::tensor({2.0259854116237364, 2.3080169152218293, 1.4537680296813915, 1.6007369432392426, 1.4132529823064277, 1.5489035046525346}),
torch::tensor({3.0949717180851337, 2.0358251379915764, 1.9473249654893}),
torch::tensor({-3.0808231377905426, -3.160021699873689, -1.7062031001273494}),
torch::tensor({-1.28423401369705}),
},
{
torch::tensor({2.0275923948348638, 2.3104512601389637, 1.455657715721078, 1.6033123357613526, 1.4153003204463288, 1.5512775896116622}),
torch::tensor({3.099479021299846, 2.0403012223048775, 1.9512285931847464}),
torch::tensor({-3.092151979336299, -3.1725453680885267, -1.7120689614428697}),
torch::tensor({-1.2880095517062655}),
},
{
torch::tensor({2.029022468328371, 2.3125066985045892, 1.4573100228823295, 1.605484259933419, 1.417038245960655, 1.5533932056240227}),
torch::tensor({3.103195011518616, 2.043964003458376, 1.9546640840748621}),
torch::tensor({-3.1014680843131184, -3.1828179298513968, -1.7172933346797972}),
torch::tensor({-1.2914899987134136}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> Adagrad() {
return {
{
torch::tensor({0.7891011045987429, 0.502443924512199, 0.8587078329085825, 0.6579710994224826, 0.7476364836215006, 1.697557019500397}),
torch::tensor({0.8914687688941954, 0.7020514988069096, 1.6892015076050444}),
torch::tensor({-1.0508031297732776, -1.3941351871450518, -1.284337597261839}),
torch::tensor({-1.071138124161711}),
},
{
torch::tensor({2.4079229696892583, 2.2346803754764286, 1.6967885588547365, 1.552279695827649, 1.2259044248443602, 2.221279696180243}),
torch::tensor({2.9334079162217193, 1.7619824934767887, 2.3464577179091473}),
torch::tensor({-2.221396083069719, -2.549950976011168, -1.9709315957317095}),
torch::tensor({-1.5858816837541876}),
},
{
torch::tensor({2.510404433941812, 2.3522584510262887, 1.7921695110761213, 1.657755825836846, 1.2891186618593045, 2.291878516133922}),
torch::tensor({3.092171180776419, 1.8971624370952997, 2.438734251283465}),
torch::tensor({-2.437641633486504, -2.7704264590526573, -2.0949471699460225}),
torch::tensor({-1.6769121890401757}),
},
{
torch::tensor({2.5652648968109415, 2.4155313947260972, 1.844241233613541, 1.7156513351246399, 1.3245206506797171, 2.3315409972138825}),
torch::tensor({3.178399916514377, 1.9721945764936502, 2.4909037706250428}),
torch::tensor({-2.5658710403147933, -2.901921821645266, -2.168560672193225}),
torch::tensor({-1.7307903926154131}),
},
{
torch::tensor({2.6021584494332592, 2.4582101324909065, 1.8796060082750778, 1.7550965207414717, 1.3489253597999988, 2.3589345190118247}),
torch::tensor({3.2368674310041516, 2.0236468833666894, 2.52707132741292}),
torch::tensor({-2.6573969292994164, -2.9960731060650505, -2.2211375717304076}),
torch::tensor({-1.7692090167089707}),
},
{
torch::tensor({2.629700772579208, 2.4901377017698683, 1.906173477530586, 1.7847957161833832, 1.3674517119505822, 2.3797578857769905}),
torch::tensor({3.2807643102638546, 2.062561811940094, 2.5546379424362775}),
torch::tensor({-2.7286379977755035, -3.0695109399636236, -2.262081199960513}),
torch::tensor({-1.7990936323432214}),
},
{
torch::tensor({2.6515471766995247, 2.51550257362603, 1.927341363452414, 1.8084994719811576, 1.3823309942932445, 2.3964995243914373}),
torch::tensor({3.3157334001309473, 2.093728023484945, 2.5768468697402924}),
torch::tensor({-2.786981763434855, -3.129746439571402, -2.29562487034177}),
torch::tensor({-1.8235564908139104}),
},
{
torch::tensor({2.6695780544837886, 2.53646401614724, 1.9448721033433505, 1.828157582353901, 1.3947329882074622, 2.4104657178934947}),
torch::tensor({3.344694775590452, 2.1196465761628516, 2.5954050923596252}),
torch::tensor({-2.8363936812536537, -3.1808219609745194, -2.32404190866147}),
torch::tensor({-1.8442667636913117}),
},
{
torch::tensor({2.684883801533072, 2.5542762192735515, 1.9597939532350015, 1.844909608012419, 1.4053459079217485, 2.4224257790968386}),
torch::tensor({3.369349515259956, 2.1417845308976795, 2.611319989214332}),
torch::tensor({-2.879251075341889, -3.225165734647855, -2.3486956737228057}),
torch::tensor({-1.86222449978646}),
},
{
torch::tensor({2.698151012423769, 2.56972998600169, 1.972757472697587, 1.8594775691681182, 1.4146081751022495, 2.43287021079559}),
torch::tensor({3.390772758897601, 2.1610741754331757, 2.6252349489549824}),
torch::tensor({-2.917092322961074, -3.264351563375218, -2.370468664387175}),
torch::tensor({-1.8780765115117757}),
},
{
torch::tensor({2.7098389356033783, 2.5833548721723747, 1.9841994925173085, 1.8723468731726323, 1.4228158926355312, 2.4421305315945085}),
torch::tensor({3.4096859099156673, 2.178143852041279, 2.6375854547611364}),
torch::tensor({-2.9509704554208467, -3.2994581338995044, -2.3899651139415874}),
torch::tensor({-1.8922653655195538}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> Adagrad_with_weight_decay() {
return {
{
torch::tensor({0.7891011218979068, 0.5024439415126254, 0.8587078332470682, 0.6579710998575992, 0.7476364849956589, 1.6975570150849029}),
torch::tensor({0.8914687701583902, 0.7020514988715463, 1.6892015071335027}),
torch::tensor({-1.0508031297726799, -1.3941351871397083, -1.2843375972607243}),
torch::tensor({-1.0711381241615712}),
},
{
torch::tensor({0.1846116678522213, 0.24944077103107917, 0.18651745437755768, 0.25219093533041764, 0.18712037968446713, 0.25289206444055234}),
torch::tensor({0.6482869597891656, 0.6580215784646755, 0.6581256007663537}),
torch::tensor({-1.454709711443681, -1.4748063405174818, -1.4811625946604765}),
torch::tensor({-1.905292836544363}),
},
{
torch::tensor({0.18059895999281475, 0.2438515539257779, 0.18067177884778182, 0.24397186395008694, 0.18168388351830797, 0.24533853846052017}),
torch::tensor({0.6325250261983028, 0.6331827793513023, 0.6366659383355598}),
torch::tensor({-1.420803333750877, -1.4215627240541653, -1.4320264544533396}),
torch::tensor({-2.030135641848322}),
},
{
torch::tensor({0.17981392697398363, 0.2427571544305695, 0.17981150414451733, 0.24275725992310523, 0.18014798619115763, 0.2432144956227816}),
torch::tensor({0.6294321320817985, 0.6294873737410742, 0.6306958589251878}),
torch::tensor({-1.4139253354785764, -1.413902680470981, -1.4173628530293867}),
torch::tensor({-2.056210117690093}),
},
{
torch::tensor({0.17967006242163747, 0.24255582734557277, 0.17966873677301953, 0.24255462870545766, 0.1797588230898851, 0.24267729072765756}),
torch::tensor({0.6288576295241085, 0.6288643132826753, 0.6291921485342001}),
torch::tensor({-1.4126465879787569, -1.4126335126907266, -1.4135586793353685}),
torch::tensor({-2.0618018405404825}),
},
{
torch::tensor({0.17964321284685653, 0.24251808241139364, 0.17964291377171066, 0.24251779104198598, 0.1796651574178102, 0.2425481059085456}),
torch::tensor({0.628748693136779, 0.6287498167193976, 0.6288312441271762}),
torch::tensor({-1.412405895385289, -1.4124029484481164, -1.4126313051315378}),
torch::tensor({-2.0630223163099304}),
},
{
torch::tensor({0.1796379973927849, 0.2425107191246215, 0.1796379363134342, 0.2425106592536174, 0.17964321802205502, 0.24251786094585864}),
torch::tensor({0.6287272170354161, 0.6287274414587727, 0.6287468362309863}),
torch::tensor({-1.4123588626342634, -1.412358263650784, -1.412412480569672}),
torch::tensor({-2.0632918101480255}),
},
{
torch::tensor({0.17963694231402444, 0.24250922426893615, 0.1796369298071621, 0.2425092121007451, 0.17963815759528073, 0.24251088666939838}),
torch::tensor({0.6287228195255881, 0.6287228675172439, 0.628727383936762}),
torch::tensor({-1.4123493065102781, -1.412349184462438, -1.4123617872243597}),
torch::tensor({-2.063351765096138}),
},
{
torch::tensor({0.1796367215904632, 0.24250891070978667, 0.17963671897003045, 0.24250890818318074, 0.17963700091084855, 0.24250929278196054}),
torch::tensor({0.6287218911936107, 0.6287219017313679, 0.6287229399204574}),
torch::tensor({-1.4123473011084142, -1.412347275640343, -1.41235016959507}),
torch::tensor({-2.0633651674043505}),
},
{
torch::tensor({0.17963667424978783, 0.24250884333120687, 0.17963667368829764, 0.24250884279379462, 0.17963673796131557, 0.24250893047794023}),
torch::tensor({0.6287216908150736, 0.6287216931558691, 0.6287219299749583}),
torch::tensor({-1.4123468700596724, -1.4123468646187736, -1.4123475243360133}),
torch::tensor({-2.0633681724342527}),
},
{
torch::tensor({0.1796366639185348, 0.24250882860835257, 0.17963666379614568, 0.24250882849182892, 0.17963667838367053, 0.2425088483939741}),
torch::tensor({0.6287216468984888, 0.6287216474215305, 0.6287217011907862}),
torch::tensor({-1.4123467758545658, -1.412346774671038, -1.4123469244007658}),
torch::tensor({-2.0633688474977467}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> Adagrad_with_weight_decay_and_lr_decay() {
return {
{
torch::tensor({0.7891011046018798, 0.5024439245163383, 0.8587078329086189, 0.6579710994225316, 0.747636483621666, 1.697557019500142}),
torch::tensor({0.8914687688943375, 0.7020514988069164, 1.6892015076050049}),
torch::tensor({-1.0508031297732776, -1.3941351871450511, -1.284337597261839}),
torch::tensor({-1.0711381241617108}),
},
{
torch::tensor({2.346218944110103, 2.191939439502003, 1.683355201740813, 1.5405520021635604, 1.2137800230828062, 2.205283463717303}),
torch::tensor({2.9090564593404, 1.7509657336815554, 2.336166413186925}),
torch::tensor({-2.206159683368316, -2.5344318233445415, -1.9622783535807609}),
torch::tensor({-1.5796101463783623}),
},
{
torch::tensor({2.3889328781057233, 2.2678221038007296, 1.7667624725138267, 1.6358015176639822, 1.2655767687152566, 2.261088056711282}),
torch::tensor({3.045569451994985, 1.8770196253823253, 2.4192707519566765}),
torch::tensor({-2.4079300017528613, -2.7399112002234305, -2.0780613510632375}),
torch::tensor({-1.664722108226537}),
},
{
torch::tensor({2.3886137557806384, 2.2922158071009178, 1.8078384116424007, 1.684352474440932, 1.290353948335789, 2.2870715509706496}),
torch::tensor({3.111110355394278, 1.9438501730282314, 2.4630249355872826}),
torch::tensor({-2.5226122034499263, -2.857315093916292, -2.143964860243905}),
torch::tensor({-1.7130685809905042}),
},
{
torch::tensor({2.374703352203156, 2.298804499257456, 1.8330249458212446, 1.7151661013307244, 1.3048586226945842, 2.3017650590464274}),
torch::tensor({3.150318222034133, 1.9877926185369321, 2.491399976401679}),
torch::tensor({-2.601415913361488, -2.938203895113964, -2.1892988334550028}),
torch::tensor({-1.7462964261966805}),
},
{
torch::tensor({2.3553658567303812, 2.297191758042688, 1.8501154749072124, 1.736836058688188, 1.3141313000193942, 2.3107452592153854}),
torch::tensor({3.1762315339155434, 2.0197585204578647, 2.5117041377790197}),
torch::tensor({-2.6606644002288697, -2.9991216074293856, -2.223413376189609}),
torch::tensor({-1.7712905233118807}),
},
{
torch::tensor({2.3338052201696207, 2.2913023710914993, 1.8624163948044772, 1.7530300731725454, 1.3203313209234842, 2.3163969478854747}),
torch::tensor({3.1943525925688934, 2.044447386769377, 2.527109724607397}),
torch::tensor({-2.7076634717294894, -3.047500808469036, -2.250495807208967}),
torch::tensor({-1.7911288238757486}),
},
{
torch::tensor({2.3114979154644892, 2.2830501835377808, 1.871616142999356, 1.765632597660841, 1.324565631636651, 2.319939234205203}),
torch::tensor({3.2074779925809085, 2.0642940833670544, 2.5392671301471235}),
torch::tensor({-2.7463093287485925, -3.087315541134716, -2.272780318857348}),
torch::tensor({-1.8074516661537263}),
},
{
torch::tensor({2.2891841627387346, 2.2734716995793693, 1.8786818999895825, 1.7757301317117602, 1.3274682997719436, 2.3220679353993825}),
torch::tensor({3.2172019619454075, 2.0807140893178175, 2.5491374815141876}),
torch::tensor({-2.7789204504423823, -3.1209351402429175, -2.2915969523376867}),
torch::tensor({-1.821234722948421}),
},
{
torch::tensor({2.2672498238343066, 2.2631678037928893, 1.8842131287032622, 1.7840007705383882, 1.3294311820750493, 2.323211243034543}),
torch::tensor({3.224507488068445, 2.094598223519413, 2.5573257155791715}),
torch::tensor({-2.8069849199086647, -3.1498826045022925, -2.3077996970997727}),
torch::tensor({-1.8331040438272388}),
},
{
torch::tensor({2.2458961718688957, 2.2525031725114775, 1.8886034384961112, 1.7908930341267952, 1.3307102435291205, 2.323647497600462}),
torch::tensor({3.230036385413078, 2.1065407459636134, 2.5642349249609664}),
torch::tensor({-2.83151249424399, -3.1751926295316566, -2.3219682378974036}),
torch::tensor({-1.8434843744626483}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> RMSprop() {
return {
{
torch::tensor({0.7890625772821005, 0.502415108650816, 0.8587027713011453, 0.657967312300643, 0.7476283936579036, 1.6975509766054537}),
torch::tensor({0.8914573371873159, 0.7020499947573374, 1.6891991194739453}),
torch::tensor({-1.0508027874171133, -1.3941348219724659, -1.2843374000099703}),
torch::tensor({-1.0711379842715099}),
},
{
torch::tensor({2.448571858277443, 2.2809152044417678, 1.7346424449151965, 1.5940004770230667, 1.250761131839982, 2.248993270255382}),
torch::tensor({2.994661478530102, 1.8150485290864256, 2.382542610897819}),
torch::tensor({-2.3036981738757825, -2.6337299521275646, -2.018370122358821}),
torch::tensor({-1.620787559800898}),
},
{
torch::tensor({2.5837582475607785, 2.4365737242301537, 1.8622886519354538, 1.7357065282848232, 1.3369695670141974, 2.3454934716983695}),
torch::tensor({3.2061266499381618, 1.9981112525417788, 2.5092495986614}),
torch::tensor({-2.6110809365525958, -2.9484807193016787, -2.194898560798439}),
torch::tensor({-1.7501043480625826}),
},
{
torch::tensor({2.669969051134511, 2.536559412710799, 1.9456091681389671, 1.828914948091767, 1.3952956766999587, 2.4110816686341923}),
torch::tensor({3.343672975593657, 2.1204057198913002, 2.5961524902119497}),
torch::tensor({-2.8372329851331006, -3.1817729538857207, -2.3249971853996954}),
torch::tensor({-1.8450422173907486}),
},
{
torch::tensor({2.7375365004059122, 2.6153071545358633, 2.0117493624534313, 1.9033001982031035, 1.4427501882445097, 2.4646213743186127}),
torch::tensor({3.452912454199796, 2.2190451524127535, 2.667552790123282}),
torch::tensor({-3.0329479456731505, -3.384582488936652, -2.4377299824997136}),
torch::tensor({-1.9271014784118226}),
},
{
torch::tensor({2.7952372917068753, 2.682820220375722, 2.0687223272686994, 1.967654548778711, 1.4844410726622166, 2.5117888904510117}),
torch::tensor({3.5471904628565745, 2.305113548262141, 2.7307948248967304}),
torch::tensor({-3.2141190290332537, -3.572944633614449, -2.5421970206546827}),
torch::tensor({-2.0029976985219666}),
},
{
torch::tensor({2.8467333937519483, 2.7432785177110395, 2.119898810135385, 2.0256805416741255, 1.5225256464280221, 2.554983108087885}),
torch::tensor({3.632098323876194, 2.383289304179778, 2.7889864719999222}),
torch::tensor({-3.387579944167926, -3.7537658010839294, -2.6423123266260427}),
torch::tensor({-2.075616951445725}),
},
{
torch::tensor({2.8938600498060203, 2.798776984183615, 2.16697381475156, 2.079238538430203, 1.5580820887115123, 2.5954023023969692}),
torch::tensor({3.71043881435304, 2.455919099321953, 2.8436784441941008}),
torch::tensor({-3.5567368287146417, -3.930484868709691, -2.740026479434574}),
torch::tensor({-2.146398256871758}),
},
{
torch::tensor({2.937649394399939, 2.850492965312311, 2.210902777232446, 2.1293746183147633, 1.5917084661873866, 2.6337100421079533}),
torch::tensor({3.7837853328443516, 2.5243155701130604, 2.8957265009949373}),
torch::tensor({-3.7234268485210475, -4.104949193318518, -2.836390693799751}),
torch::tensor({-2.216118773360611}),
},
{
torch::tensor({2.9787316798887558, 2.8991451078473207, 2.252272487010975, 2.1767288729202767, 1.6237627746697592, 2.670302507579268}),
torch::tensor({3.8530980655045086, 2.589275531553025, 2.9456388178450936}),
torch::tensor({-3.8886856459619636, -4.278191888396593, -2.9319964928350313}),
torch::tensor({-2.285217699505124}),
},
{
torch::tensor({3.017515620579049, 2.9452004251453268, 2.291469925522962, 2.2217217025782916, 1.6544730927621272, 2.705431441204422}),
torch::tensor({3.9190041004420166, 2.651317624465938, 2.993736489599992}),
torch::tensor({-4.053111559341913, -4.450801801162238, -3.0271845519131957}),
torch::tensor({-2.3539498905973444}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> RMSprop_with_weight_decay() {
return {
{
torch::tensor({0.7890798754118442, 0.5024321083861885, 0.8587031097835685, 0.6579677474141494, 0.7476297677960806, 1.6975465611838714}),
torch::tensor({0.891458601354904, 0.7020500593937647, 1.6891986479348047}),
torch::tensor({-1.0508027868194278, -1.3941348166291232, -1.2843373988951865}),
torch::tensor({-1.0711379841318796}),
},
{
torch::tensor({0.2139892652405453, 0.2779011713353896, 0.18684802794665187, 0.2507569370785562, 0.19145335235130007, 0.2557687813140708}),
torch::tensor({0.6720959116689083, 0.6480734848064099, 0.654263007004671}),
torch::tensor({-1.4357633640899097, -1.4493557950073235, -1.4619011018357073}),
torch::tensor({-1.9673083558727926}),
},
{
torch::tensor({0.23961935744660673, 0.30354236865888945, 0.19567694278583514, 0.2544696440133763, 0.21982879020261814, 0.27495711471979495}),
torch::tensor({0.6927895724658635, 0.638015535479105, 0.6523245375960234}),
torch::tensor({-1.413722583500382, -1.4170291001633526, -1.4166977298480703}),
torch::tensor({-2.0626651115437147}),
},
{
torch::tensor({0.2506635865272117, 0.314639511428635, 0.25116892910818034, 0.30431399579592144, 0.25219625048710015, 0.3160110008170742}),
torch::tensor({0.7051419232960522, 0.6699011906397543, 0.699097284678438}),
torch::tensor({-1.4206083241624232, -1.4257037444100107, -1.4171061826065132}),
torch::tensor({-2.075537874763694}),
},
{
torch::tensor({0.23285924743063605, 0.29652494304777544, 0.2335322002738168, 0.2969991261380461, 0.23358272245229555, 0.2973997498166104}),
torch::tensor({0.6855589594925036, 0.6796983775695974, 0.6864174803983276}),
torch::tensor({-1.43110762794651, -1.4334934742818164, -1.422739552145125}),
torch::tensor({-2.0842642493046184}),
},
{
torch::tensor({0.23356397699389828, 0.29737142391985355, 0.23367622061822368, 0.29749447597160267, 0.23418481357395918, 0.29818122925156104}),
torch::tensor({0.6866530583001205, 0.6858933385102559, 0.6883944045412603}),
torch::tensor({-1.4564955509607018, -1.4583548131500643, -1.4418225445708595}),
torch::tensor({-2.1064103749186183}),
},
{
torch::tensor({0.2318717301174723, 0.2952904159872858, 0.23194024439476665, 0.29537019824987687, 0.2316421336904657, 0.2951041425983894}),
torch::tensor({0.6834813130194509, 0.6834401711464199, 0.6837275457100463}),
torch::tensor({-1.4647835805276763, -1.4653452408179053, -1.4571142112777709}),
torch::tensor({-2.1209598505912086}),
},
{
torch::tensor({0.2308683396504178, 0.2940474629750448, 0.23089067678260966, 0.29407615110959306, 0.23064069314214175, 0.29379043611390243}),
torch::tensor({0.6815062281792611, 0.6815233687209215, 0.6812759203026146}),
torch::tensor({-1.4643013018530682, -1.4644523635284246, -1.4617493939684878}),
torch::tensor({-2.1247293635678854}),
},
{
torch::tensor({0.23066464201462678, 0.29376059273730426, 0.23067069245857366, 0.2937690399784267, 0.23057551211606675, 0.2936477517373107}),
torch::tensor({0.6809028781780304, 0.6809134105028244, 0.6807404613096301}),
torch::tensor({-1.4637927352177986, -1.4638374228010727, -1.46299287102643}),
torch::tensor({-2.1258082720638107}),
},
{
torch::tensor({0.23062625079199173, 0.2936990787425707, 0.2306278924729115, 0.2937014834661651, 0.23059813368157003, 0.29366073890476396}),
torch::tensor({0.6807251804689082, 0.6807295616357246, 0.6806523640328994}),
torch::tensor({-1.4635790398985618, -1.463592926902286, -1.4633272688236565}),
torch::tensor({-2.1261396358141798}),
},
{
torch::tensor({0.23061701122700193, 0.293683983817782, 0.23061747865501653, 0.29368467998690806, 0.23060855595638208, 0.2936719507340021}),
torch::tensor({0.6806714673830832, 0.6806730903175793, 0.6806434720800856}),
torch::tensor({-1.4635008778278134, -1.4635052859178375, -1.4634208375068285}),
torch::tensor({-2.1262432969587723}),
},
};
}
inline std::vector<std::vector<torch::Tensor>> RMSprop_with_weight_decay_and_centered() {
return {
{
torch::tensor({0.7941000061626792, 0.507452636734552, 0.8637405354185987, 0.663005089317529, 0.7526661272860107, 1.7025887305065852}),
torch::tensor({0.8964950370033696, 0.7070877948157552, 1.6942369105467197}),
torch::tensor({-1.055840599214661, -1.3991726335388424, -1.2893752132746332}),
torch::tensor({-1.0761757981162612}),
},
{
torch::tensor({2.3762999876885833, 2.239095829416783, 1.726175067071914, 1.5891569459230444, 1.2410074108588462, 2.2345431036725723}),
torch::tensor({2.990896455635836, 1.8152108764849464, 2.377985429759037}),
torch::tensor({-2.3071822180635286, -2.636859516619699, -2.0198181394256642}),
torch::tensor({-1.622583045791722}),
},
{
torch::tensor({2.372800588647971, 2.3022753207224254, 1.836028714221617, 1.7190937269287105, 1.3068955839895078, 2.3035835673200364}),
torch::tensor({3.1656599892042343, 1.9942937608209466, 2.4947143457182657}),
torch::tensor({-2.6139790332516775, -2.9507738987695404, -2.1954425128779516}),
torch::tensor({-1.7513053380188806}),
},
{
torch::tensor({2.2398453700818455, 2.2513384246965904, 1.8892176431436287, 1.7921873754661686, 1.3310951408713538, 2.3236392222350397}),
torch::tensor({3.240166119454613, 2.109742813600189, 2.5651614461576973}),
torch::tensor({-2.8388734382997454, -3.1824200770676123, -2.324831397600949}),
torch::tensor({-1.8460315737386976}),
},
{
torch::tensor({1.9829606312242465, 2.097356567850692, 1.9050263843525033, 1.8325835415812346, 1.3222762370713104, 2.3024963133870147}),
torch::tensor({3.2465360572089974, 2.1967266045869915, 2.6091992649970672}),
torch::tensor({-3.0326878099587207, -3.3827004807595005, -2.436989182250496}),
torch::tensor({-1.928273216206344}),
},
{
torch::tensor({1.6051175329080525, 1.8332107491649114, 1.8794767349053179, 1.8403588051948856, 1.273824111314107, 2.2296571379436823}),
torch::tensor({3.1814362940910437, 2.263019214072847, 2.6273016977574013}),
torch::tensor({-3.210932646440219, -3.567153254014387, -2.541016943923914}),
torch::tensor({-2.0049155134617154}),
},
{
torch::tensor({1.1588059349082709, 1.477861379523226, 1.7992410089026634, 1.806460009198667, 1.1739931551629919, 2.08647960875392}),
torch::tensor({3.03843703712275, 2.308203068375877, 2.6125393914734083}),
torch::tensor({-3.379830678608588, -3.741970414470626, -2.6410082400846546}),
torch::tensor({-2.079294995910487}),
},
{
torch::tensor({0.7701433312419088, 1.1105026677424745, 1.646507516936639, 1.71625269098179, 1.013748545414221, 1.8532966501655352}),
torch::tensor({2.827176875885245, 2.327401948159928, 2.5535309398603405}),
torch::tensor({-3.54193329850986, -3.9096652952123145, -2.739408870192437}),
torch::tensor({-2.1537939241668997}),
},
{
torch::tensor({0.5598923129351211, 0.8460500042788701, 1.4084175549165017, 1.5547314210944563, 0.8019580519338424, 1.5258384663629627}),
torch::tensor({2.5774950379490265, 2.313101306699127, 2.4388695757441745}),
torch::tensor({-3.6974974230160087, -4.070190514312716, -2.8378932675718405}),
torch::tensor({-2.2307225014430423}),
},
{
torch::tensor({0.5016784472836648, 0.7258690889265433, 1.0976902935953956, 1.319949187972513, 0.5853930356154851, 1.1446978015944624}),
torch::tensor({2.3235249877284945, 2.2592840970420176, 2.2681461698609375}),
torch::tensor({-3.8444921272569115, -4.22021051361099, -2.9373192115434263}),
torch::tensor({-2.312733063937045}),
},
{
torch::tensor({0.4875468895095056, 0.6878747871467128, 0.7787871237567606, 1.0462592546102176, 0.4416468896022397, 0.8122992916762792}),
torch::tensor({2.1078734515587483, 2.17034337037527, 2.0666325968568535}),
torch::tensor({-3.9782695475825216, -4.352093055115415, -3.0377809502927033}),
torch::tensor({-2.403496388200805}),
},
};