Skip to content

Commit

Permalink
on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Nov 12, 2024
1 parent 38616fc commit 0fc3498
Showing 1 changed file with 28 additions and 23 deletions.
51 changes: 28 additions & 23 deletions examples/eg3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,32 @@ let
end

using Random, LinearAlgebra
using Zygote, Lux, ComponentArrays
using Reactant, Enzyme, Optimisers, Functors
using MLDataDevices
using BenchmarkTools
using MLDataDevices, BenchmarkTools
using Enzyme, Zygote, Lux, ComponentArrays

# configure BLAS
ncores = min(Sys.CPU_THREADS, length(Sys.cpu_info()))
BLAS.set_num_threads(ncores)

# configure CUDA
# using CUDA, LuxCUDA
# CUDA.allowscalar(false)
#
Reactant.set_default_backend("cpu")
using CUDA, LuxCUDA
CUDA.allowscalar(false)

# configure Reactant
using Reactant
Reactant.set_default_backend("gpu")

rng = Random.default_rng()
Random.seed!(rng, 0)

device_zy = cpu_device()
device_zy = gpu_device()
device_ra = reactant_device()

function Reactant.synchronize(x::ComponentArray)
Reactant.synchronize(getdata(x))
ComponentArray(getdata(x), getaxes(x))
end

#======================================================#
function main()

Expand Down Expand Up @@ -88,7 +93,7 @@ function main()

function grad_zy(model, ps, st, x, y)
lossfun = ps -> loss(model, ps, st, x, y)
only(zyote.gradient(lossfun, ps))
only(Zygote.gradient(lossfun, ps))
end

#------------------------#
Expand Down Expand Up @@ -116,28 +121,28 @@ function main()

println("\n# FWD Vanilla\n")

@btime $mlp( $x_zy, $pM_zy , $stM_zy )
@btime $kan1($x_zy, $pK1_zy, $stK1_zy)
@btime $kan2($x_zy, $pK2_zy, $stK2_zy)
@btime CUDA.@sync $mlp( $x_zy, $pM_zy , $stM_zy )
@btime CUDA.@sync $kan1($x_zy, $pK1_zy, $stK1_zy)
@btime CUDA.@sync $kan2($x_zy, $pK2_zy, $stK2_zy)

println("\n# FWD Reactant\n")

@btime $mlp_comp( $x_ra, $pM_ra , $stM_ra )
@btime $kan1_comp($x_ra, $pK1_ra, $stK1_ra)
@btime $kan2_comp($x_ra, $pK2_ra, $stK2_ra)
@btime Reactant.synchronize($mlp_comp( $x_ra, $pM_ra , $stM_ra )[1])
@btime Reactant.synchronize($kan1_comp($x_ra, $pK1_ra, $stK1_ra)[1])
@btime Reactant.synchronize($kan2_comp($x_ra, $pK2_ra, $stK2_ra)[1])

#------------------------#
println("\n# BWD Zygote\n")

@btime $grad_zy($mlp , $pM , $stM , $x, $y)
@btime $grad_zy($kan1, $pK1, $stK1, $x, $y)
@btime $grad_zy($kan2, $pK2, $stK2, $x, $y)
@btime CUDA.@sync $grad_zy($mlp , $pM , $stM , $x, $y)
@btime CUDA.@sync $grad_zy($kan1, $pK1, $stK1, $x, $y)
@btime CUDA.@sync $grad_zy($kan2, $pK2, $stK2, $x, $y)

println("\n# BWD Reactant\n")

@btime $grad_ra_comp_M( $mlp , $pM_ra , $stM_ra , $x_ra, $y_ra)
@btime $grad_ra_comp_K1($kan1, $pK1_ra, $stK1_ra, $x_ra, $y_ra)
@btime $grad_ra_comp_K2($kan2, $pK2_ra, $stK2_ra, $x_ra, $y_ra)
@btime Reactant.synchronize($grad_ra_comp_M( $mlp , $pM_ra , $stM_ra , $x_ra, $y_ra))
@btime Reactant.synchronize($grad_ra_comp_K1($kan1, $pK1_ra, $stK1_ra, $x_ra, $y_ra))
@btime Reactant.synchronize($grad_ra_comp_K2($kan2, $pK2_ra, $stK2_ra, $x_ra, $y_ra))
#------------------------#

return
Expand Down

0 comments on commit 0fc3498

Please sign in to comment.