Skip to content

Commit

Permalink
choose normalizations
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed May 15, 2024
1 parent e95e827 commit dc3d7eb
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 35 deletions.
26 changes: 15 additions & 11 deletions examples/eg1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ device = Lux.gpu_device()

function main()
x = rand32(rng, 1, 1000) |> device
y = rand32(rng, 1, 1000) |> device

mlp = Chain(
Dense(1, 32, tanh),
Expand All @@ -35,9 +34,9 @@ function main()
)

kan = Chain(
KDense( 1, 10, 10; use_base_act = true),
KDense(10, 10, 10; use_base_act = true),
KDense(10, 1, 10; use_base_act = true),
KDense( 1, 10, 10; use_base_act = false),
KDense(10, 10, 10; use_base_act = false),
KDense(10, 1, 10; use_base_act = false),
)

display(mlp)
Expand All @@ -54,24 +53,31 @@ function main()
f_mlp(p) = mlp(x, p, stM)[1] |> sum
f_kan(p) = kan(x, p, stK)[1] |> sum

# # Zygote is type unstable - consider using generated functinos
# _, pbM = Zygote.pullback(f_mlp, pM)
# _, pbK = Zygote.pullback(f_kan, pK)

# @code_warntype pbM(x)
# @code_warntype pbK(x)

if device isa LuxDeviceUtils.AbstractLuxGPUDevice
println("# FWD PASS")

@btime CUDA.@sync $mlp($x, $pM, $stM)
@btime CUDA.@sync $kan($x, $pK, $stK)

println("# BWD PASS")

@btime CUDA.@sync Zygote.gradient($f_mlp, $pM)
@btime CUDA.@sync Zygote.gradient($f_kan, $pK)
else
println("# FWD PASS")

@btime $mlp($x, $pM, $stM)
@btime $kan($x, $pK, $stK)

println("# BWD PASS")

@btime Zygote.gradient($f_mlp, $pM)
@btime Zygote.gradient($f_kan, $pK)
end
Expand All @@ -80,5 +86,3 @@ function main()
end

main()

nothing
72 changes: 72 additions & 0 deletions src/alternate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
#======================================================#
# An alternate KAN Dense layer.
# Had to implement this to confirm that it doesn't train well
#======================================================#
export KDense1
@concrete struct KDense1{use_base_act} <: LuxCore.AbstractExplicitLayer
in_dims::Int
out_dims::Int
grid_len::Int
denominator
normalizer
base_act
init_C
init_W
end

function KDense1(
in_dims::Int,
out_dims::Int,
grid_len::Int;
denominator = Float32(2 / (grid_len - 1)),
base_act = silu,
init_C = glorot_uniform,
init_W = glorot_uniform,
use_base_act = true,
use_fast_act::Bool = true,
)
normalizer = use_fast_act ? tanh_fast : tanh
base_act = use_fast_act ? NNlib.fast_act(base_act) : base_act

KDense1{use_base_act}(
in_dims, out_dims, grid_len,
denominator, normalizer, base_act,
init_C, init_W,
)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::KDense1)
(;
C = l.init_C(rng, l.grid_len, l.in_dims),
W = l.init_W(rng, l.out_dims, l.in_dims),
)
end

function LuxCore.initialstates(::AbstractRNG, l::KDense1,)
grid = collect(LinRange(-1, 1, l.grid_len)) .|> Float32

(; grid,)
end

LuxCore.statelength(l::KDense1) = l.grid_len
LuxCore.parameterlength(l::KDense1) = l.in_dims * (l.grid_len + l.out_dims)

function (l::KDense1{use_base_act})(x::AbstractArray, p, st) where{use_base_act}
size_in = size(x) # [I, ..., batch,]
size_out = (l.out_dims, size_in[2:end]...,) # [O, ..., batch,]

x = reshape(x, l.in_dims, :)
K = size(x, 2)

x_norm = l.normalizer(x) # ∈ [-1, 1]
x_resh = reshape(x_norm, 1, l.in_dims, K) # [I, K]
basis = rbf.(x_resh, st.grid, l.denominator) # [G, I, K]
spline = dropdims(sum(p.C .* basis; dims = 1); dims = 1) # [I, K]
y = use_base_act ? (spline + l.base_act.(x)) : spline
z = p.W * y # [O, K]

