diff --git a/kan/spline.py b/kan/spline.py index 6f64d883..1992729d 100644 --- a/kan/spline.py +++ b/kan/spline.py @@ -133,6 +133,6 @@ def curve2coef(x_eval, y_eval, grid, k, device="cpu"): torch.Size([5, 13]) ''' # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar - mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1) + mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1).to(y_eval.dtype) coef = torch.linalg.lstsq(mat.to('cpu'), y_eval.unsqueeze(dim=2).to('cpu')).solution[:, :, 0] # sometimes 'cuda' version may diverge return coef.to(device)