Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add AD (Enzyme) support via MeshIntegralsEnzymeExt #152

Merged
merged 24 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f1107cf
add Enzyme as a potential differentiation method for the jacobian
kylebeggs Dec 13, 2024
c776dc9
refactor check for enzyme support
kylebeggs Dec 13, 2024
9f3aadc
add FP to _default_diff_method
kylebeggs Dec 13, 2024
dfe12a1
add `using Enzyme` to benchmarks.jl
kylebeggs Dec 13, 2024
539500b
update CoordRefSystems.jl compat
kylebeggs Dec 13, 2024
875356f
add Enzyme to Benchmark Project.toml
kylebeggs Dec 13, 2024
c221137
fix Meshes compat in Benchmark Project.toml
kylebeggs Dec 13, 2024
8ecff85
use import Enzyme, not using Enzyme
kylebeggs Dec 13, 2024
76999ac
fix typo in Benchmarks Project.toml
kylebeggs Dec 13, 2024
bd8174d
remove Meshes version check in combinations.jl
kylebeggs Dec 13, 2024
1786576
Apply format suggestion
kylebeggs Dec 13, 2024
3747154
Update test/Project.toml
kylebeggs Dec 13, 2024
0d669a0
Bump compat of Enzyme to v0.13.19
JoshuaLampert Dec 13, 2024
f74aa98
test supports_autoenzyme to combinations; test both backends for wron…
kylebeggs Dec 13, 2024
08b9bf2
Restore recently-updated FiniteDifference constructors
mikeingold Dec 14, 2024
45949a1
Add docstrings, formatting
mikeingold Dec 14, 2024
e05c8ad
Formatting
mikeingold Dec 14, 2024
8d19413
Add test for two-arg jacobian
mikeingold Dec 14, 2024
570a53e
Use rest of MeshIntegrals namespace
mikeingold Dec 14, 2024
71215f1
Disambiguate use of jacobian
mikeingold Dec 14, 2024
d2eff83
fix test
JoshuaLampert Dec 14, 2024
c4f6b92
use `import Enzyme`
kylebeggs Dec 14, 2024
7f25b8a
use `import Enzyme`
kylebeggs Dec 14, 2024
65dfe9c
remove unneeded MeshIntegrals.jl
kylebeggs Dec 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

# development related
.vscode
dev
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,20 @@ Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
MeshIntegralsEnzymeExt = "Enzyme"

[compat]
CliffordNumbers = "0.1.9"
CoordRefSystems = "0.12, 0.13, 0.14, 0.15, 0.16"
CoordRefSystems = "0.15, 0.16"
Enzyme = "0.13.19"
FastGaussQuadrature = "1"
HCubature = "1.5"
LinearAlgebra = "1"
Meshes = "0.50, 0.51, 0.52"
Meshes = "0.51.20, 0.52"
QuadGK = "2.1.1"
Unitful = "1.19"
julia = "1.9"
4 changes: 3 additions & 1 deletion benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
BenchmarkTools = "1.5"
Enzyme = "0.13.19"
LinearAlgebra = "1"
Meshes = "0.50, 0.51, 0.52"
Meshes = "0.51.20, 0.52"
Unitful = "1.19"
julia = "1.9"
1 change: 1 addition & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using LinearAlgebra
using Meshes
using MeshIntegrals
using Unitful
import Enzyme

const SUITE = BenchmarkGroup()

Expand Down
19 changes: 19 additions & 0 deletions ext/MeshIntegralsEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module MeshIntegralsEnzymeExt

using MeshIntegrals: MeshIntegrals, AutoEnzyme
using Meshes: Meshes
using Enzyme: Enzyme

function MeshIntegrals.jacobian(
geometry::Meshes.Geometry,
ts::Union{AbstractVector{T}, Tuple{T, Vararg{T}}},
::AutoEnzyme
) where {T <: AbstractFloat}
Dim = Meshes.paramdim(geometry)
if Dim != length(ts)
throw(ArgumentError("ts must have same number of dimensions as geometry."))
end
return Meshes.to.(Enzyme.jacobian(Enzyme.Forward, geometry, ts...))
end

end
2 changes: 1 addition & 1 deletion src/MeshIntegrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import QuadGK
import Unitful

include("differentiation.jl")
export DifferentiationMethod, FiniteDifference, jacobian
export DifferentiationMethod, FiniteDifference, AutoEnzyme, jacobian

