-
Notifications
You must be signed in to change notification settings - Fork 0
/
taylor.py
118 lines (71 loc) · 2.25 KB
/
taylor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import matplotlib.pyplot as plt
import sys
import torch
import torch.nn as nn
import torch.optim as optim
sys.path.append('..')
from animate import *
X = torch.tensor(np.linspace(-3.14, 3.14, 1000), dtype=torch.float32)
X = X.reshape(X.shape + (1,))
y = torch.nn.ReLU()(torch.sin(X)) + torch.nn.ReLU()(-torch.cos(X))
x = X/3.14#normalize
class activate(nn.Module):
def __init__(self):
super(activate, self).__init__()
def forward(self, x):
return torch.sin(x) * x
activation = nn.LeakyReLU(-1.2)
class taylor_encoder(nn.Module):
def __init__(self, n):
super(taylor_encoder, self).__init__()
self.exponents = torch.arange(n)
def forward(self, x):
return torch.pow(x, self.exponents)
class taylor_decoder(nn.Module):
def __init__(self, x):
super(taylor_decoder, self).__init__()
self.inputs = x
def forward(self, a):
out = torch.sum(a*self.inputs, dim=1)
return out.reshape(out.shape + (1, ))
class taylor(nn.Module):
def __init__(self, n, sequential):
super(taylor, self).__init__()
self.user_seq = sequential
self.n = n
def raw_forward(self, x):
features = taylor_encoder(self.n)(x)
coeff = self.user_seq(features)
return coeff
def forward(self, x):
features = taylor_encoder(self.n)(x)
coeff = self.user_seq(features)
return taylor_decoder(features)(coeff)
n = 10 #features
seq = nn.Sequential(
nn.Linear(n, 16, bias=True),
nn.LeakyReLU(-1.2),
nn.Linear(16, 64, bias=True),
nn.LeakyReLU(-1.2),
nn.Linear(64, 16, bias=True),
nn.LeakyReLU(-1.2),
nn.Linear(16, n, bias=True)
)
model = taylor(n, seq)
loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.004)
screen = Screen(1900, 1000, "test", 60, (0, 10))
def render_callback():
global pred_y, X, y, epoch
plot(X, y, (1, 0, 0))
plot(X, pred_y, (1, 0, 1))
rendertext(f"epoch = {epoch}", (600, 800))
for epoch in range(6000):
print(f"epoch = {epoch+1}", end='\r')
pred_y = model(x)
loss = loss_function(pred_y, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
screen.mainloop(render_callback)