diff --git a/examples/eg1.jl b/examples/eg1.jl index ef203ff..6e570d6 100644 --- a/examples/eg1.jl +++ b/examples/eg1.jl @@ -1,6 +1,5 @@ # using KolmogorovArnold -using Random, LinearAlgebra # Add test dependencies to env stack let @@ -9,8 +8,9 @@ let !(tstpath in LOAD_PATH) && push!(LOAD_PATH, tstpath) end +using Random, LinearAlgebra using Zygote, Lux, ComponentArrays -using LuxDeviceUtils, CUDA, LuxCUDA +using MLDataDevices, CUDA, LuxCUDA using BenchmarkTools # configure BLAS @@ -77,7 +77,7 @@ function main() # @code_warntype pbM(x) # @code_warntype pbK(x) - if device isa LuxDeviceUtils.AbstractLuxGPUDevice + if device isa MLDataDevices.AbstractGPUDevice println("# FWD PASS") @btime CUDA.@sync $mlp($x, $pM, $stM) diff --git a/examples/eg2.jl b/examples/eg2.jl index c8cc5d2..bb918ea 100644 --- a/examples/eg2.jl +++ b/examples/eg2.jl @@ -1,6 +1,5 @@ # using KolmogorovArnold -using Random, LinearAlgebra # Add test dependencies to env stack let @@ -9,8 +8,9 @@ let !(tstpath in LOAD_PATH) && push!(LOAD_PATH, tstpath) end +using Random, LinearAlgebra using Plots, NNlib, Zygote, BenchmarkTools -using LuxDeviceUtils, CUDA, LuxCUDA +using MLDataDevices, CUDA, LuxCUDA # configure BLAS ncores = min(Sys.CPU_THREADS, length(Sys.cpu_info())) diff --git a/examples/eg3.jl b/examples/eg3.jl index 57118da..16fc225 100644 --- a/examples/eg3.jl +++ b/examples/eg3.jl @@ -1,6 +1,5 @@ # evaluate on MNIST1D using KolmogorovArnold -using Random, LinearAlgebra # Add test dependencies to env stack let @@ -9,9 +8,9 @@ let !(tstpath in LOAD_PATH) && push!(LOAD_PATH, tstpath) end +using Random, LinearAlgebra using Zygote, Lux, ComponentArrays -using LuxDeviceUtils, CUDA, LuxCUDA -using MLUtils, MLDatasets +using MLDataDevices, CUDA, LuxCUDA # configure BLAS ncores = min(Sys.CPU_THREADS, length(Sys.cpu_info())) diff --git a/src/KolmogorovArnold.jl b/src/KolmogorovArnold.jl index c0e1b0b..20a35ee 100644 --- a/src/KolmogorovArnold.jl +++ b/src/KolmogorovArnold.jl @@ -1,7 +1,7 @@ module KolmogorovArnold -using Random using LinearAlgebra +using Random: AbstractRNG using NNlib using LuxCore diff --git a/src/kdense.jl b/src/kdense.jl index 4a65533..42384b6 100644 --- a/src/kdense.jl +++ b/src/kdense.jl @@ -2,7 +2,7 @@ #======================================================# # Kolmogorov-Arnold Layer #======================================================# -@concrete struct KDense{use_base_act} <: LuxCore.AbstractExplicitLayer +@concrete struct KDense{use_base_act} <: LuxCore.AbstractLuxLayer in_dims::Int out_dims::Int grid_len::Int diff --git a/test/runtests.jl b/test/runtests.jl index 863605a..81a0561 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,6 @@ using KolmogorovArnold using Test -using Lux, Zygote -using Optimisers, OptimizationOptimJL - pkgpath = dirname(dirname(pathof(KolmogorovArnold))) @testset "KolmogorovArnold.jl" begin