Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added GKAN #9870

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/gkan_ex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os.path as osp

import torch
from torch.nn import CrossEntropyLoss

from torch_geometric.datasets import Planetoid
from torch_geometric.nn.models.gkan import GKANModel

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the Cora dataset
path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data", "Cora")
dataset = Planetoid(root=path, name="Cora")
data = dataset[0].to(device)

# Model configuration
input_dim = dataset.num_features
hidden_dim = 64
output_dim = dataset.num_classes
num = 5 # Number of grid intervals for KANLayer
k = 3 # Spline polynomial order
num_layers = 3
architecture = 2 # 1 for Aggregate->KAN, 2 for KAN->Aggregate

# Initialize the GKANModel
model = GKANModel(
input_dim=input_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
num_layers=num_layers,
num=num, # Grid intervals for KANLayer
k=k, # Spline polynomial order
architecture=architecture,
).to(device)

# Optimizer and Loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = CrossEntropyLoss()


def train():
"""Train the GKAN model."""
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index) # Forward pass
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()


@torch.no_grad()
def test():
"""Evaluate the GKAN model."""
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]
).sum().item() / data.test_mask.sum().item()
return acc


# Training Loop
print("Training GKANModel on Cora...")
for epoch in range(1, 201): # 200 epochs
loss = train()
print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")

# Testing
accuracy = test()
print(f"\nTest Accuracy: {accuracy:.4f}")
51 changes: 51 additions & 0 deletions test/nn/models/test_gkan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch

from torch_geometric.nn.models.gkan import GKANModel


def test_gkan() -> None:
num_nodes = 10
num_features = 128
num_classes = 3
num_layers = 3
num = 5 # Number of grid intervals
k = 3 # Polynomial order of splines
architecture = 2 # KAN -> Aggregate architecture

# Mock input data
x = torch.randn(num_nodes, num_features)
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
], dtype=torch.long)

# Initialize GKANModel
model = GKANModel(
input_dim=num_features,
hidden_dim=64,
output_dim=num_classes,
num_layers=num_layers,
num=num, # Spline grid intervals
k=k, # Spline polynomial order
architecture=architecture)

# Check model string representation
assert str(model).startswith('GKANModel')

# Check forward pass in training mode
model.train()
out = model(x, edge_index)
assert out.size() == (num_nodes, num_classes)
assert out.min().item() >= -10 and out.max().item() <= 10

# Check forward pass in evaluation mode
model.eval()
out = model(x, edge_index)
assert out.size() == (num_nodes, num_classes)
assert out.min().item() >= -5 and out.max().item() <= 5

print("GKANModel test passed successfully!")


if __name__ == "__main__":
test_gkan()
176 changes: 176 additions & 0 deletions torch_geometric/nn/models/gkan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import MessagePassing


# Spline-related utility functions
def B_batch(x, grid, k=0):
"""Evaluate x on B-spline bases.
"""
x = x.unsqueeze(dim=2)
grid = grid.unsqueeze(dim=0)

if k == 0:
value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
else:
B_km1 = B_batch(x[:, :, 0], grid=grid[0], k=k - 1)
value = (
(x - grid[:, :, :-(k + 1)]) /
(grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] +
(grid[:, :, k + 1:] - x) /
(grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:])

# Handle degenerate cases
value = torch.nan_to_num(value)
return value


def coef2curve(x_eval, grid, coef, k):
"""Convert B-spline coefficients to B-spline curves and evaluate x on them.
"""
b_splines = B_batch(x_eval, grid, k=k)
y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))
return y_eval


def curve2coef(x_eval, y_eval, grid, k):
"""Convert B-spline curves to B-spline coefficients using least squares.
"""
batch = x_eval.shape[0]
in_dim = x_eval.shape[1]
out_dim = y_eval.shape[2]
n_coef = grid.shape[1] - k - 1
mat = B_batch(x_eval, grid, k)
mat = mat.permute(1, 0, 2)[:, None, :, :].expand(in_dim, out_dim, batch,
n_coef)
y_eval = y_eval.permute(1, 2, 0).unsqueeze(dim=3)

# Solve least squares
coef = torch.linalg.lstsq(mat, y_eval).solution[:, :, :, 0]
return coef


def extend_grid(grid, k_extend=0):
"""Extend the grid by k points on both sides for smoother splines.
"""
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)

for i in range(k_extend):
grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)

return grid


# KANLayer
class KANLayer(nn.Module):
"""Implements a spline-based activation layer for GKAN.
"""
def __init__(self, in_dim, out_dim, num=5, k=3, noise_scale=0.5,
grid_range=[-1, 1], base_fun=torch.nn.SiLU()):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.num = num
self.k = k
self.base_fun = base_fun

# Initialize grid with k-extended values
grid = torch.linspace(grid_range[0], grid_range[1],
steps=num + 1)[None, :].expand(in_dim, num + 1)
extended_grid = extend_grid(grid, k_extend=k)
self.register_buffer('grid', extended_grid) # Non-trainable grid

# Initialize spline coefficients with noise
noises = (torch.rand(num + 1, in_dim, out_dim) -
0.5) * noise_scale / num
self.coef = nn.Parameter(
curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k))

# Trainable scaling parameters
self.scale_base = nn.Parameter(torch.ones(in_dim, out_dim))
self.scale_sp = nn.Parameter(torch.ones(in_dim, out_dim))

def forward(self, x):
"""Forward pass through the KANLayer.

Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_dim).

Returns:
torch.Tensor: Output tensor of shape (batch_size, out_dim).
"""
base = self.base_fun(x) # Residual function b(x)
y = coef2curve(x, grid=self.grid, coef=self.coef,
k=self.k) # Spline function
y = self.scale_base[None, :, :] * base[:, :, None] + self.scale_sp[
None, :, :] * y
return torch.sum(y, dim=1)


# GKAN Layer
class GKAN(MessagePassing):
"""Graph Kolmogorov-Arnold Network (GKAN) layer.
"""
def __init__(self, input_dim, hidden_dim, output_dim, num=5, k=3,
architecture=2):
super().__init__(aggr="add")
self.architecture = architecture

# KAN Layer (spline-based activation)
self.kan_layer = KANLayer(input_dim, hidden_dim, num=num, k=k)

# Linear transformation for output
self.linear = nn.Linear(hidden_dim, output_dim)

def forward(self, x, edge_index):
"""Forward pass for GKAN layer.
"""
if self.architecture == 1:
# Aggregate first, then apply KAN
x = self.propagate(edge_index, x=x)
x = self.kan_layer(x)
elif self.architecture == 2:
# Apply KAN first, then aggregate
x = self.kan_layer(x)
x = self.propagate(edge_index, x=x)
return self.linear(x)

def message(self, x_j):
return x_j

def update(self, aggr_out):
return aggr_out


# GKAN Model
class GKANModel(nn.Module):
"""Graph Kolmogorov-Arnold Network (GKAN) model for node classification.
"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, num=5,
k=3, architecture=2):
super().__init__()
self.layers = nn.ModuleList()

# Input layer
self.layers.append(
GKAN(input_dim, hidden_dim, hidden_dim, num=num, k=k,
architecture=architecture))

# Hidden layers
for _ in range(num_layers - 2):
self.layers.append(
GKAN(hidden_dim, hidden_dim, hidden_dim, num=num, k=k,
architecture=architecture))

# Output layer
self.layers.append(
GKAN(hidden_dim, hidden_dim, output_dim, num=num, k=k,
architecture=architecture))

def forward(self, x, edge_index):
for layer in self.layers:
x = F.relu(layer(x, edge_index))
return x
Loading