Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding other KAN layers #3

Merged
merged 17 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,32 @@
Manifest.toml
# Files generated by invoking Julia with --code-coverage
*.jl.cov
*.jl.*.cov

# Files generated by invoking Julia with --track-allocation
*.jl.mem

# System-specific files and directories generated by the BinaryProvider and BinDeps packages
# They contain absolute paths specific to the host computer, and so should not be committed
deps/deps.jl
deps/build.log
deps/downloads/
deps/usr/
deps/src/

# Build artifacts for creating documentation generated by the Documenter package
docs/build/
docs/site/

# File generated by Pkg, the package manager, based on a corresponding Project.toml
# It records a fixed state of all packages used by the project. As such, it should not be
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest*.toml

# Julia data files
*.jld
*.jld2

# # Figures
# *.gif
# *.png
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[compat]
TensorOperations = "5.1.0"
LuxCore = "1"
julia = "1.6"

Expand Down
60 changes: 46 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

[![Build Status](https://github.com/vpuri3/KolmogorovArnold.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/vpuri3/KolmogorovArnold.jl/actions/workflows/CI.yml?query=branch%3Amaster)

Julia implementation of [FourierKAN](https://github.com/GistNoesis/FourierKAN)

Julia implementation of [ChebyKAN](https://github.com/SynodicMonth/ChebyKAN)


Julia implementation of the [Kolmogorov-Arnold network](https://arxiv.org/abs/2404.19756)
for the [`Lux.jl`](https://lux.csail.mit.edu/stable/) framework.
This implementation is based on [efficient-kan](https://github.com/Blealtan/efficient-kan)
Expand All @@ -23,7 +28,7 @@ x = rand32(rng, in_dim, 10)
y = layer(x, p, st)
```

We compare the performance of KAN with an MLP that has the same number of parameters (see `examples/eg1.jl`).
We compare the performance of different implementation of KAN with an MLP that has the same number of parameters (see `examples/eg1.jl`).
```julia
using Lux, KolmogorovArnold
using CUDA, LuxDeviceUtils
Expand All @@ -34,8 +39,8 @@ device = Lux.gpu_device()
rng = Random.default_rng()
Random.seed!(rng, 0)

x = rand32(rng, 1, 1000) |> device

x = rand32(rng, 1, 1000) |> device
x₀ = rand32(rng, 1000, 1) |> device
# define MLP, KANs

mlp = Chain(
Expand All @@ -59,35 +64,62 @@ kan2 = Chain(
KDense(40, 1, 10; use_base_act = false, basis_func, normalizer),
) # 16_800 parameters plus 30 states.

kan3 = Chain(
CDense( 1, 40, G),
CDense(40, 40, G),
CDense(40, 1, G),
) # 18_561 parameters plus 0 states.

kan4 = Chain(
FDense( 1, 30, G),
FDense(30, 30, G),
FDense(30, 1, G),
) # 19_261 parameters plus 0 states.

# set up experiment
pM , stM = Lux.setup(rng, mlp)
pM, stM = Lux.setup(rng, mlp)
pK1, stK1 = Lux.setup(rng, kan1)
pK2, stK2 = Lux.setup(rng, kan2)
pK3, stK3 = Lux.setup(rng, kan3)
pK4, stK4 = Lux.setup(rng, kan4)


pM = ComponentArray(pM) |> device
pK1 = ComponentArray(pK1) |> device
pK2 = ComponentArray(pK2) |> device
pK3 = ComponentArray(pK3) |> device
pK4 = ComponentArray(pK4) |> device

stM, stK1, stK2 = device(stM), device(stK1), device(stK2)
stM, stK1, stK2, stK3, stK4 = device(stM), device(stK1), device(stK2), device(stK4), device(stK4)

# Forward pass
@btime CUDA.@sync $mlp( $x, $pM , $stM) # 46.645 μs (267 allocations: 6.88 KiB)
@btime CUDA.@sync $kan1($x, $pK1, $stK1) # 244.895 μs (1298 allocations: 31.16 KiB)
@btime CUDA.@sync $kan2($x, $pK2, $stK2) # 148.830 μs (887 allocations: 21.08 KiB)

@btime CUDA.@sync $mlp($x, $pM, $stM) # 31.611 μs (248 allocations: 5.45 KiB)
@btime CUDA.@sync $kan1($x, $pK1, $stK1) # 125.790 μs (1034 allocations: 21.97 KiB)
@btime CUDA.@sync $kan2($x, $pK2, $stK2) # 87.585 μs (1335 allocations: 13.95 KiB)
@btime CUDA.@sync $kan3($x', $pK3, $stK3) # 210.785 μs (1335 allocations: 31.03 KiB)
@btime CUDA.@sync $kan4($x', $pK4, $stK4) # 2.392 ms (1642 allocations: 34.56 KiB)


# Backward pass

f_mlp(p) = mlp( x, p, stM )[1] |> sum
f_mlp(p) = mlp(x, p, stM)[1] |> sum
f_kan1(p) = kan1(x, p, stK1)[1] |> sum
f_kan2(p) = kan2(x, p, stK2)[1] |> sum
f_kan3(p) = kan3(x₀, p, stK3)[1] |> sum
f_kan4(p) = kan4(x₀, p, stK4)[1] |> sum


@btime CUDA.@sync Zygote.gradient($f_mlp, $pM) # 268.074 μs (1971 allocations: 57.03 KiB)
@btime CUDA.@sync Zygote.gradient($f_kan1, $pK1) # 831.888 μs (5015 allocations: 123.25 KiB)
@btime CUDA.@sync Zygote.gradient($f_kan2, $pK2) # 658.578 μs (3314 allocations: 87.16 KiB)
@btime CUDA.@sync Zygote.gradient($f_kan3, $pK3) # 1.647 ms (7138 allocations: 180.45 KiB)
@btime CUDA.@sync Zygote.gradient($f_kan4, $pK4) # 7.028 ms (8745 allocations: 199.42 KiB)

@btime CUDA.@sync Zygote.gradient($f_mlp , $pM) # 541.759 μs (2343 allocations: 70.77 KiB)
@btime CUDA.@sync Zygote.gradient($f_kan1, $pK1) # 1.471 ms (6396 allocations: 171.08 KiB)
@btime CUDA.@sync Zygote.gradient($f_kan2, $pK2) # 1.046 ms (4314 allocations: 123.08 KiB)

```
The performance of KANs improves significantly with `use_base_act = false`.
Although KANs are currently 2-3x slower than an MLPs with the same number of parameters,
The performance of KAN with radial basis functions improves significantly with `use_base_act = false`.
Although KANs are currently significantly slower than an MLPs with the same number of parameters,
the promise with this architecture is that a small KAN can potentially do the work of a much bigger MLP.
More experiments need to be done to assess the validity of this claim.

Expand Down
52 changes: 41 additions & 11 deletions examples/eg1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ Random.seed!(rng, 0)
device = Lux.gpu_device()

#======================================================#
function main()
x = rand32(rng, 1, 1000) |> device
function main(N=1000)
x = rand32(rng, 1, N) |> device
x₀ = rand32(rng, N, 1) |> device

wM, wK, G = 128, 40, 10 # MLP width, KAN width, grid size
wM, wK, wK2, G = 128, 40, 30, 10 # MLP width, KAN width, grid size

mlp = Chain(
Dense(1, wM, tanh),
Expand All @@ -51,23 +52,44 @@ function main()
KDense(wK, 1, G; use_base_act = false, basis_func, normalizer),
)

kan3 = Chain(
CDense( 1, wK, G),
CDense(wK, wK, G),
CDense(wK, 1, G),
)

kan4 = Chain(
FDense( 1, wK2, G),
FDense(wK2, wK2, G),
FDense(wK2, 1, G),
)

display(mlp)
display(kan1)
display(kan2)
display(kan3)
display(kan4)

pM, stM = Lux.setup(rng, mlp)
pK1, stK1 = Lux.setup(rng, kan1)
pK2, stK2 = Lux.setup(rng, kan2)
pK3, stK3 = Lux.setup(rng, kan3)
pK4, stK4 = Lux.setup(rng, kan4)

pM = ComponentArray(pM) |> device

pM = ComponentArray(pM) |> device
pK1 = ComponentArray(pK1) |> device
pK2 = ComponentArray(pK2) |> device
pK3 = ComponentArray(pK3) |> device
pK4 = ComponentArray(pK4) |> device

stM, stK1, stK2 = device(stM), device(stK1), device(stK2)
stM, stK1, stK2, stK3, stK4 = device(stM), device(stK1), device(stK2), device(stK4), device(stK4)

f_mlp(p) = mlp(x, p, stM)[1] |> sum
f_mlp(p) = mlp(x, p, stM)[1] |> sum
f_kan1(p) = kan1(x, p, stK1)[1] |> sum
f_kan2(p) = kan2(x, p, stK2)[1] |> sum
f_kan3(p) = kan3(x₀, p, stK3)[1] |> sum
f_kan4(p) = kan4(x₀, p, stK4)[1] |> sum

# # Zygote is type unstable - consider using generated functinos
# _, pbM = Zygote.pullback(f_mlp, pM)
Expand All @@ -83,24 +105,32 @@ function main()
@btime CUDA.@sync $mlp($x, $pM, $stM)
@btime CUDA.@sync $kan1($x, $pK1, $stK1)
@btime CUDA.@sync $kan2($x, $pK2, $stK2)

@btime CUDA.@sync $kan3($x₀, $pK3, $stK3)
@btime CUDA.@sync $kan4($x₀, $pK4, $stK4)

println("# BWD PASS")

@btime CUDA.@sync Zygote.gradient($f_mlp, $pM)
@btime CUDA.@sync Zygote.gradient($f_kan1, $pK1)
@btime CUDA.@sync Zygote.gradient($f_kan2, $pK2)
@btime CUDA.@sync Zygote.gradient($f_kan3, $pK3)
@btime CUDA.@sync Zygote.gradient($f_kan4, $pK4)
else
println("# FWD PASS")

@btime $mlp($x, $pM, $stM)
@btime $kan1($x, $pK1, $stK1)
@btime $kan2($x, $pK2, $stK2)

@btime $kan3($x₀, $pK3, $stK3)
@btime $kan4($x₀, $pK4, $stK4)

println("# BWD PASS")

@btime Zygote.gradient($f_mlp, $pM)
@btime Zygote.gradient($f_kan1, $pK1)
@btime Zygote.gradient($f_kan2, $pK2)
@btime Zygote.gradient($f_kan3, $pK3)
@btime Zygote.gradient($f_kan4, $pK4)
end

nothing
Expand Down
79 changes: 79 additions & 0 deletions examples/eg4.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using KolmogorovArnold
using MLDataDevices
using Lux, Zygote, Random, Statistics, Plots, CUDA, LuxCUDA, ComponentArrays
using Optimisers, OptimizationOptimJL
using cuTENSOR

cpud = cpu_device()
gpud = gpu_device()
rng = Random.default_rng()

"""
Fits the model to the curve and returns the
"""
function fit(name, model, device, i_shape, i_d)

x₀ = sort!(rand(Float32, i_shape...), dims=i_d)
yₜ = cos.(π .* x₀) .+ sin.(4 .* π .* x₀) .* tanh.(x₀) .+ x₀.^4
x₀ = x₀ |> device
yₜ = yₜ |> device

# Initiate model
parameters, layer_states = Lux.setup(rng, model)
parameters = ComponentArray(parameters) |> device
layer_states = layer_states |> device

# Initial Prediction
yᵢ, layer_states = model(x₀, parameters, layer_states)

# Set up the optimizer
opt_state = Optimisers.setup(Optimisers.Adam(0.0003), parameters)

# Define the loss function
function loss_fn(pa, ls)
yₚ, new_ls = model(x₀, pa, ls)
l = mean((yₚ .- yₜ).^2)
return l, new_ls
end

loss = 10.0f32
epoch = 0
while loss > 1e-4 && epoch < 2e4
(loss, layer_states), back = pullback(loss_fn, parameters, layer_states)
grad, _ = back((1.0, nothing))
grad = map(g -> clamp(g, -3, 3), grad)
opt_state, parameters = Optimisers.update(opt_state, parameters, grad)
#print("\rName: $name - Epoch: $epoch, Loss: $loss")
epoch += 1
end

#println()
## Getting the final evaluatin for the test
#yₑ, layer_states = model(x₀, parameters, layer_states)
## Plotting the truth and the initial / final predictions
#x₀ = vec(x₀) |> cpud
#yᵢ = vec(yᵢ) |> cpud
#yₑ = vec(yₑ) |> cpud
#yₜ = vec(yₜ) |> cpud
#plot(x₀, yₜ, label="truth", color="red")
#plot!(x₀, yᵢ, label="inita approx", color="blue")
#plot!(x₀, yₑ, label="final approx", color="green")
#scatter!(x₀, yₜ, color="red", label=false)
#scatter!(x₀, yᵢ, color="blue", label=false)
#scatter!(x₀, yₑ, color="green", label=false)
#savefig("$name.png")

epoch
end



@test fit("fKAN_cpu", Chain(FDense(1, 10, 10), FDense(10, 1, 10)), cpud, (50, 1), 1) <= 2e4
@test fit("cKAN_cpu", Chain(CDense(1, 20, 50), CDense(20, 1, 50)), cpud, (50, 1), 1) <= 2e4
@test fit("rKAN_cpu", Chain(KDense(1, 10, 10), KDense(10, 1, 10)), cpud, (1, 50), 2) <= 2e4

if gpud isa MLDataDevices.AbstractGPUDevice
@test fit("fKAN_gpu", Chain(FDense(1, 10, 10), FDense(10, 1, 10)), gpud, (50, 1), 1) <= 2e4
@test fit("cKAN_gpu", Chain(CDense(1, 20, 50), CDense(20, 1, 50)), gpud, (50, 1), 1) <= 2e4
@test fit("rKAN_gpu", Chain(KDense(1, 10, 10), KDense(10, 1, 10)), gpud, (1, 50), 2) <= 2e4
end
10 changes: 9 additions & 1 deletion src/KolmogorovArnold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,23 @@ using LuxCore
using WeightInitializers
using ConcreteStructs

using TensorOperations

using ChainRulesCore
const CRC = ChainRulesCore

include("utils.jl")
export rbf, rswaf, iqf
export rbf, rswaf, iqf, batched_mul

include("kdense.jl")
export KDense

include("fdense.jl")
export FDense

include("cdense.jl")
export CDense

# include("explicit")
# export GDense

Expand Down
Loading