include("utils.jl")

Expand Down
16 changes: 11 additions & 5 deletions src/differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ A category of types used to specify the desired method for calculating derivativ
Derivatives are used to form Jacobian matrices when calculating the differential
element size throughout the integration region.

See also [`FiniteDifference`](@ref).
See also [`FiniteDifference`](@ref), [`AutoEnzyme`](@ref).
"""
abstract type DifferentiationMethod end

Expand All @@ -27,8 +27,14 @@ end
FiniteDifference{T}() where {T <: AbstractFloat} = FiniteDifference{T}(T(1e-6))
FiniteDifference() = FiniteDifference{Float64}()

"""
AutoEnzyme()

Use to specify use of the Enzyme.jl for calculating derivatives.
"""
struct AutoEnzyme <: DifferentiationMethod end

# Future Support:
# struct AutoEnzyme <: DifferentiationMethod end
# struct AutoZygote <: DifferentiationMethod end

################################################################################
Expand All @@ -52,7 +58,7 @@ function jacobian(
geometry::G,
ts::Union{AbstractVector{T}, Tuple{T, Vararg{T}}}
) where {G <: Geometry, T <: AbstractFloat}
return jacobian(geometry, ts, _default_diff_method(G))
return jacobian(geometry, ts, _default_diff_method(G, T))
end

function jacobian(
Expand All @@ -68,7 +74,7 @@ function jacobian(
# Get the partial derivative along the n'th axis via finite difference
# approximation, where ts is the current parametric position
function ∂ₙr(ts, n, ε)
# Build left/right parametric coordinates with non-allocating iterators
# Build left/right parametric coordinates with non-allocating iterators
left = Iterators.map(((i, t),) -> i == n ? t - ε : t, enumerate(ts))
right = Iterators.map(((i, t),) -> i == n ? t + ε : t, enumerate(ts))
# Select orientation of finite-diff
Expand Down Expand Up @@ -107,7 +113,7 @@ possible and finite difference approximations otherwise.
function differential(
geometry::G,
ts::Union{AbstractVector{T}, Tuple{T, Vararg{T}}},
diff_method::DifferentiationMethod = _default_diff_method(G)
diff_method::DifferentiationMethod = _default_diff_method(G, T)
) where {G <: Geometry, T <: AbstractFloat}
J = Iterators.map(_KVector, jacobian(geometry, ts, diff_method))
return LinearAlgebra.norm(foldl(∧, J))
Expand Down
16 changes: 11 additions & 5 deletions src/integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
################################################################################

"""
integral(f, geometry[, rule]; diff_method=_default_method(geometry), FP=Float64)
integral(f, geometry[, rule]; diff_method=_default_diff_method(geometry, FP), FP=Float64)

Numerically integrate a given function `f(::Point)` over the domain defined by
a `geometry` using a particular numerical integration `rule` with floating point
Expand All @@ -16,7 +16,7 @@ precision of type `FP`.
`GaussKronrod()` in 1D and `HAdaptiveCubature()` else)

