Skip to content

Commit

Permalink
Updated: Added GKAN and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
danielshin1 committed Dec 15, 2024
1 parent 2b1b327 commit 1fca3a9
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 0 deletions.
69 changes: 69 additions & 0 deletions examples/gkan_ex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os.path as osp
import torch
from torch.nn import CrossEntropyLoss
from torch_geometric.datasets import Planetoid
from torch_geometric.nn.models.gkan import GKANModel

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the Cora dataset
path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data", "Cora")
dataset = Planetoid(root=path, name="Cora")
data = dataset[0].to(device)

# Model configuration
input_dim = dataset.num_features
hidden_dim = 64
output_dim = dataset.num_classes
num = 5 # Number of grid intervals for KANLayer
k = 3 # Spline polynomial order
num_layers = 3
architecture = 2 # 1 for Aggregate->KAN, 2 for KAN->Aggregate

# Initialize the GKANModel
model = GKANModel(
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
num_layers=num_layers,
num=num, # Grid intervals for KANLayer
k=k, # Spline polynomial order
architecture=architecture,
).to(device)

# Optimizer and Loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = CrossEntropyLoss()


def train():
"""Train the GKAN model."""
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index) # Forward pass
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()


@torch.no_grad()
def test():
"""Evaluate the GKAN model."""
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
return acc


# Training Loop
print("Training GKANModel on Cora...")
for epoch in range(1, 201): # 200 epochs
loss = train()
print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")

# Testing
accuracy = test()
print(f"\nTest Accuracy: {accuracy:.4f}")
51 changes: 51 additions & 0 deletions test/nn/models/test_gkan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from torch_geometric.nn.models.gkan import GKANModel


def test_gkan() -> None:
num_nodes = 10
num_features = 128
num_classes = 3
num_layers = 3
num = 5 # Number of grid intervals
k = 3 # Polynomial order of splines
architecture = 2 # KAN -> Aggregate architecture

# Mock input data
x = torch.randn(num_nodes, num_features)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
], dtype=torch.long)

# Initialize GKANModel
model = GKANModel(
input_dim=num_features,
hidden_dim=64,
output_dim=num_classes,
num_layers=num_layers,
num=num, # Spline grid intervals
k=k, # Spline polynomial order
architecture=architecture
)

# Check model string representation
assert str(model).startswith('GKANModel')

# Check forward pass in training mode
model.train()
out = model(x, edge_index)
assert out.size() == (num_nodes, num_classes)
assert out.min().item() >= -10 and out.max().item() <= 10

# Check forward pass in evaluation mode
model.eval()
out = model(x, edge_index)
assert out.size() == (num_nodes, num_classes)
assert out.min().item() >= -5 and out.max().item() <= 5

print("GKANModel test passed successfully!")


if __name__ == "__main__":
test_gkan()
168 changes: 168 additions & 0 deletions torch_geometric/nn/models/gkan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing


# Spline-related utility functions
def B_batch(x, grid, k=0):
"""
Evaluate x on B-spline bases.
"""
x = x.unsqueeze(dim=2)
grid = grid.unsqueeze(dim=0)

if k == 0:
value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
else:
B_km1 = B_batch(x[:, :, 0], grid=grid[0], k=k - 1)
value = (
(x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1]
+ (grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:]
)

# Handle degenerate cases
value = torch.nan_to_num(value)
return value


def coef2curve(x_eval, grid, coef, k):
"""
Convert B-spline coefficients to B-spline curves and evaluate x on them.
"""
b_splines = B_batch(x_eval, grid, k=k)
y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))
return y_eval


def curve2coef(x_eval, y_eval, grid, k):
"""
Convert B-spline curves to B-spline coefficients using least squares.
"""
batch = x_eval.shape[0]
in_dim = x_eval.shape[1]
out_dim = y_eval.shape[2]
n_coef = grid.shape[1] - k - 1
mat = B_batch(x_eval, grid, k)
mat = mat.permute(1, 0, 2)[:, None, :, :].expand(in_dim, out_dim, batch, n_coef)
y_eval = y_eval.permute(1, 2, 0).unsqueeze(dim=3)

# Solve least squares
coef = torch.linalg.lstsq(mat, y_eval).solution[:, :, :, 0]
return coef


def extend_grid(grid, k_extend=0):
"""
Extend the grid by k points on both sides for smoother splines.
"""
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)

for i in range(k_extend):
grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)

return grid


# KANLayer
class KANLayer(nn.Module):
"""
Implements a spline-based activation layer for GKAN.
"""
def __init__(self, in_dim, out_dim, num=5, k=3, noise_scale=0.5, grid_range=[-1, 1], base_fun=torch.nn.SiLU()):
super(KANLayer, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.num = num
self.k = k
self.base_fun = base_fun

# Initialize grid with k-extended values
grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None, :].expand(in_dim, num + 1)
extended_grid = extend_grid(grid, k_extend=k)
self.register_buffer('grid', extended_grid) # Non-trainable grid

# Initialize spline coefficients with noise
noises = (torch.rand(num + 1, in_dim, out_dim) - 0.5) * noise_scale / num
self.coef = nn.Parameter(curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k))

# Trainable scaling parameters
self.scale_base = nn.Parameter(torch.ones(in_dim, out_dim))
self.scale_sp = nn.Parameter(torch.ones(in_dim, out_dim))

def forward(self, x):
"""
Forward pass through the KANLayer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_dim).
Returns:
torch.Tensor: Output tensor of shape (batch_size, out_dim).
"""
base = self.base_fun(x) # Residual function b(x)
y = coef2curve(x, grid=self.grid, coef=self.coef, k=self.k) # Spline function
y = self.scale_base[None, :, :] * base[:, :, None] + self.scale_sp[None, :, :] * y
return torch.sum(y, dim=1)


# GKAN Layer
class GKAN(MessagePassing):
"""
Graph Kolmogorov-Arnold Network (GKAN) layer.
"""
def __init__(self, input_dim, hidden_dim, output_dim, num=5, k=3, architecture=2):
super(GKAN, self).__init__(aggr="add")
self.architecture = architecture

# KAN Layer (spline-based activation)
self.kan_layer = KANLayer(input_dim, hidden_dim, num=num, k=k)

# Linear transformation for output
self.linear = nn.Linear(hidden_dim, output_dim)

def forward(self, x, edge_index):
"""
Forward pass for GKAN layer.
"""
if self.architecture == 1:
# Aggregate first, then apply KAN
x = self.propagate(edge_index, x=x)
x = self.kan_layer(x)
elif self.architecture == 2:
# Apply KAN first, then aggregate
x = self.kan_layer(x)
x = self.propagate(edge_index, x=x)
return self.linear(x)

def message(self, x_j):
return x_j

def update(self, aggr_out):
return aggr_out


# GKAN Model
class GKANModel(nn.Module):
"""
Graph Kolmogorov-Arnold Network (GKAN) model for node classification.
"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, num=5, k=3, architecture=2):
super(GKANModel, self).__init__()
self.layers = nn.ModuleList()

# Input layer
self.layers.append(GKAN(input_dim, hidden_dim, hidden_dim, num=num, k=k, architecture=architecture))

# Hidden layers
for _ in range(num_layers - 2):
self.layers.append(GKAN(hidden_dim, hidden_dim, hidden_dim, num=num, k=k, architecture=architecture))

# Output layer
self.layers.append(GKAN(hidden_dim, hidden_dim, output_dim, num=num, k=k, architecture=architecture))

def forward(self, x, edge_index):
for layer in self.layers:
x = F.relu(layer(x, edge_index))
return x

0 comments on commit 1fca3a9

Please sign in to comment.