Skip to content

Commit

Permalink
use_base_act
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed May 13, 2024
1 parent 39e96d3 commit 0232e9e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 28 deletions.
29 changes: 21 additions & 8 deletions examples/eg1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ let
!(tstpath in LOAD_PATH) && push!(LOAD_PATH, tstpath)
end

using Lux
using Lux, ComponentArrays
using LuxDeviceUtils, CUDA, LuxCUDA
using BenchmarkTools

Expand All @@ -20,14 +20,28 @@ device = Lux.gpu_device()
function main()
x = rand32(rng, 1, 1000) |> device

mlp = Chain(Dense(1, 32), Dense(32, 32), Dense(32, 1),)
kan = Chain(KDense(1, 8, 15), KDense(8, 8, 15), KDense(8, 1, 15))
mlp = Chain(
Dense(1, 32, tanh),
Dense(32, 32, tanh),
Dense(32, 1),
)

pM, stM = Lux.setup(rng, mlp) |> device
pK, stK = Lux.setup(rng, kan) |> device
kan = Chain(
KDense(1, 8, 15; use_base_act = false),
KDense(8, 8, 15; use_base_act = false),
KDense(8, 1, 15; use_base_act = false),
)

display(mlp)
display(kan)
# display(mlp)
# display(kan)

pM, stM = Lux.setup(rng, mlp)
pK, stK = Lux.setup(rng, kan)

pM = ComponentArray(pM) |> device
pK = ComponentArray(pK) |> device

stM, stK = device(stM), device(stK)

@btime CUDA.@sync $mlp($x, $pM, $stM)
@btime CUDA.@sync $kan($x, $pK, $stK)
Expand All @@ -37,5 +51,4 @@ end

main()


nothing
57 changes: 37 additions & 20 deletions src/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#======================================================#
# Kolmogorov-Arnold Layer
#======================================================#
@concrete struct KDense <: LuxCore.AbstractExplicitLayer
@concrete struct KDense{use_base_act} <: LuxCore.AbstractExplicitLayer
in_dims::Int
out_dims::Int
grid_len::Int
Expand All @@ -11,7 +11,6 @@
base_act
init_W1
init_W2
use_base_activation
end

function KDense(
Expand All @@ -22,26 +21,31 @@ function KDense(
base_act = silu,
init_W1 = glorot_uniform,
init_W2 = glorot_uniform,
use_base_activation = true,
use_fast_activation::Bool = true,
use_base_act = true,
use_fast_act::Bool = true,
)
normalizer = use_fast_activation ? tanh_fast : tanh
normalizer = use_fast_act ? tanh_fast : tanh
base_act = use_fast_act ? NNlib.fast_act(base_act) : base_act

KDense(
KDense{use_base_act}(
in_dims, out_dims, grid_len,
denominator, normalizer, base_act,
init_W1, init_W2, use_base_activation,
init_W1, init_W2,
)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::KDense)
function LuxCore.initialparameters(
rng::AbstractRNG,
l::KDense{use_base_act}
) where{use_base_act}
p = (;
W1 = l.init_W1(rng, l.out_dims, l.grid_len * l.in_dims),
)

if l.use_base_activation
if use_base_act
p = (;
p..., W2 = l.init_W2(rng, l.out_dims, l.in_dims),
p...,
W2 = l.init_W2(rng, l.out_dims, l.in_dims),
)
end

Expand All @@ -54,30 +58,43 @@ function LuxCore.initialstates(::AbstractRNG, l::KDense,)
(; grid,)
end

LuxCore.statelength(l::KDense) = l.grid_len
function LuxCore.parameterlength(l::KDense)
function LuxCore.statelength(l::KDense)
l.grid_len
end

function LuxCore.parameterlength(
l::KDense{use_base_act},
) where{use_base_act}
len = l.in_dims * l.grid_len * l.out_dims
if l.use_base_activation
if use_base_act
len += l.in_dims * l.out_dims
end

len
end

function (l::KDense)(x::AbstractVecOrMat, p, st)
function (l::KDense{false})(x::AbstractVecOrMat, p, st)
K = size(x, 2) # [I, K]
x_resh = reshape(x, 1, :) # [1, I * K]
x_norm = l.normalizer(x_resh) # ∈ [-1, 1]

basis = rbf.(x_norm, st.grid, l.denominator) # [G, I * K]
basis = reshape(basis, l.grid_len * l.in_dims, K) # [G * I, K]
y = p.W1 * basis # [O, K]

y, st
end

function (l::KDense{true})(x::AbstractVecOrMat, p, st)
K = size(x, 2) # [I, K]
x_resh = reshape(x, 1, :) # [1, I * K]
x_norm = l.normalizer(x_resh) # ∈ [-1, 1]

basis = rbf.(x_norm, st.grid, l.denominator) # [G, I * K]
basis = rbf.(x_norm, st.grid, l.denominator) # [G, I * K]
basis = reshape(basis, l.grid_len * l.in_dims, K) # [G * I, K]
spline = p.W1 * basis # [O, K]

y = if l.use_base_activation
spline + p.W2 * l.base_act.(x)
else
spline
end
y = spline + p.W2 * l.base_act.(x)

y, st
end
Expand Down

0 comments on commit 0232e9e

Please sign in to comment.