-
Notifications
You must be signed in to change notification settings - Fork 36
/
ChebyKANLayer_unoptimized.py
30 lines (27 loc) · 1.38 KB
/
ChebyKANLayer_unoptimized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
# Its an unoptimized version of ChebyKANLayer.py, see issue #3 for more details
class ChebyKANLayer(nn.Module):
def __init__(self, input_dim, output_dim, degree):
super(ChebyKANLayer, self).__init__()
self.inputdim = input_dim
self.outdim = output_dim
self.degree = degree
self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1/(input_dim * (degree + 1)))
def forward(self, x):
x = torch.reshape(x, (-1, self.inputdim)) # shape = (batch_size, inputdim)
# Since Chebyshev polynomial is defined in [-1, 1]
# We need to normalize x to [-1, 1] using tanh
x = torch.tanh(x)
# Initialize Chebyshev polynomial tensors
cheby = torch.ones(x.shape[0], self.inputdim, self.degree + 1, device=x.device)
if self.degree > 0:
cheby[:, :, 1] = x
for i in range(2, self.degree + 1):
cheby[:, :, i] = 2 * x * cheby[:, :, i - 1].clone() - cheby[:, :, i - 2].clone()
# Compute the Chebyshev interpolation
y = torch.einsum('bid,iod->bo', cheby, self.cheby_coeffs) # shape = (batch_size, outdim)
y = y.view(-1, self.outdim)
return y