Skip to content

Commit

Permalink
Try out TypeGroupedNamedTuple
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 4, 2024
1 parent d4cb583 commit e2a88c3
Showing 1 changed file with 76 additions and 1 deletion.
77 changes: 76 additions & 1 deletion src/cache/precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,78 @@
import Thermodynamics as TD
import ClimaCore: Spaces, Fields

struct UniformNamedTuple{K, T, N}
k::NTuple{N, K}
v::NTuple{N, T}
end
# @inline Base.getproperty(nt::UniformNamedTuple, sym::Symbol) =
# getproperty(nt, Val(sym))
to_named_tuple(nt::UniformNamedTuple) = (; zip(nt.k, nt.v)...)
# Does this need to be compile-time known?
@inline function Base.getproperty(nt::UniformNamedTuple, sym::Symbol)
i = findfirst(s -> s == sym, getfield(nt, :k))
if isnothing(i)
error("No property $sym found in $(to_named_tuple(nt))")
else
@inbounds getproperty(getfield(nt, :v), i)
end
end
Base.eltype(::UniformNamedTuple{K, T, N}) where {K, T, N} = K
Base.length(::UniformNamedTuple{K, T, N}) where {K, T, N} = N
function to_uniform_named_tuple(nt)
k = Tuple(keys(nt))
v = Tuple(values(nt))
K = typeof(first(k))
T = typeof(first(v))
N = length(v)
return UniformNamedTuple{K, T, N}(k, v)
end

struct TypeGroupedNamedTuple{M, C}
tnmap::M
cache::C
end
TypeGroupedNamedTuple() =
TypeGroupedNamedTuple{Nothing, Nothing}(nothing, nothing)

function Base.getproperty(nt::TypeGroupedNamedTuple, sym::Symbol)
cache = getfield(nt, :cache)
tnmap = getfield(nt, :tnmap)
type_grouped_cache = getfield(cache, tnmap[sym])
getproperty(type_grouped_cache, sym)
end
process_key(k) = Symbol(k)

to_named_tuple(dict::Dict) = (; (process_key(k) => v for (k, v) in dict)...)

function type_grouped_named_tuple(flat_cache::NamedTuple)
type_grouped_named_tuple(collect(zip(keys(flat_cache), values(flat_cache))))
end

Base.propertynames(nt::TypeGroupedNamedTuple) =
Tuple(keys(getfield(nt, :tnmap)))
function type_grouped_named_tuple(flat_cache)
type_grouped_cache = Dict{Symbol, Any}()
for (i, (sym, c)) in enumerate(flat_cache)
H = process_key(hash(typeof(c)))
entry = Pair(sym, c)
if haskey(type_grouped_cache, H)
type_grouped_cache[H] = (; type_grouped_cache[H]..., sym => c)
else
type_grouped_cache[H] = (; sym => c)
end
end
type_grouped_cache =
map(x -> to_uniform_named_tuple(x), to_named_tuple(type_grouped_cache))
K = map(symc -> symc[1], flat_cache)
V = map(symc -> process_key(hash(typeof(symc[2]))), flat_cache)
tnmap = Dict(pairs((; zip(K, V)...)))
return TypeGroupedNamedTuple{typeof(tnmap), typeof(type_grouped_cache)}(
tnmap,
type_grouped_cache,
)
end

"""
precomputed_quantities(Y, atmos)
Expand Down Expand Up @@ -156,7 +228,7 @@ function precomputed_quantities(Y, atmos)
ᶜqᵣ = similar(Y.c, FT),
ᶜqₛ = similar(Y.c, FT),
) : (;)
return (;
nt = (;
gs_quantities...,
sgs_quantities...,
advective_sgs_quantities...,
Expand All @@ -165,6 +237,9 @@ function precomputed_quantities(Y, atmos)
precipitation_quantities...,
cloud_diagnostics_tuple,
)
tgnt = type_grouped_named_tuple(nt)
@show typeof(tgnt)
return tgnt
end

# Interpolates the third contravariant component of Y.c.uₕ to cell faces.
Expand Down

0 comments on commit e2a88c3

Please sign in to comment.