-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #126 from agdestein/tensorclosure
Tensor closure
- Loading branch information
Showing
15 changed files
with
451 additions
and
107 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,11 +4,13 @@ authors = ["Syver Døving Agdestein <[email protected]>"] | |
version = "1.0.0" | ||
|
||
[deps] | ||
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" | ||
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" | ||
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" | ||
IncompressibleNavierStokes = "5e318141-6589-402b-868d-77d7df8c442e" | ||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" | ||
|
@@ -20,23 +22,25 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | |
NeuralClosure = "099dac27-d7f2-4047-93d5-0baee36b9c25" | ||
Observables = "510215fc-4207-5dde-b226-833fc4488ee2" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" | ||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" | ||
WGLMakie = "276b4fcb-3e11-5398-bf8b-a0c2d153d008" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[sources.IncompressibleNavierStokes] | ||
path = "../.." | ||
|
||
[sources.NeuralClosure] | ||
path = "../NeuralClosure" | ||
[sources] | ||
IncompressibleNavierStokes = {path = "../.."} | ||
NeuralClosure = {path = "../NeuralClosure"} | ||
|
||
[compat] | ||
Accessors = "0.1" | ||
Adapt = "4" | ||
CUDA = "5" | ||
CairoMakie = "0.12" | ||
ChainRulesCore = "1.25.0" | ||
ChainRulesCore = "1" | ||
ComponentArrays = "0.15" | ||
Dates = "1" | ||
FFTW = "1" | ||
IncompressibleNavierStokes = "2" | ||
JLD2 = "0.5" | ||
|
@@ -48,6 +52,8 @@ NNlib = "0.9" | |
NeuralClosure = "1" | ||
Observables = "0.5" | ||
Optimisers = "0.3, 0.4" | ||
StaticArrays = "1.9.8" | ||
ParameterSchedulers = "0.4" | ||
StaticArrays = "1" | ||
WGLMakie = "0.10" | ||
Zygote = "0.6, 0.7" | ||
julia = "1.9" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,36 @@ | ||
module SymmetryClosure | ||
|
||
using StaticArrays | ||
using Accessors | ||
using Adapt | ||
using ChainRulesCore | ||
using CUDA | ||
using Dates | ||
using IncompressibleNavierStokes | ||
using JLD2 | ||
using KernelAbstractions | ||
using LinearAlgebra | ||
using IncompressibleNavierStokes | ||
using Lux | ||
using NeuralClosure | ||
using NNlib | ||
using Optimisers | ||
using Random | ||
using StaticArrays | ||
|
||
include("tensorclosure.jl") | ||
include("setup.jl") | ||
include("cases.jl") | ||
include("train.jl") | ||
|
||
export tensorclosure, polynomial | ||
export tensorclosure, polynomial, create_cnn | ||
export slurm_vars, | ||
time_info, | ||
hardware, | ||
splatfileparts, | ||
getdatafile, | ||
namedtupleload, | ||
splitseed, | ||
getsetup, | ||
testcase | ||
export trainpost | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,27 @@ | ||
function testcase() | ||
# Choose where to put output | ||
function testcase(backend) | ||
basedir = haskey(ENV, "DEEPDIP") ? ENV["DEEPDIP"] : joinpath(@__DIR__, "..") | ||
outdir = mkpath(joinpath(basedir, "output", "kolmogorov")) | ||
|
||
outdir = mkpath(joinpath(basedir, "output", "Kolmogorov2D")) | ||
plotdir = mkpath(joinpath(outdir, "plots")) | ||
seed_dns = 123 | ||
ntrajectory = 8 | ||
T = Float32 | ||
|
||
params = (; | ||
D = 2, | ||
lims = (T(0), T(1)), | ||
Re = T(6e3), | ||
tburn = T(0.5), | ||
tsim = T(5), | ||
tsim = T(3), | ||
savefreq = 50, | ||
ndns = 4096, | ||
ndns = 2048, | ||
nles = [32, 64, 128], | ||
filters = [FaceAverage()], | ||
backend, | ||
icfunc = (setup, psolver, rng) -> | ||
random_field(setup, T(0); kp = 20, psolver, rng), | ||
method = RKMethods.LMWray3(; T), | ||
method = LMWray3(; T), | ||
bodyforce = (dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y), | ||
issteadybodyforce = true, | ||
processors = (; log = timelogger(; nupdate = 100)), | ||
) | ||
(; outdir, plotdir, seed_dns, ntrajectory, params) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Some script utils | ||
|
||
function slurm_vars() | ||
jobid = haskey(ENV, "SLURM_JOB_ID") ? parse(Int, ENV["SLURM_JOB_ID"]) : nothing | ||
taskid = | ||
haskey(ENV, "SLURM_ARRAY_TASK_ID") ? parse(Int, ENV["SLURM_ARRAY_TASK_ID"]) : | ||
nothing | ||
isnothing(jobid) || @info "Running on SLURM" jobid taskid | ||
(; jobid, taskid) | ||
end | ||
|
||
function time_info() | ||
@info "Starting at $(Dates.now())" | ||
@info """ | ||
Last commit: | ||
$(cd(() -> read(`git log -n 1`, String), @__DIR__)) | ||
""" | ||
end | ||
|
||
hardware() = | ||
if CUDA.functional() | ||
@info "Running on CUDA" | ||
CUDA.allowscalar(false) | ||
backend = CUDABackend() | ||
device = x -> adapt(backend, x) | ||
clean = () -> (GC.gc(); CUDA.reclaim()) | ||
(; backend, device, clean) | ||
else | ||
@warn """ | ||
Running on CPU. | ||
Consider reducing the size of DNS, LES, and CNN layers if | ||
you want to test run on a laptop. | ||
""" | ||
(; backend = CPU(), device = identity, clean = () -> nothing) | ||
end | ||
|
||
function splatfileparts(args...; kwargs...) | ||
sargs = string.(args) | ||
skwargs = map((k, v) -> string(k) * "=" * string(v), keys(kwargs), values(kwargs)) | ||
s = [sargs..., skwargs...] | ||
join(s, "_") | ||
end | ||
|
||
getdatafile(outdir, nles, filter, seed) = | ||
joinpath(outdir, "data", splatfileparts(; seed = repr(seed), filter, nles) * ".jld2") | ||
|
||
function namedtupleload(file) | ||
dict = load(file) | ||
k, v = keys(dict), values(dict) | ||
pairs = @. Symbol(k) => v | ||
(; pairs...) | ||
end | ||
|
||
getsetup(; params, nles) = Setup(; | ||
x = ntuple(α -> range(params.lims..., nles + 1), params.D), | ||
params.Re, | ||
params.backend, | ||
params.bodyforce, | ||
params.issteadybodyforce, | ||
) |
Oops, something went wrong.