diff --git a/examples/eg1.jl b/examples/eg1.jl index 1663b94..b18c1d2 100644 --- a/examples/eg1.jl +++ b/examples/eg1.jl @@ -43,8 +43,13 @@ function main() stM, stK = device(stM), device(stK) - @btime CUDA.@sync $mlp($x, $pM, $stM) - @btime CUDA.@sync $kan($x, $pK, $stK) + if device isa LuxDeviceUtils.AbstractLuxGPUDevice + @btime CUDA.@sync $mlp($x, $pM, $stM) + @btime CUDA.@sync $kan($x, $pK, $stK) + else + @btime $mlp($x, $pM, $stM) + @btime $kan($x, $pK, $stK) + end nothing end