From 0232e9e92e4dc3165df32560a900c1fb16cfcff2 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Mon, 13 May 2024 09:04:06 -0400 Subject: [PATCH] use_base_act --- examples/eg1.jl | 29 ++++++++++++++++++------- src/type.jl | 57 ++++++++++++++++++++++++++++++++----------------- 2 files changed, 58 insertions(+), 28 deletions(-) diff --git a/examples/eg1.jl b/examples/eg1.jl index 0aa7b5c..ef6c897 100644 --- a/examples/eg1.jl +++ b/examples/eg1.jl @@ -9,7 +9,7 @@ let !(tstpath in LOAD_PATH) && push!(LOAD_PATH, tstpath) end -using Lux +using Lux, ComponentArrays using LuxDeviceUtils, CUDA, LuxCUDA using BenchmarkTools @@ -20,14 +20,28 @@ device = Lux.gpu_device() function main() x = rand32(rng, 1, 1000) |> device - mlp = Chain(Dense(1, 32), Dense(32, 32), Dense(32, 1),) - kan = Chain(KDense(1, 8, 15), KDense(8, 8, 15), KDense(8, 1, 15)) + mlp = Chain( + Dense(1, 32, tanh), + Dense(32, 32, tanh), + Dense(32, 1), + ) - pM, stM = Lux.setup(rng, mlp) |> device - pK, stK = Lux.setup(rng, kan) |> device + kan = Chain( + KDense(1, 8, 15; use_base_act = false), + KDense(8, 8, 15; use_base_act = false), + KDense(8, 1, 15; use_base_act = false), + ) - display(mlp) - display(kan) + # display(mlp) + # display(kan) + + pM, stM = Lux.setup(rng, mlp) + pK, stK = Lux.setup(rng, kan) + + pM = ComponentArray(pM) |> device + pK = ComponentArray(pK) |> device + + stM, stK = device(stM), device(stK) @btime CUDA.@sync $mlp($x, $pM, $stM) @btime CUDA.@sync $kan($x, $pK, $stK) @@ -37,5 +51,4 @@ end main() - nothing diff --git a/src/type.jl b/src/type.jl index 4692aae..c33ddb5 100644 --- a/src/type.jl +++ b/src/type.jl @@ -2,7 +2,7 @@ #======================================================# # Kolmogorov-Arnold Layer #======================================================# -@concrete struct KDense <: LuxCore.AbstractExplicitLayer +@concrete struct KDense{use_base_act} <: LuxCore.AbstractExplicitLayer in_dims::Int out_dims::Int grid_len::Int @@ -11,7 +11,6 @@ base_act init_W1 init_W2 - use_base_activation end function KDense( @@ -22,26 +21,31 @@ function KDense( base_act = silu, init_W1 = glorot_uniform, init_W2 = glorot_uniform, - use_base_activation = true, - use_fast_activation::Bool = true, + use_base_act = true, + use_fast_act::Bool = true, ) - normalizer = use_fast_activation ? tanh_fast : tanh + normalizer = use_fast_act ? tanh_fast : tanh + base_act = use_fast_act ? NNlib.fast_act(base_act) : base_act - KDense( + KDense{use_base_act}( in_dims, out_dims, grid_len, denominator, normalizer, base_act, - init_W1, init_W2, use_base_activation, + init_W1, init_W2, ) end -function LuxCore.initialparameters(rng::AbstractRNG, l::KDense) +function LuxCore.initialparameters( + rng::AbstractRNG, + l::KDense{use_base_act} +) where{use_base_act} p = (; W1 = l.init_W1(rng, l.out_dims, l.grid_len * l.in_dims), ) - if l.use_base_activation + if use_base_act p = (; - p..., W2 = l.init_W2(rng, l.out_dims, l.in_dims), + p..., + W2 = l.init_W2(rng, l.out_dims, l.in_dims), ) end @@ -54,30 +58,43 @@ function LuxCore.initialstates(::AbstractRNG, l::KDense,) (; grid,) end -LuxCore.statelength(l::KDense) = l.grid_len -function LuxCore.parameterlength(l::KDense) +function LuxCore.statelength(l::KDense) + l.grid_len +end + +function LuxCore.parameterlength( + l::KDense{use_base_act}, +) where{use_base_act} len = l.in_dims * l.grid_len * l.out_dims - if l.use_base_activation + if use_base_act len += l.in_dims * l.out_dims end len end -function (l::KDense)(x::AbstractVecOrMat, p, st) +function (l::KDense{false})(x::AbstractVecOrMat, p, st) + K = size(x, 2) # [I, K] + x_resh = reshape(x, 1, :) # [1, I * K] + x_norm = l.normalizer(x_resh) # ∈ [-1, 1] + + basis = rbf.(x_norm, st.grid, l.denominator) # [G, I * K] + basis = reshape(basis, l.grid_len * l.in_dims, K) # [G * I, K] + y = p.W1 * basis # [O, K] + + y, st +end + +function (l::KDense{true})(x::AbstractVecOrMat, p, st) K = size(x, 2) # [I, K] x_resh = reshape(x, 1, :) # [1, I * K] x_norm = l.normalizer(x_resh) # ∈ [-1, 1] - basis = rbf.(x_norm, st.grid, l.denominator) # [G, I * K] + basis = rbf.(x_norm, st.grid, l.denominator) # [G, I * K] basis = reshape(basis, l.grid_len * l.in_dims, K) # [G * I, K] spline = p.W1 * basis # [O, K] - y = if l.use_base_activation - spline + p.W2 * l.base_act.(x) - else - spline - end + y = spline + p.W2 * l.base_act.(x) y, st end