Skip to content

Commit

Permalink
Update files
Browse files Browse the repository at this point in the history
  • Loading branch information
agdestein committed Nov 3, 2023
1 parent 95eb820 commit f58dc63
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 64 deletions.
58 changes: 33 additions & 25 deletions scratch/train_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ size(data_train.u[1][1][1])
o = Observable(data_train.u[1][end][1])
heatmap(o)
for i = 1:501
# o[] = data_train.u[1][i][1]
o[] = data_train.cF[1][i][1]
o[] = data_train.u[1][i][1]
# o[] = data_train.cF[1][i][1]
sleep(0.001)
end

Expand Down Expand Up @@ -98,7 +98,7 @@ closure, θ₀ = cnn(

# Bias
[true, true, true, false];
)
);

# closure, θ₀ = fno(
# setup,
Expand All @@ -116,28 +116,38 @@ closure, θ₀ = cnn(
# gelu,
# );

@info "Closure model has $(length(θ₀)) parameters"
closure.NN

# Create input/output arrays
function create_io_arrays(data, setup)
nsample = length(data.u)
nt = length(data.u[1]) - 1
D = setup.grid.dimension()
T = eltype(data.u[1][1][1])
(; N) = setup.grid
u = zeros(T, (N .- 2)..., D, nt + 1, nsample)
c = zeros(T, (N .- 2)..., D, nt + 1, nsample)
ifield = ntuple(Returns(:), D)
for i = 1:nsample, j = 1:nt+1, α = 1:D
copyto!(view(u, ifield..., α, j, i), view(data.u[i][j][α], setup.grid.Iu[α]))
copyto!(view(c, ifield..., α, j, i), view(data.cF[i][j][α], setup.grid.Iu[α]))
end
reshape(u, (N .- 2)..., D, :), reshape(c, (N .- 2)..., D, :)
end

# Test data
u_test = device(reshape(data_test.u[:, 1:20, 1:2], :, 40))
c_test = device(reshape(data_test.cF[:, 1:20, 1:2], :, 40))
io_train = create_io_arrays(data_train, setup)
io_valid = create_io_arrays(data_valid, setup)
io_test = create_io_arrays(data_test, setup)

# Prepare training
θ = 1.0f-1 * device(θ₀)
# θ = device(θ₀)
θ = 1.0f-1 * cu(θ₀)
# θ = cu(θ₀)
opt = Optimisers.setup(Adam(1.0f-3), θ)
callbackstate = Point2f[]
randloss = create_randloss(
mean_squared_error,
closure,
data_train.V,
data_train.cF;
nuse = 50,
device,
)
randloss = create_randloss(mean_squared_error, closure, io_train...; nuse = 50, device = cu)

# Warm-up
randloss(θ);
randloss(θ)
@time randloss(θ);
first(gradient(randloss, θ));
@time first(gradient(randloss, θ));
Expand All @@ -153,23 +163,21 @@ first(gradient(randloss, θ));
niter = 2000,
ncallback = 10,
callbackstate,
callback = create_callback(closure, u_test, c_test; state = callbackstate),
)
callback = create_callback(closure, cu(io_valid)...; state = callbackstate),
);
GC.gc()
CUDA.reclaim()

Array(θ)

# # Save trained parameters
# jldsave("output/forced/theta_cnn.jld2"; θ = Array(θ))
# jldsave("output/forced/theta_fno.jld2"; θ = Array(θ))
# jldsave("output/forced/theta_cnn.jld2"; theta = Array(θ))
# jldsave("output/forced/theta_fno.jld2"; theta = Array(θ))

# # Load trained parameters
# θθ = load("output/theta_cnn.jld2")
# θθ = load("output/theta_fno.jld2")
# θθ = θθ["θ"]
# θθ = cu(θθ)
# θ .= θθ
# copyto!(θ, θθ["theta"])

