Skip to content

Commit

Permalink
add Enzyme as a potential differentiation method for the jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebeggs committed Dec 13, 2024
1 parent c7c0a47 commit f1107cf
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 31 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

# development related
.vscode
dev
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@ Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
MeshIntegralsEnzymeExt = "Enzyme"

[compat]
CliffordNumbers = "0.1.9"
CoordRefSystems = "0.12, 0.13, 0.14, 0.15, 0.16"
Enzyme = "0.13.22"
FastGaussQuadrature = "1"
HCubature = "1.5"
LinearAlgebra = "1"
Expand Down
15 changes: 15 additions & 0 deletions ext/MeshIntegralsEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module MeshIntegralsEnzymeExt

using MeshIntegrals: MeshIntegrals, AutoEnzyme
using Meshes: Meshes
using Enzyme: Enzyme

function MeshIntegrals.jacobian(
geometry::Meshes.Geometry,
ts::Union{AbstractVector{T}, Tuple{T, Vararg{T}}},
::AutoEnzyme
) where {T <: AbstractFloat}
return Meshes.to.(Enzyme.jacobian(Enzyme.Forward, geometry, ts...))
end

end
2 changes: 1 addition & 1 deletion src/MeshIntegrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import QuadGK
import Unitful

include("differentiation.jl")
export DifferentiationMethod, FiniteDifference, jacobian
export DifferentiationMethod, FiniteDifference, AutoEnzyme, jacobian

include("utils.jl")

