Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed May 13, 2024
1 parent 0232e9e commit 907a038
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 21 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ pK, stK = Lux.setup(rng, kan) |> device
34.360 μs (175 allocations: 4.78 KiB)
155.781 μs (565 allocations: 17.50 KiB)
```
With `use_base_activcation = false`, the performance of KAN effectively doubles
With `use_base_act = false`, the performance of KAN effectively doubles
```julia
kan = Chain(
KDense(1, 8, 15; use_base_activation = false),
KDense(8, 8, 15; use_base_activation = false),
KDense(8, 1, 15; use_base_activation = false),
KDense(1, 8, 15; use_base_act = false),
KDense(8, 8, 15; use_base_act = false),
KDense(8, 1, 15; use_base_act = false),
)
p, st = Lux.setup(rng, kan) |> device
@btime CUDA.@sync $mlp($x, $p, $st)
Expand Down
6 changes: 3 additions & 3 deletions examples/eg1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ function main()
)

kan = Chain(
KDense(1, 8, 15; use_base_act = false),
KDense(8, 8, 15; use_base_act = false),
KDense(8, 1, 15; use_base_act = false),
KDense(1, 8, 15; use_base_act = true),
KDense(8, 8, 15; use_base_act = true),
KDense(8, 1, 15; use_base_act = true),
)

# display(mlp)
Expand Down
20 changes: 6 additions & 14 deletions src/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,7 @@ function LuxCore.parameterlength(
len
end

function (l::KDense{false})(x::AbstractVecOrMat, p, st)
K = size(x, 2) # [I, K]
x_resh = reshape(x, 1, :) # [1, I * K]
x_norm = l.normalizer(x_resh) # ∈ [-1, 1]

basis = rbf.(x_norm, st.grid, l.denominator) # [G, I * K]
basis = reshape(basis, l.grid_len * l.in_dims, K) # [G * I, K]
y = p.W1 * basis # [O, K]

y, st
end

function (l::KDense{true})(x::AbstractVecOrMat, p, st)
function (l::KDense{use_base_act})(x::AbstractVecOrMat, p, st) where{use_base_act}
K = size(x, 2) # [I, K]
x_resh = reshape(x, 1, :) # [1, I * K]
x_norm = l.normalizer(x_resh) # ∈ [-1, 1]
Expand All @@ -94,7 +82,11 @@ function (l::KDense{true})(x::AbstractVecOrMat, p, st)
basis = reshape(basis, l.grid_len * l.in_dims, K) # [G * I, K]
spline = p.W1 * basis # [O, K]

y = spline + p.W2 * l.base_act.(x)
y = if use_base_act
spline + p.W2 * l.base_act.(x)
else
spline
end

y, st
end
Expand Down

0 comments on commit 907a038

Please sign in to comment.