Skip to content

Commit

Permalink
train to fit
Browse files Browse the repository at this point in the history
  • Loading branch information
KindXiaoming committed Jul 14, 2024
2 parents 994d482 + e5abcb7 commit c24cf75
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 26 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
<img width="600" alt="kan_plot" src="https://github.com/KindXiaoming/pykan/assets/23551623/a2d2d225-b4d2-4c1e-823e-bc45c7ea96f9">

## Notice: The current version uses MultKAN instead of KAN. The use of argument and methods are almost the same except for the class name. More documents soon, a quick tutorial on MultKAN [here](https://github.com/KindXiaoming/pykan/blob/master/tutorials/MultKAN_tutorial.ipynb).

# Kolmogorov-Arnold Networks (KANs)

This is the github repo for the paper ["KAN: Kolmogorov-Arnold Networks"](https://arxiv.org/abs/2404.19756). Find the documentation [here](https://kindxiaoming.github.io/pykan/). Here's [author's note](https://github.com/KindXiaoming/pykan?tab=readme-ov-file#authors-note) responding to current hype of KANs.
Expand Down
24 changes: 10 additions & 14 deletions kan/KAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, scale_base_mu=0.0,
# splines
#scale_base = 1 / np.sqrt(width[l]) + (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * noise_scale_base
scale_base = scale_base_mu * 1 / np.sqrt(width[l]) + \
scale_base_sigma * (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * 1/np.sqrt(width[l])
scale_base_sigma * (torch.randn(width[l] , width[l + 1], ) * 2 - 1) * 1/np.sqrt(width[l])
sp_batch = KANLayer(in_dim=width[l], out_dim=width[l + 1], num=grid, k=k, noise_scale=noise_scale, scale_base=scale_base, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable,
sb_trainable=sb_trainable, device=device)
self.act_fun.append(sp_batch)
Expand Down Expand Up @@ -757,7 +757,7 @@ def score2alpha(score):
plt.gcf().get_axes()[0].text(0.5, y0 * (len(self.width) - 1) + 0.2, title, fontsize=40 * scale, horizontalalignment='center', verticalalignment='center')

def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=0., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., stop_grid_update_step=50, batch=-1,
small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu'):
small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video'):
'''
training
Expand Down Expand Up @@ -793,8 +793,6 @@ def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lam
threshold to determine large or small numbers (may want to apply larger penalty to smaller numbers)
small_reg_factor : float
penalty strength applied to small factors relative to large factos
device : str
device
save_fig_freq : int
save figure every (save_fig_freq) step
Expand Down Expand Up @@ -865,19 +863,18 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):
batch_size_test = dataset['test_input'].shape[0]
else:
batch_size = batch
batch_size_test = batch

global train_loss, reg_

def closure():
global train_loss, reg_
optimizer.zero_grad()
pred = self.forward(dataset['train_input'][train_id].to(device))
pred = self.forward(dataset['train_input'][train_id].to(self.device))
if sglr_avoid == True:
id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0]
train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device))
train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(self.device))
else:
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
reg_ = reg(self.acts_scale)
objective = train_loss + lamb * reg_
objective.backward()
Expand All @@ -890,28 +887,27 @@ def closure():
for _ in pbar:

train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)

if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
self.update_grid_from_samples(dataset['train_input'][train_id].to(self.device))

if opt == "LBFGS":
optimizer.step(closure)

if opt == "Adam":
pred = self.forward(dataset['train_input'][train_id].to(device))
pred = self.forward(dataset['train_input'][train_id].to(self.device))
if sglr_avoid == True:
id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0]
train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device))
train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(self.device))
else:
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
reg_ = reg(self.acts_scale)
loss = train_loss + lamb * reg_
optimizer.zero_grad()
loss.backward()
optimizer.step()

test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(device)), dataset['test_label'][test_id].to(device))
test_loss = loss_fn_eval(self.forward(dataset['test_input'].to(self.device)), dataset['test_label'].to(self.device))

if _ % log == 0:
pbar.set_description("train loss: %.2e | test loss: %.2e | reg: %.2e " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
Expand Down
4 changes: 3 additions & 1 deletion kan/KANLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=
mask = sparse_mask(in_dim, out_dim)
else:
mask = 1.

scale_base = scale_base.to(device)
self.scale_base = torch.nn.Parameter(torch.ones(in_dim, out_dim, device=device) * scale_base * mask).requires_grad_(sb_trainable) # make scale trainable
#else:
#self.scale_base = torch.nn.Parameter(scale_base.to(device)).requires_grad_(sb_trainable)
Expand Down Expand Up @@ -224,7 +226,7 @@ def update_grid_from_samples(self, x):
grid_adaptive = x_pos[ids, :].permute(1,0)
margin = 0.01
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :]
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,).to(self.device)[None, :]
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
self.grid.data = extend_grid(grid, k_extend=self.k)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k, device=self.device)
Expand Down
15 changes: 10 additions & 5 deletions kan/MultKAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,8 @@ def reg(acts_scale):
def closure():
global train_loss, reg_
optimizer.zero_grad()
pred = self.forward(dataset['train_input'][train_id].to(device), singularity_avoiding=singularity_avoiding, y_th=y_th)
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
pred = self.forward(dataset['train_input'][train_id].to(self.device), singularity_avoiding=singularity_avoiding, y_th=y_th)
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
if self.save_plot_data:
if reg_metric == 'act':
reg_ = reg(self.acts_scale_spline)
Expand All @@ -803,15 +803,20 @@ def closure():
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)

<<<<<<< HEAD
if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step:
self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
=======
if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ > start_grid_update_step:
self.update_grid_from_samples(dataset['train_input'][train_id].to(self.device))
>>>>>>> e5abcb74c6cdc6af70d9665ec5a7b4ccb94ee564

if opt == "LBFGS":
optimizer.step(closure)

if opt == "Adam":
pred = self.forward(dataset['train_input'][train_id].to(device), singularity_avoiding=singularity_avoiding, y_th=y_th)
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
pred = self.forward(dataset['train_input'][train_id].to(self.device), singularity_avoiding=singularity_avoiding, y_th=y_th)
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
if self.save_plot_data:
if reg_metric == 'act':
reg_ = reg(self.acts_scale_spline)
Expand All @@ -825,7 +830,7 @@ def closure():
loss.backward()
optimizer.step()

test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(device)), dataset['test_label'][test_id].to(device))
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)), dataset['test_label'][test_id].to(self.device))

if _ % log == 0:
pbar.set_description("train loss: %.2e | test loss: %.2e | reg: %.2e " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
Expand Down
7 changes: 4 additions & 3 deletions kan/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def extend_grid(grid, k_extend=0):
value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (
grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]'''

x = x.unsqueeze(dim=2)
grid = grid.unsqueeze(dim=0)
x = x.unsqueeze(dim=2).to(device)
grid = grid.unsqueeze(dim=0).to(device)

if k == 0:
value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
else:
Expand Down Expand Up @@ -113,7 +114,7 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"):
coef = coef.to(x_eval.dtype)
y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))'''

b_splines = B_batch(x_eval, grid, k=k) # (batch, in_dim, n_coef)
b_splines = B_batch(x_eval, grid, k=k).to(device) # (batch, in_dim, n_coef)
# coef (in_dim, out_dim, n_coef)
#print(b_splines.shape, coef.shape)
y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setuptools.setup(
name="pykan",
version="0.1.1",
version="0.1.2",
author="Ziming Liu",
author_email="[email protected]",
description="Kolmogorov Arnold Networks",
Expand Down
294 changes: 294 additions & 0 deletions tutorials/API_3_grid.ipynb

Large diffs are not rendered by default.

0 comments on commit c24cf75

Please sign in to comment.