Expand Down
33 changes: 27 additions & 6 deletions src/differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ A category of types used to specify the desired method for calculating derivativ
Derivatives are used to form Jacobian matrices when calculating the differential
element size throughout the integration region.
See also [`FiniteDifference`](@ref).
See also [`FiniteDifference`](@ref), [`Analytical`](@ref).
"""
abstract type DifferentiationMethod end

Expand All @@ -23,12 +23,33 @@ struct FiniteDifference{T <: AbstractFloat} <: DifferentiationMethod
ε::T
end

# Default constructors
FiniteDifference{T}() where {T <: AbstractFloat} = FiniteDifference{T}(T(1e-6))
FiniteDifference() = FiniteDifference{Float64}()
# If ε not specified, default to 1e-6
FiniteDifference() = FiniteDifference(1e-6)

"""
Analytical()
Use to specify use of analytically-derived solutions for calculating derivatives.
These solutions are currently defined only for a subset of geometry types.
# Supported Geometries:
- `BezierCurve`
- `Line`
- `Plane`
- `Ray`
- `Tetrahedron`
- `Triangle`
"""
struct Analytical <: DifferentiationMethod end

"""
AutoEnzyme()
Use to specify use of the Enzyme.jl for calculating derivatives.
"""
struct AutoEnzyme <: DifferentiationMethod end

# Future Support:
# struct AutoEnzyme <: DifferentiationMethod end
# struct AutoZygote <: DifferentiationMethod end

################################################################################
Expand Down Expand Up @@ -68,7 +89,7 @@ function jacobian(
# Get the partial derivative along the n'th axis via finite difference
# approximation, where ts is the current parametric position
function ∂ₙr(ts, n, ε)
# Build left/right parametric coordinates with non-allocating iterators
# Build left/right parametric coordinates with non-allocating iterators
left = Iterators.map(((i, t),) -> i == n ? t - ε : t, enumerate(ts))
right = Iterators.map(((i, t),) -> i == n ? t + ε : t, enumerate(ts))
# Select orientation of finite-diff
Expand Down
10 changes: 8 additions & 2 deletions src/integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
################################################################################

"""
integral(f, geometry[, rule]; diff_method=_default_method(geometry), FP=Float64)
integral(f, geometry[, rule]; diff_method=_default_diff_method(geometry), FP=Float64)
Numerically integrate a given function `f(::Point)` over the domain defined by
a `geometry` using a particular numerical integration `rule` with floating point
Expand All @@ -16,7 +16,7 @@ precision of type `FP`.
`GaussKronrod()` in 1D and `HAdaptiveCubature()` else)
# Keyword Arguments
- `diff_method::DifferentiationMethod = _default_method(geometry)`: the method to
- `diff_method::DifferentiationMethod = _default_diff_method(geometry)`: the method to
use for calculating Jacobians that are used to calculate differential elements
- `FP = Float64`: the floating point precision desired.
"""
Expand Down Expand Up @@ -44,6 +44,8 @@ function _integral(
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(geometry)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(geometry, diff_method)

# Implementation depends on number of parametric dimensions over which to integrate
N = Meshes.paramdim(geometry)
if N == 1
Expand Down Expand Up @@ -72,6 +74,8 @@ function _integral(
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(geometry)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(geometry, diff_method)

N = Meshes.paramdim(geometry)

# Get Gauss-Legendre nodes and weights of type FP for a region [-1,1]ᴺ
Expand Down Expand Up @@ -101,6 +105,8 @@ function _integral(
FP::Type{T} = Float64,
diff_method::DM = _default_diff_method(geometry)
) where {DM <: DifferentiationMethod, T <: AbstractFloat}
_check_diff_method_support(geometry, diff_method)

N = Meshes.paramdim(geometry)

integrand(ts) = f(geometry(ts...)) * differential(geometry, ts, diff_method)
Expand Down
7 changes: 5 additions & 2 deletions src/specializations/BezierCurve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@ function integral(
curve::Meshes.BezierCurve,
rule::IntegrationRule;
alg::Meshes.BezierEvalMethod = Meshes.Horner(),
diff_method::DM = _default_diff_method(curve),
kwargs...
)
) where {DM <: DifferentiationMethod}
_check_diff_method_support(curve, diff_method)

# Generate a _ParametricGeometry whose parametric function auto-applies the alg kwarg
param_curve = _ParametricGeometry(_parametric(curve, alg), Meshes.paramdim(curve))

# Integrate the _ParametricGeometry using the standard methods
return _integral(f, param_curve, rule; kwargs...)
return _integral(f, param_curve, rule; diff_method = diff_method, kwargs...)
end

################################################################################
Expand Down
11 changes: 7 additions & 4 deletions src/specializations/CylinderSurface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@ function integral(
f,
cyl::Meshes.CylinderSurface,
rule::I;
diff_method::DM = _default_diff_method(cyl),
kwargs...
) where {I <: IntegrationRule}
) where {I <: IntegrationRule, DM <: DifferentiationMethod}
_check_diff_method_support(cyl, diff_method)

# The generic method only parametrizes the sides
sides = _integral(f, cyl, rule; kwargs...)
sides = _integral(f, cyl, rule; diff_method = diff_method, kwargs...)

# Integrate the Disk at the top
disk_top = Meshes.Disk(cyl.top, cyl.radius)
top = _integral(f, disk_top, rule; kwargs...)
top = _integral(f, disk_top, rule; diff_method = diff_method, kwargs...)

# Integrate the Disk at the bottom
disk_bottom = Meshes.Disk(cyl.bot, cyl.radius)
bottom = _integral(f, disk_bottom, rule; kwargs...)
bottom = _integral(f, disk_bottom, rule; diff_method = diff_method, kwargs...)

return sides + top + bottom
end
11 changes: 10 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ end
# DifferentiationMethod
################################################################################

_check_diff_method_support(::Geometry, ::DifferentiationMethod) = nothing

# Return the default DifferentiationMethod instance for a particular geometry type
function _default_diff_method(
g::Type{G}
Expand All @@ -20,7 +22,14 @@ function _default_diff_method(
end

# Return the default DifferentiationMethod instance for a particular geometry instance
_default_diff_method(g::G) where {G <: Geometry} = _default_diff_method(G)
_default_diff_method(::G) where {G <: Geometry} = _default_diff_method(G)

non_enzyme_types = (:BezierCurve, :CylinderSurface, :Cylinder, :ParametrizedCurve)
for geometry_type in non_enzyme_types
@eval function _check_diff_method_support(::Meshes.$geometry_type, ::AutoEnzyme)
throw(ArgumentError("Differentiation method AutoEnzyme not supported for $(string(Meshes.$geometry_type))."))
end
end

################################################################################
# Numerical Tools
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CoordRefSystems = "b46f11dc-f210-4604-bfba-323c1ec968cb"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
Expand Down
41 changes: 26 additions & 15 deletions test/combinations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ This file includes tests for:
using Meshes
using MeshIntegrals
using Unitful
using Enzyme

# Used for testing callable objects as integrand functions
struct Callable{F <: Function}
Expand All @@ -43,27 +44,27 @@ This file includes tests for:
gausslegendre::Bool
hadaptivecubature::Bool
# DifferentiationMethods
# autoenzyme::Bool
autoenzyme::Bool
end

# Shortcut constructor for geometries with typical support structure
function SupportStatus(geometry::Geometry)
function SupportStatus(geometry::Geometry, autoenzyme = true)
if paramdim(geometry) == 1
aliases = Bool.((1, 0, 0))
rules = Bool.((1, 1, 1))
return SupportStatus(aliases..., rules...)
return SupportStatus(aliases..., rules..., autoenzyme)
elseif paramdim(geometry) == 2
aliases = Bool.((0, 1, 0))
rules = Bool.((1, 1, 1))
return SupportStatus(aliases..., rules...)
return SupportStatus(aliases..., rules..., autoenzyme)
elseif paramdim(geometry) == 3
aliases = Bool.((0, 0, 1))
rules = Bool.((0, 1, 1))
return SupportStatus(aliases..., rules...)
return SupportStatus(aliases..., rules..., autoenzyme)
else
aliases = Bool.((0, 0, 0))
rules = Bool.((0, 1, 1))
return SupportStatus(aliases..., rules...)
return SupportStatus(aliases..., rules..., autoenzyme)
end
end

Expand Down Expand Up @@ -110,15 +111,19 @@ This file includes tests for:
end
end # for

#=
iter_diff_methods = (
(supports.autoenzyme, AutoEnzyme()),
)

for (supported, diff_method) in iter_diff_methods
@test integral(testable.integrand, testable.geometry; diff_method=diff_method)≈sol rtol=rtol
end
=#
if supported
@test integral(
testable.integrand, testable.geometry; diff_method = diff_method)testable.solution rtol=rtol
else
@test_throws "not supported" integral(
testable.integrand, testable.geometry; diff_method = diff_method)
end
end # for
end # function
end #testsnippet

Expand Down Expand Up @@ -180,7 +185,8 @@ end

# Package and run tests
testable = TestableGeometry(integrand, curve, solution)
runtests(testable; rtol = 0.5e-2)
supports = SupportStatus(curve, false)
runtests(testable, supports; rtol = 0.5e-2)
end

@testitem "Meshes.Box 1D" setup=[Combinations] begin
Expand Down Expand Up @@ -321,7 +327,8 @@ end

# Package and run tests
testable = TestableGeometry(integrand, cyl, solution)
runtests(testable)
supports = SupportStatus(cyl, false)
runtests(testable, supports)
end

@testitem "Meshes.CylinderSurface" setup=[Combinations] begin
Expand All @@ -336,7 +343,8 @@ end

# Package and run tests
testable = TestableGeometry(integrand, cyl, solution)
runtests(testable)
supports = SupportStatus(cyl, false)
runtests(testable, supports)
end

@testitem "Meshes.Disk" setup=[Combinations] begin
Expand Down Expand Up @@ -476,9 +484,12 @@ end

# Package and run tests
testable_cart = TestableGeometry(integrand, curve_cart, solution)
runtests(testable_cart)
supports_cart = SupportStatus(curve_cart, false)
runtests(testable_cart, supports_cart)

testable_polar = TestableGeometry(integrand, curve_polar, solution)
runtests(testable_polar)
supports_polar = SupportStatus(curve_polar, false)
runtests(testable_polar, supports_polar)
end
end

Expand Down

0 comments on commit f1107cf

Please sign in to comment.