-
Notifications
You must be signed in to change notification settings - Fork 36
/
ChebyKANLayer.py
36 lines (32 loc) · 1.36 KB
/
ChebyKANLayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
class ChebyKANLayer(nn.Module):
def __init__(self, input_dim, output_dim, degree):
super(ChebyKANLayer, self).__init__()
self.inputdim = input_dim
self.outdim = output_dim
self.degree = degree
self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
nn.init.normal_(self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))
self.register_buffer("arange", torch.arange(0, degree + 1, 1))
def forward(self, x):
# Since Chebyshev polynomial is defined in [-1, 1]
# We need to normalize x to [-1, 1] using tanh
x = torch.tanh(x)
# View and repeat input degree + 1 times
x = x.view((-1, self.inputdim, 1)).expand(
-1, -1, self.degree + 1
) # shape = (batch_size, inputdim, self.degree + 1)
# Apply acos
x = x.acos()
# Multiply by arange [0 .. degree]
x *= self.arange
# Apply cos
x = x.cos()
# Compute the Chebyshev interpolation
y = torch.einsum(
"bid,iod->bo", x, self.cheby_coeffs
) # shape = (batch_size, outdim)
y = y.view(-1, self.outdim)
return y