relative_error(closure(device(data_train.V[:, 1, :]), θ), device(data_train.cF[:, 1, :]))
relative_error(
Expand Down
17 changes: 9 additions & 8 deletions src/closures/create_les_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ _filter_saver(dns, les, comp; nupdate = 1) = processor(function (state)
push!(_t, t)
push!(_u, Array.(ubar))
# push!(_p, Array(pbar))
push!(_F, Array.(Fubar))
# push!(_F, Array.(Fubar))
# push!(_FG, Array.(FGbar))
push!(_cF, Array.(cF))
# push!(_cFG, Array.(cFG))
Expand All @@ -77,7 +77,7 @@ _filter_saver(dns, les, comp; nupdate = 1) = processor(function (state)
t = _t,
u = _u,
# p = _p,
F = _F,
# F = _F,
# FG = _FG,
cF = _cF,
# cFG = _cFG,
Expand Down Expand Up @@ -119,17 +119,18 @@ function create_les_data(
Δt,
u = fill(fill(ntuple-> zeros(T, N...), D), 0), 0),
# p = fill(fill(zeros(T, N...), 0), 0),
F = fill(fill(ntuple-> zeros(T, N...), D), 0), 0),
# F = fill(fill(ntuple(α -> zeros(T, N...), D), 0), 0),
# FG = fill(fill(ntuple(α -> zeros(T, N...), D), 0), 0),
cF = fill(fill(ntuple-> zeros(T, N...), D), 0), 0),
# cFG = fill(fill(ntuple(α -> zeros(T, N...), D), 0), 0),
# force = fill(fill(ntuple(α -> zeros(T, N...), D), 0), 0),
)

@info "Generating $(Base.summarysize(filtered) / 1e6) Mb of LES data"
# @info "Generating $(Base.summarysize(filtered) / 1e6) Mb of LES data"
@info "Generating $(nsim * (nt + 1) * nles * 3 * 2 * length(bitstring(zero(T))) / 8 / 1e6) Mb of LES data"

for isim = 1:nsim
@info "Generating data for simulation $isim of $nsim"
# @info "Generating data for simulation $isim of $nsim"

# Initial conditions
u₀, p₀ = random_field(dns, T(0); pressure_solver)
Expand All @@ -155,7 +156,7 @@ function create_les_data(
p₀,
(T(0), tburn);
Δt,
processors = (step_logger(; nupdate = 10),),
# processors = (step_logger(; nupdate = 10),),
pressure_solver,
)

