Skip to content

Commit

Permalink
Use KA.zeros instead of casting manually
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 committed Oct 12, 2024
1 parent 90bc20d commit 54e4640
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions src/FastDEC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ export dec_wedge_product, cache_wedge, dec_c_wedge_product, dec_c_wedge_product!
# Wedge Product
#--------------

# Cache terms to be used by wedge product kernels.
function cache_wedge_kernel(::Type{Tuple{0,1}}, sd::Union{EmbeddedDeltaDualComplex1D, EmbeddedDeltaDualComplex2D})
# Cache coefficients to be used by wedge product kernels.
function wedge_kernel_coeffs(::Type{Tuple{0,1}}, sd::Union{EmbeddedDeltaDualComplex1D, EmbeddedDeltaDualComplex2D})
(hcat(convert(Vector{Int32}, sd[:∂v0])::Vector{Int32}, convert(Vector{Int32}, sd[:∂v1])::Vector{Int32}),
simplices(1, sd))
ne(sd))
end

function cache_wedge_kernel(::Type{Tuple{0,2}}, sd::EmbeddedDeltaDualComplex2D{Bool, float_type, _p}) where {float_type, _p}
function wedge_kernel_coeffs(::Type{Tuple{0,2}}, sd::EmbeddedDeltaDualComplex2D{Bool, float_type, _p}) where {float_type, _p}
verts = Array{Int32}(undef, 6, ntriangles(sd))
coeffs = Array{float_type}(undef, 6, ntriangles(sd))
shift::Int = ntriangles(sd)
Expand All @@ -52,10 +52,10 @@ function cache_wedge_kernel(::Type{Tuple{0,2}}, sd::EmbeddedDeltaDualComplex2D{B
coeffs[dt, t] = sd[dt_real, :dual_area] / sd[t, :area]
end
end
(verts, coeffs, triangles(sd))
(verts, coeffs, ntriangles(sd))
end

function cache_wedge_kernel(::Type{Tuple{1,1}}, sd::EmbeddedDeltaDualComplex2D{Bool, float_type, _p}) where {float_type, _p}
function wedge_kernel_coeffs(::Type{Tuple{1,1}}, sd::EmbeddedDeltaDualComplex2D{Bool, float_type, _p}) where {float_type, _p}
coeffs = Array{float_type}(undef, 3, ntriangles(sd))
shift = ntriangles(sd)
@inbounds for i in 1:ntriangles(sd)
Expand All @@ -66,27 +66,29 @@ function cache_wedge_kernel(::Type{Tuple{1,1}}, sd::EmbeddedDeltaDualComplex2D{B
end
e = Array{Int32}(undef, 3, ntriangles(sd))
e[1, :], e[2, :], e[3, :] = (2, 0, sd), (2, 1, sd), (2, 2, sd)
(e, coeffs, triangles(sd))
(e, coeffs, ntriangles(sd))
end

# Grab the float type of the volumes of the complex.
function cache_wedge(::Type{Tuple{m,n}}, sd::EmbeddedDeltaDualComplex1D{Bool, float_type, _p}, backend, arr_cons=identity, cast_float=nothing) where {float_type,_p,m,n}
cache_wedge(m, n, sd, float_type, arr_cons, cast_float)
end
function cache_wedge(::Type{Tuple{m,n}}, sd::EmbeddedDeltaDualComplex2D{Bool, float_type, _p}, backend, arr_cons=identity, cast_float=nothing) where {float_type,_p,m,n}
cache_wedge(m, n, sd, float_type, arr_cons, cast_float)
end
# Grab wedge kernel coeffs and cast.
function cache_wedge(m::Int, n::Int, sd::HasDeltaSet1D, float_type::DataType, arr_cons, cast_float::Union{Nothing, DataType})
ft = isnothing(cast_float) ? float_type : cast_float
wc = cache_wedge_kernel(Tuple{m,n}, sd)
wc = wedge_kernel_coeffs(Tuple{m,n}, sd)
if wc[2] isa Matrix
(arr_cons(wc[1]), arr_cons(Matrix{ft}(wc[2])), wc[3])
else
(arr_cons.(wc[1:end-1])..., wc[end])
end
end

# XXX: This assumes that the dual vertex on an edge is always the midpoint.
# TODO: Add options to change 0.5 to a different float
# XXX: 0.5 implies the dual vertex on an edge is the midpoint.
# TODO: Add options to change 0.5 to a different value.
@kernel function wedge_kernel_01!(res, @Const(f), @Const(α), @Const(p), @Const(simples))
@uniform half = eltype(f)(0.5)
i = @index(Global)
Expand All @@ -109,17 +111,17 @@ end
@inbounds res[i] = (c1 * (ae2 * be1 - ae1 * be2) + c2 * (ae2 * be0 - ae0 * be2) + c3 * (ae1 * be0 - ae0 * be1))
end

function auto_select_backend(kernel_function, res, f, α, p, c)
function auto_select_backend(kernel_function, res, α, β, p, c)
backend = get_backend(res)
kernel = kernel_function(backend, backend == CPU() ? 64 : 256)
kernel(res, f, α, p, c, ndrange=size(res))
kernel(res, α, β, p, c, ndrange=size(res))
res
end

# Manually dispatch, since CUDA.jl kernels cannot.
# An alternative would wrap each wedge_kernel separately.
# Alternatively, wrap each wedge_kernel separately.
function dec_c_wedge_product!(::Type{Tuple{j,k}}, res, α, β, p, c) where {j,k}
kernel = if (j,k) == (0,1)
kernel_function = if (j,k) == (0,1)
wedge_kernel_01!
elseif (j,k) == (0,2)
wedge_kernel_02!
Expand All @@ -128,13 +130,12 @@ function dec_c_wedge_product!(::Type{Tuple{j,k}}, res, α, β, p, c) where {j,k}
else
error("Unsupported combination of degrees $j and $k. Ensure that their sum is not greater than the degree of the complex, and the degree of the first is ≤ the degree of the second.")
end
auto_select_backend(kernel, res, α, β, p, c)
auto_select_backend(kernel_function, res, α, β, p, c)
end

# The last item in the wedge_cache is the range of simplices.
function dec_c_wedge_product(::Type{Tuple{m,n}}, α, β, wedge_cache) where {m,n}
arr_type = typeof(α isa SimplexForm ? α.data : α)
res = arr_type(zeros(eltypeisa EForm ? α.data : α), last(last(wedge_cache))))
α_data = α isa SimplexForm ? α.data : α
res = KernelAbstractions.zeros(get_backend(α_data), eltype(α_data), last(wedge_cache))
dec_c_wedge_product!(Tuple{m,n}, res, α, β, wedge_cache[1], wedge_cache[2])
end

Expand All @@ -145,6 +146,13 @@ Return a function that computes the wedge product between a primal `m`-form and
It is assumed...
... for the 0-1 wedge product, that the dual vertex on an edge is at the midpoint.
... for the 1-1 wedge product, that the dual mesh simplices are in the default order as returned by the dual complex constructor.
# Arguments:
`Tuple{m,n}`: the degrees of the differential forms.
`sd`: the simplicial complex.
`backend=Val{:CPU}`: a value-type to select special backend logic, if implemented.
`arr_cons=identity`: a constructor of the desired array type on the appropriate backend e.g. `MtlArray`.
`cast_float=nothing`: a specific Float type to use e.g. `Float32`. Otherwise, the type of the first differential form will be used.
"""
function dec_wedge_product(::Type{Tuple{m,n}}, sd::HasDeltaSet, backend=Val{:CPU}, arr_cons=identity, cast_float=nothing) where {m,n}
error("Unsupported combination of degrees $m and $n. Ensure that their sum is not greater than the degree of the complex.")
Expand Down

0 comments on commit 54e4640

Please sign in to comment.