# Keyword Arguments
- `diff_method::DifferentiationMethod = _default_method(geometry)`: the method to
- `diff_method::DifferentiationMethod = _default_diff_method(geometry, FP)`: the method to
use for calculating Jacobians that are used to calculate differential elements
- `FP = Float64`: the floating point precision desired.
"""
Expand All @@ -42,8 +42,10 @@ function _integral(
geometry,
rule::GaussKronrod;
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(geometry)
diff_method::DM = _default_diff_method(geometry, FP)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(geometry, diff_method)

# Implementation depends on number of parametric dimensions over which to integrate
N = Meshes.paramdim(geometry)
if N == 1
Expand All @@ -70,8 +72,10 @@ function _integral(
geometry,
rule::GaussLegendre;
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(geometry)
diff_method::DM = _default_diff_method(geometry, FP)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(geometry, diff_method)

N = Meshes.paramdim(geometry)

# Get Gauss-Legendre nodes and weights of type FP for a region [-1,1]ᴺ
Expand Down Expand Up @@ -99,8 +103,10 @@ function _integral(
geometry,
rule::HAdaptiveCubature;
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(geometry)
diff_method::DM = _default_diff_method(geometry, FP)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(geometry, diff_method)

N = Meshes.paramdim(geometry)

integrand(ts) = f(geometry(ts...)) * differential(geometry, ts, diff_method)
Expand Down
8 changes: 6 additions & 2 deletions src/specializations/BezierCurve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@ function integral(
curve::Meshes.BezierCurve,
rule::IntegrationRule;
alg::Meshes.BezierEvalMethod = Meshes.Horner(),
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(curve, FP),
kwargs...
)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(curve, diff_method)

# Generate a _ParametricGeometry whose parametric function auto-applies the alg kwarg
param_curve = _ParametricGeometry(_parametric(curve, alg), Meshes.paramdim(curve))

# Integrate the _ParametricGeometry using the standard methods
return _integral(f, param_curve, rule; kwargs...)
return _integral(f, param_curve, rule; diff_method = diff_method, FP = FP, kwargs...)
end

################################################################################
Expand Down
12 changes: 8 additions & 4 deletions src/specializations/CylinderSurface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,22 @@ function integral(
f,
cyl::Meshes.CylinderSurface,
rule::I;
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(cyl, FP),
kwargs...
) where {I <: IntegrationRule}
) where {I <: IntegrationRule, DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(cyl, diff_method)

# The generic method only parametrizes the sides
sides = _integral(f, cyl, rule; kwargs...)
sides = _integral(f, cyl, rule; diff_method = diff_method, FP = FP, kwargs...)

# Integrate the Disk at the top
disk_top = Meshes.Disk(cyl.top, cyl.radius)
top = _integral(f, disk_top, rule; kwargs...)
top = _integral(f, disk_top, rule; diff_method = diff_method, FP = FP, kwargs...)

# Integrate the Disk at the bottom
disk_bottom = Meshes.Disk(cyl.bot, cyl.radius)
bottom = _integral(f, disk_bottom, rule; kwargs...)
bottom = _integral(f, disk_bottom, rule; diff_method = diff_method, FP = FP, kwargs...)

return sides + top + bottom
end
47 changes: 41 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,50 @@ end
# DifferentiationMethod
################################################################################

# Return the default DifferentiationMethod instance for a particular geometry type
"""
supports_autoenzyme(geometry)

Return whether a geometry (or geometry type) has a parametric function that can be
differentiated with Enzyme. See GitHub Issue #154 for more information.
"""
supports_autoenzyme(::Type{<:Meshes.Geometry}) = true
supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false
supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false
supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false
supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false
supports_autoenzyme(::G) where {G <: Geometry} = supports_autoenzyme(G)

"""
_check_diff_method_support(::Geometry, ::DifferentiationMethod) -> nothing

Throw an error if incompatible geometry-diff_method combination detected.
"""
_check_diff_method_support(::Geometry, ::DifferentiationMethod) = nothing
function _check_diff_method_support(geometry::Geometry, ::AutoEnzyme)
if !supports_autoenzyme(geometry)
throw(ArgumentError("AutoEnzyme not supported for this geometry."))
end
end

"""
_default_diff_method(geometry, FP)

Return an instance of the default DifferentiationMethod for a particular geometry
(or geometry type) and floating point type.
"""
function _default_diff_method(
g::Type{G}
) where {G <: Geometry}
return FiniteDifference()
g::Type{G}, FP::Type{T}
) where {G <: Geometry, T <: AbstractFloat}
if supports_autoenzyme(g) && FP <: Union{Float32, Float64}
AutoEnzyme()
else
FiniteDifference()
end
end

# Return the default DifferentiationMethod instance for a particular geometry instance
_default_diff_method(g::G) where {G <: Geometry} = _default_diff_method(G)
function _default_diff_method(::G, ::Type{T}) where {G <: Geometry, T <: AbstractFloat}
_default_diff_method(G, T)
end

################################################################################
# Numerical Tools
Expand Down
6 changes: 4 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CoordRefSystems = "b46f11dc-f210-4604-bfba-323c1ec968cb"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
kylebeggs marked this conversation as resolved.
Show resolved Hide resolved
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
Expand All @@ -12,10 +13,11 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
Aqua = "0.7, 0.8"
CoordRefSystems = "0.12, 0.13, 0.14, 0.15, 0.16"
CoordRefSystems = "0.15, 0.16"
Enzyme = "0.13.19"
ExplicitImports = "1.6.0"
LinearAlgebra = "1"
Meshes = "0.50, 0.51, 0.52"
Meshes = "0.51.20, 0.52"
SpecialFunctions = "2"
TestItemRunner = "1"
TestItems = "1"
Expand Down
Loading
Loading