Skip to content

Commit

Permalink
Make plotting work on GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
efaulhaber committed Dec 23, 2024
1 parent 0d99560 commit e75b9b9
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions src/visualization/recipes_plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,45 @@ const TrixiParticlesODESolution = ODESolution{<:Any, <:Any, <:Any, <:Any, <:Any,
<:Semidiscretization}}

RecipesBase.@recipe function f(sol::TrixiParticlesODESolution)
# Redirect everything to the recipe
# Redirect everything to the next recipe
return sol.u[end].x..., sol.prob.p
end

# GPU version
RecipesBase.@recipe function f(sol::TrixiParticlesODESolution, semi::Semidiscretization)
# Move GPU data to the CPU
v_ode = Array(sol.u[end].x[1])
u_ode = Array(sol.u[end].x[2])
semi_ = Adapt.adapt(Array, sol.prob.p)

# Redirect everything to the next recipe
return v_ode, u_ode, semi_, particle_spacings(semi)
end

RecipesBase.@recipe function f(v_ode::AbstractGPUArray, u_ode, semi::Semidiscretization;
particle_spacings=nothing,
size=(600, 400), # Default size
xlims=(-Inf, Inf), ylims=(-Inf, Inf))
throw(ArgumentError("to plot GPU data, use `plot(sol, semi)`"))
end

RecipesBase.@recipe function f(v_ode, u_ode, semi::Semidiscretization;
particle_spacings=TrixiParticles.particle_spacings(semi),
size=(600, 400), # Default size
xlims=(-Inf, Inf), ylims=(-Inf, Inf))
systems_data = map(semi.systems) do system
return v_ode, u_ode, semi, particle_spacings
end

RecipesBase.@recipe function f(v_ode, u_ode, semi::Semidiscretization, particle_spacings;
size=(600, 400), # Default size
xlims=(-Inf, Inf), ylims=(-Inf, Inf))
systems_data = map(enumerate(semi.systems)) do (i, system)
u = wrap_u(u_ode, system, semi)
coordinates = active_coordinates(u, system)
x = collect(coordinates[1, :])
y = collect(coordinates[2, :])

particle_spacing = system.initial_condition.particle_spacing
particle_spacing = particle_spacings[i]
if particle_spacing < 0
particle_spacing = 0.0
end
Expand All @@ -41,6 +66,10 @@ RecipesBase.@recipe function f(v_ode, u_ode, semi::Semidiscretization;
return (semi, systems_data...)
end

function particle_spacings(semi::Semidiscretization)
return [system.initial_condition.particle_spacing for system in semi.systems]
end

RecipesBase.@recipe function f((initial_conditions::InitialCondition)...;
xlims=(Inf, Inf), ylims=(Inf, Inf))
idx = 0
Expand Down

0 comments on commit e75b9b9

Please sign in to comment.