A PyTorch-based library for functional basis representations and smooth function approximation.
- B-spline basis functions with arbitrary order
- Fourier basis functions
- Smooth function approximation with penalized regression
- Batch processing support
- GPU acceleration through PyTorch
https://github.com/mynanshan/TorchFuncBasis.git
cd TorchFuncBasis
pip install torchfuncbasis
import torch
from torchfuncbasis.basis import BSplineBasis
from torchfuncbasis.smoother import points2basiscoefs
basis = BSplineBasis(n_basis=11, domain_range=(0, 1), order=4)
print(basis)
A Basis
object can be called to evaluate the basis matrix or its derivatives. The input is allowed to be a batch of points with shape (*batch, n_points, *dim_domain)
:
batch
is arbitrary leading dimensions, can be emptyn_points
is the number of points in the domaindim_domain
is the dimension of the domain, can be dropped if the domain is one-dimensional
x = torch.linspace(0, 1, 101).unsqueeze(0).repeat(5, 7, 1)
print(f"x's shape: {x.shape}")
basis_matrix = basis(x)
print(f"basis_matrix.shape: {basis_matrix.shape}")
basis_deriv_matrix = basis(x, derivative=1)
gram_matrix = basis.gram_matrix()
print(f"gram_matrix.shape: {gram_matrix.shape}")
gram_matrix_deriv = basis.gram_matrix(derivative=1)
Suppose we have a set of points (x, y)
and a basis object basis
. The y
is expected to have shape (*batch, n_points, *dim_response)
, where
*batch
andn_points
are the same as those inx
dim_response
is the dimension of the response variable, can be dropped if the response is one-dimensional
We can fit a smooth function by solving a penalized least-squares problem with points2basiscoefs
. The returned value is the basis coefficients of shape (*batch, dim_response, n_basis)
.
x = torch.linspace(0, 1, 101).unsqueeze(0).repeat(5, 7, 1)
print(f"x's shape: {x.shape}")
y = torch.cat([
(torch.sin(2 * torch.pi * x) + torch.randn_like(x) * 0.05).unsqueeze(-1),
(torch.cos(2 * torch.pi * x) + torch.randn_like(x) * 0.05).unsqueeze(-1)
], dim=-1)
print(f"y's shape: {y.shape}")
coefs = points2basiscoefs(x, y, basis, smoothing_param=1e-4)
print(f"coefs.shape: {coefs.shape}")
- scikit-fda: a Numpy-based library for comprehensive functional data analysis, including basis functions and smoothing methods.
scikit-fda
is our main reference for code architecture. - pykan/spline.py for PyTorch-based B-spline implementation.
- Other univariate basis functions
- Tensor-product basis functions
- A wrapper for smoothing