From 907a0385dbd1eecedc199d9f04fc39ccf82c148e Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Mon, 13 May 2024 09:08:13 -0400 Subject: [PATCH] updates --- README.md | 8 ++++---- examples/eg1.jl | 6 +++--- src/type.jl | 20 ++++++-------------- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index b5ba83f..2748f63 100644 --- a/README.md +++ b/README.md @@ -65,12 +65,12 @@ pK, stK = Lux.setup(rng, kan) |> device 34.360 μs (175 allocations: 4.78 KiB) 155.781 μs (565 allocations: 17.50 KiB) ``` -With `use_base_activcation = false`, the performance of KAN effectively doubles +With `use_base_act = false`, the performance of KAN effectively doubles ```julia kan = Chain( - KDense(1, 8, 15; use_base_activation = false), - KDense(8, 8, 15; use_base_activation = false), - KDense(8, 1, 15; use_base_activation = false), + KDense(1, 8, 15; use_base_act = false), + KDense(8, 8, 15; use_base_act = false), + KDense(8, 1, 15; use_base_act = false), ) p, st = Lux.setup(rng, kan) |> device @btime CUDA.@sync $mlp($x, $p, $st) diff --git a/examples/eg1.jl b/examples/eg1.jl index ef6c897..1663b94 100644 --- a/examples/eg1.jl +++ b/examples/eg1.jl @@ -27,9 +27,9 @@ function main() ) kan = Chain( - KDense(1, 8, 15; use_base_act = false), - KDense(8, 8, 15; use_base_act = false), - KDense(8, 1, 15; use_base_act = false), + KDense(1, 8, 15; use_base_act = true), + KDense(8, 8, 15; use_base_act = true), + KDense(8, 1, 15; use_base_act = true), ) # display(mlp) diff --git a/src/type.jl b/src/type.jl index c33ddb5..69b4ccd 100644 --- a/src/type.jl +++ b/src/type.jl @@ -73,19 +73,7 @@ function LuxCore.parameterlength( len end -function (l::KDense{false})(x::AbstractVecOrMat, p, st) - K = size(x, 2) # [I, K] - x_resh = reshape(x, 1, :) # [1, I * K] - x_norm = l.normalizer(x_resh) # ∈ [-1, 1] - - basis = rbf.(x_norm, st.grid, l.denominator) # [G, I * K] - basis = reshape(basis, l.grid_len * l.in_dims, K) # [G * I, K] - y = p.W1 * basis # [O, K] - - y, st -end - -function (l::KDense{true})(x::AbstractVecOrMat, p, st) +function (l::KDense{use_base_act})(x::AbstractVecOrMat, p, st) where{use_base_act} K = size(x, 2) # [I, K] x_resh = reshape(x, 1, :) # [1, I * K] x_norm = l.normalizer(x_resh) # ∈ [-1, 1] @@ -94,7 +82,11 @@ function (l::KDense{true})(x::AbstractVecOrMat, p, st) basis = reshape(basis, l.grid_len * l.in_dims, K) # [G * I, K] spline = p.W1 * basis # [O, K] - y = spline + p.W2 * l.base_act.(x) + y = if use_base_act + spline + p.W2 * l.base_act.(x) + else + spline + end y, st end