reshape(z, size_out), st
end
#======================================================#
#
68 changes: 49 additions & 19 deletions src/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,60 @@
in_dims::Int
out_dims::Int
grid_len::Int
denominator
#
normalizer
grid_lims
denominator
#
base_act
init_W1
init_W2
init_C
init_W
end

function KDense(
in_dims::Int,
out_dims::Int,
grid_len::Int;
#
normalizer = tanh,
grid_lims::NTuple{2, Real} = (-1.0f0, 1.0f0),
denominator = Float32(2 / (grid_len - 1)),
base_act = silu,
init_W1 = glorot_uniform,
init_W2 = glorot_uniform,
#
base_act = swish,
use_base_act = true,
#
init_C = glorot_uniform,
init_W = glorot_uniform,
use_fast_act::Bool = true,
)
normalizer = use_fast_act ? tanh_fast : tanh
base_act = use_fast_act ? NNlib.fast_act(base_act) : base_act
T = promote_type(eltype.(grid_lims)...)

if isnothing(grid_lims)
grid_lims = if normalizer (sigmoid, sigmoid_fast)
(0, 1)
elseif normalizer (tanh, tanh_fast, softsign)
(-1, 1)
else
(-1, 1)
end
end

grid_span = grid_lims[2] > grid_lims[1]
@assert grid_span > 0

if isnothing(denominator)
denominator = grid_span / (grid_len - 1)
end

if use_fast_act
base_act = NNlib.fast_act(base_act)
normalizer = NNlib.fast_act(normalizer)
end

KDense{use_base_act}(
in_dims, out_dims, grid_len,
denominator, normalizer, base_act,
init_W1, init_W2,
normalizer, T.(grid_lims), T(denominator),
base_act, init_C, init_W,
)
end

Expand All @@ -39,24 +68,24 @@ function LuxCore.initialparameters(
l::KDense{use_base_act}
) where{use_base_act}
p = (;
W1 = l.init_W1(rng, l.out_dims, l.grid_len * l.in_dims),
C = l.init_C(rng, l.out_dims, l.grid_len * l.in_dims),
)
# W1 = l.init_W1(rng, l.out_dims, l.grid_len, l.in_dims),
# C = l.init_C(rng, l.out_dims, l.grid_len, l.in_dims),

if use_base_act
p = (;
p...,
W2 = l.init_W2(rng, l.out_dims, l.in_dims),
W = l.init_W(rng, l.out_dims, l.in_dims),
)
end

p
end

function LuxCore.initialstates(::AbstractRNG, l::KDense,)
grid = collect(LinRange(-1, 1, l.grid_len)) .|> Float32

(; grid,)
(;
grid = collect(LinRange(l.grid_lims..., l.grid_len))
)
end

function LuxCore.statelength(l::KDense)
Expand All @@ -81,19 +110,20 @@ function (l::KDense{use_base_act})(x::AbstractArray, p, st) where{use_base_act}
x = reshape(x, l.in_dims, :)
K = size(x, 2)

x_norm = l.normalizer(x) # ∈ [-1, 1]
x_norm = l.normalizer.(x) # ∈ [-1, 1]
x_resh = reshape(x_norm, 1, :) # [1, K]
basis = rbf.(x_resh, st.grid, l.denominator) # [G, I * K]
basis = reshape(basis, l.grid_len * l.in_dims, K) # [G * I, K]
spline = p.W1 * basis # [O, K]
spline = p.C * basis # [O, K]

y = if use_base_act
base = p.W2 * l.base_act.(x)
base = p.W * l.base_act.(x)
spline + base
else
spline
end

reshape(y, size_out), st
end
#======================================================#
#
5 changes: 0 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,3 @@ function CRC.rrule(::typeof(gaussian1D), x)
return y, ∇gaussian1D
end

# from https://github.com/LuxDL/Lux.jl/pull/627
@inline silu(x) = x * sigmoid(x)
@inline silu_fast(x) = x * sigmoid_fast(x)
@inline NNlib.fast_act(::typeof(silu), ::AbstractArray=1:0) = silu_fast

0 comments on commit dc3d7eb

Please sign in to comment.