Expand All @@ -169,7 +170,7 @@ function create_les_data(
Δt,
processors = (
_filter_saver(_dns, _les, compression),
step_logger(; nupdate = 10),
# step_logger(; nupdate = 10),
),
pressure_solver,
)
Expand All @@ -178,7 +179,7 @@ function create_les_data(
# Store result for current IC
push!(filtered.u, f.u)
# push!(filtered.p, f.p)
push!(filtered.F, f.F)
# push!(filtered.F, f.F)
# push!(filtered.FG, f.FG)
push!(filtered.cF, f.cF)
# push!(filtered.cFG, f.cFG)
Expand Down
6 changes: 2 additions & 4 deletions src/closures/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@ The function `loss` should take inputs like `loss(f, x, y, θ)`.
The batch is moved to `device` before the loss is evaluated.
"""
function create_randloss(loss, f, x, y; nuse = size(x, 2), device = identity)
x = reshape(x, size(x, 1), :)
y = reshape(y, size(y, 1), :)
nsample = size(x, 2)
function create_randloss(loss, f, x, y; nuse = 50, device = identity)
nsample = size(x)[end]
d = ndims(x)
function randloss(θ)
i = Zygote.@ignore sort(shuffle(1:nsample)[1:nuse])
Expand Down
28 changes: 14 additions & 14 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,25 +317,25 @@ function laplacian!(L, p, setup)
end

"""
interpolate_u_p(setup, u)
interpolate_u_p(u, setup)
Interpolate velocity to pressure points.
"""
interpolate_u_p(setup, u) = interpolate_u_p!(
setup,
interpolate_u_p(u, setup) = interpolate_u_p!(
ntuple(
α -> KernelAbstractions.zeros(get_backend(u[1]), eltype(u[1]), setup.grid.N),
setup.grid.dimension(),
),
u,
setup,
)

"""
interpolate_u_p!(setup, up, u)
interpolate_u_p!(up, u, setup)
Interpolate velocity to pressure points.
"""
function interpolate_u_p!(setup, up, u)
function interpolate_u_p!(up, u, setup)
(; boundary_conditions, grid, Re, bodyforce) = setup
(; dimension, Np, Ip) = grid
D = dimension()
Expand All @@ -354,29 +354,29 @@ function interpolate_u_p!(setup, up, u)
end

"""
interpolate_ω_p(setup, ω)
interpolate_ω_p(ω, setup)
Interpolate vorticity to pressure points.
"""
interpolate_ω_p(setup, ω) = interpolate_ω_p!(
setup,
interpolate_ω_p(ω, setup) = interpolate_ω_p!(
setup.grid.dimension() == 2 ?
KernelAbstractions.zeros(get_backend(ω), eltype(ω), setup.grid.N) :
ntuple(
α -> KernelAbstractions.zeros(get_backend(ω[1]), eltype(ω[1]), setup.grid.N),
setup.grid.dimension(),
),
ω,
setup,
)

"""
interpolate_ω_p!(setup, ωp, ω)
interpolate_ω_p!(ωp, ω, setup)
Interpolate vorticity to pressure points.
"""
interpolate_ω_p!(setup, ωp, ω) = interpolate_ω_p!(setup.grid.dimension, setup, ωp, ω)
interpolate_ω_p!(ωp, ω, setup) = interpolate_ω_p!(setup.grid.dimension, ωp, ω, setup)

function interpolate_ω_p!(::Dimension{2}, setup, ωp, ω)
function interpolate_ω_p!(::Dimension{2}, ωp, ω, setup)
(; boundary_conditions, grid, Re, bodyforce) = setup
(; dimension, Np, Ip) = grid
D = dimension()
Expand All @@ -392,7 +392,7 @@ function interpolate_ω_p!(::Dimension{2}, setup, ωp, ω)
ωp
end

function interpolate_ω_p!(::Dimension{3}, setup, ωp, ω)
function interpolate_ω_p!(::Dimension{3}, ωp, ω, setup)
(; boundary_conditions, grid, Re) = setup
(; dimension, Np, Ip) = grid
D = dimension()
Expand Down Expand Up @@ -521,10 +521,10 @@ Qfield(u, setup) = Qfield!(
Compute total kinetic energy. The velocity components are interpolated to the
volume centers and squared.
"""
function kinetic_energy(setup, u)
function kinetic_energy(u, setup)
(; dimension, Ω, Ip) = setup.grid
D = dimension()
up = interpolate_u_p(setup, u)
up = interpolate_u_p(u, setup)
E = zero(eltype(up[1]))
for α = 1:D
# E += sum(I -> Ω[I] * up[α][I]^2, Ip)
Expand Down
4 changes: 2 additions & 2 deletions src/postprocess/plot_velocity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function plot_velocity(::Dimension{2}, setup, u; kwargs...)
T = eltype(xp[1])

# Get velocity at pressure points
up = interpolate_u_p(setup, u)
up = interpolate_u_p(u, setup)
# qp = map((u, v) -> √(u^2 + v^2), up, vp)
qp = sqrt.(up[1] .^ 2 .+ up[2] .^ 2)

Expand Down Expand Up @@ -43,7 +43,7 @@ function plot_velocity(::Dimension{3}, setup, u; kwargs...)
(; xp) = setup.grid

# Get velocity at pressure points
up = interpolate_u_p(setup, u)
up = interpolate_u_p(u, setup)
qp = map((u, v, w) -> sum(u^2 + v^2 + w^2), up...)

# Levels
Expand Down
4 changes: 2 additions & 2 deletions src/postprocess/plot_vorticity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function plot_vorticity(::Dimension{2}, setup, u; kwargs...)

# Get fields
ω = vorticity(u, setup)
ωp = interpolate_ω_p(setup, ω)
ωp = interpolate_ω_p(ω, setup)
ωp = Array(ωp)[Ip]

# Levels
Expand Down Expand Up @@ -50,7 +50,7 @@ function plot_vorticity(::Dimension{3}, setup, u; kwargs...)
(; grid) = setup
(; xp) = grid

ωp = interpolate_ω_p(setup, vorticity(u, setup))
ωp = interpolate_ω_p(vorticity(u, setup), setup)
qp = map((u, v, w) -> sum(u^2 + v^2 + w^2), ωp...)

# Levels
Expand Down
4 changes: 2 additions & 2 deletions src/postprocess/save_vtk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ function save_vtk(setup, u, p, filename = "output/solution")
D = dimension()
xp = Array.(xp)
vtk_grid(filename, xp...) do vtk
up = interpolate_u_p(setup, u)
ωp = interpolate_ω_p(setup, vorticity(u, setup))
up = interpolate_u_p(u, setup)
ωp = interpolate_ω_p(vorticity(u, setup), setup)
if D == 2
# ParaView prefers 3D vectors. Add zero z-component.
up3 = zero(up[1])
Expand Down
14 changes: 7 additions & 7 deletions src/processors/real_time_plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ function field_plot(

(; u, p, t) = state[]
_f = if fieldname == :velocity
up = interpolate_u_p(setup, u)
up = interpolate_u_p(u, setup)
elseif fieldname == :vorticity
ω = vorticity(u, setup)
ωp = interpolate_ω_p(setup, ω)
ωp = interpolate_ω_p(ω, setup)
elseif fieldname == :streamfunction
ψ = get_streamfunction(setup, u, t)
elseif fieldname == :pressure
Expand All @@ -72,12 +72,12 @@ function field_plot(
isnothing(sleeptime) || sleep(sleeptime)
(; u, p, t) = $state
f = if fieldname == :velocity
interpolate_u_p!(setup, up, u)
interpolate_u_p!(up, u, setup)
map((u, v) -> sum(u^2 + v^2), up...)
elseif fieldname == :vorticity
apply_bc_u!(u, t, setup)
vorticity!(ω, u, setup)
interpolate_ω_p!(setup, ωp, ω)
interpolate_ω_p!(ωp, ω, setup)
elseif fieldname == :streamfunction
get_streamfunction!(setup, ψ, u, t)
elseif fieldname == :pressure
Expand Down Expand Up @@ -181,10 +181,10 @@ function field_plot(
isnothing(sleeptime) || sleep(sleeptime)
(; u, p, t) = $state
f = if fieldname == :velocity
up = interpolate_u_p(setup, u)
up = interpolate_u_p(u, setup)
map((u, v, w) -> sum(u^2 + v^2 + w^2), up...)
elseif fieldname == :vorticity
ωp = interpolate_ω_p(setup, vorticity(u, setup))
ωp = interpolate_ω_p(vorticity(u, setup), setup)
map((u, v, w) -> sum(u^2 + v^2 + w^2), ωp...)
elseif fieldname == :streamfunction
get_streamfunction(setup, u, t)
Expand Down Expand Up @@ -275,7 +275,7 @@ function energy_spectrum_plot(setup, state; displayfig = true)
k = Array(reshape(k, :))
ehat = @lift begin
(; u, p, t) = $state
up = interpolate_u_p(setup, u)
up = interpolate_u_p(u, setup)
e = sum(up -> up[Ip] .^ 2, up)
Array(reshape(abs.(fft(e)[ntuple-> kx[α] .+ 1, D)...]), :))
end
Expand Down

0 comments on commit f58dc63

Please sign in to comment.