Skip to content

Commit

Permalink
moving to Lux v1
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Oct 29, 2024
1 parent d2f99fb commit 8910776
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 13 deletions.
6 changes: 3 additions & 3 deletions examples/eg1.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#
using KolmogorovArnold
using Random, LinearAlgebra

# Add test dependencies to env stack
let
Expand All @@ -9,8 +8,9 @@ let
!(tstpath in LOAD_PATH) && push!(LOAD_PATH, tstpath)
end

using Random, LinearAlgebra
using Zygote, Lux, ComponentArrays
using LuxDeviceUtils, CUDA, LuxCUDA
using MLDataDevices, CUDA, LuxCUDA
using BenchmarkTools

# configure BLAS
Expand Down Expand Up @@ -77,7 +77,7 @@ function main()
# @code_warntype pbM(x)
# @code_warntype pbK(x)

if device isa LuxDeviceUtils.AbstractLuxGPUDevice
if device isa MLDataDevices.AbstractGPUDevice
println("# FWD PASS")

@btime CUDA.@sync $mlp($x, $pM, $stM)
Expand Down
4 changes: 2 additions & 2 deletions examples/eg2.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#
using KolmogorovArnold
using Random, LinearAlgebra

# Add test dependencies to env stack
let
Expand All @@ -9,8 +8,9 @@ let
!(tstpath in LOAD_PATH) && push!(LOAD_PATH, tstpath)
end

using Random, LinearAlgebra
using Plots, NNlib, Zygote, BenchmarkTools
using LuxDeviceUtils, CUDA, LuxCUDA
using MLDataDevices, CUDA, LuxCUDA

# configure BLAS
ncores = min(Sys.CPU_THREADS, length(Sys.cpu_info()))
Expand Down
5 changes: 2 additions & 3 deletions examples/eg3.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# evaluate on MNIST1D
using KolmogorovArnold
using Random, LinearAlgebra

# Add test dependencies to env stack
let
Expand All @@ -9,9 +8,9 @@ let
!(tstpath in LOAD_PATH) && push!(LOAD_PATH, tstpath)
end

using Random, LinearAlgebra
using Zygote, Lux, ComponentArrays
using LuxDeviceUtils, CUDA, LuxCUDA
using MLUtils, MLDatasets
using MLDataDevices, CUDA, LuxCUDA

# configure BLAS
ncores = min(Sys.CPU_THREADS, length(Sys.cpu_info()))
Expand Down
2 changes: 1 addition & 1 deletion src/KolmogorovArnold.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module KolmogorovArnold

using Random
using LinearAlgebra
using Random: AbstractRNG

using NNlib
using LuxCore
Expand Down
2 changes: 1 addition & 1 deletion src/kdense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#======================================================#
# Kolmogorov-Arnold Layer
#======================================================#
@concrete struct KDense{use_base_act} <: LuxCore.AbstractExplicitLayer
@concrete struct KDense{use_base_act} <: LuxCore.AbstractLuxLayer
in_dims::Int
out_dims::Int
grid_len::Int
Expand Down
3 changes: 0 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
using KolmogorovArnold
using Test

using Lux, Zygote
using Optimisers, OptimizationOptimJL

pkgpath = dirname(dirname(pathof(KolmogorovArnold)))

@testset "KolmogorovArnold.jl" begin
Expand Down

0 comments on commit 8910776

Please sign in to comment.