Skip to content

Commit

Permalink
Merge pull request #148 from Jim137/develop
Browse files Browse the repository at this point in the history
Fix dtype error in `curve2coef`
  • Loading branch information
KindXiaoming authored May 10, 2024
2 parents 41aca57 + e4f3f8b commit e6078bc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion kan/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,6 @@ def curve2coef(x_eval, y_eval, grid, k, device="cpu"):
torch.Size([5, 13])
'''
# x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar
mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1)
mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1).to(y_eval.dtype)
coef = torch.linalg.lstsq(mat.to('cpu'), y_eval.unsqueeze(dim=2).to('cpu')).solution[:, :, 0] # sometimes 'cuda' version may diverge
return coef.to(device)

0 comments on commit e6078bc

Please sign in to comment.