Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug where integral fails when Enzyme ext not loaded #160

Merged
merged 19 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ext/MeshIntegralsEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,12 @@ function MeshIntegrals.jacobian(
return Meshes.to.(Enzyme.jacobian(Enzyme.Forward, geometry, ts...))
end

# Supports all geometries except for those that throw errors
# See GitHub Issue #154 for more information
MeshIntegrals.supports_autoenzyme(::Type{<:Meshes.Geometry}) = true
MeshIntegrals.supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false
MeshIntegrals.supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false
MeshIntegrals.supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false
MeshIntegrals.supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false

end
43 changes: 29 additions & 14 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,63 @@ end
################################################################################

"""
supports_autoenzyme(geometry)
supports_autoenzyme(geometry::Geometry)
supports_autoenzyme(type::Type{<:Geometry})

Return whether a geometry (or geometry type) has a parametric function that can be
differentiated with Enzyme. See GitHub Issue #154 for more information.
"""
supports_autoenzyme(::Type{<:Meshes.Geometry}) = true
supports_autoenzyme(::Type{<:Meshes.BezierCurve}) = false
supports_autoenzyme(::Type{<:Meshes.CylinderSurface}) = false
supports_autoenzyme(::Type{<:Meshes.Cylinder}) = false
supports_autoenzyme(::Type{<:Meshes.ParametrizedCurve}) = false
function supports_autoenzyme end

# Returns false for all geometries when Enzyme extension is not loaded
supports_autoenzyme(::Type{<:Any}) = false

# If provided a geometry instance, re-run with the type as argument
supports_autoenzyme(::G) where {G <: Geometry} = supports_autoenzyme(G)

"""
_check_diff_method_support(::Geometry, ::DifferentiationMethod) -> nothing

Throw an error if incompatible geometry-diff_method combination detected.
Throw an error if incompatible combination {geometry, diff_method} detected.
"""
_check_diff_method_support(::Geometry, ::DifferentiationMethod) = nothing
function _check_diff_method_support end

# If diff_method == Enzyme, then perform check
function _check_diff_method_support(geometry::Geometry, ::AutoEnzyme)
if !supports_autoenzyme(geometry)
throw(ArgumentError("AutoEnzyme not supported for this geometry."))
end
end

# If diff_method != AutoEnzyme, then do nothing
_check_diff_method_support(::Geometry, ::DifferentiationMethod) = nothing

"""
_default_diff_method(geometry, FP)

Return an instance of the default DifferentiationMethod for a particular geometry
(or geometry type) and floating point type.
"""
function _default_diff_method(
g::Type{G}, FP::Type{T}
::Type{G},
FP::Type{T}
) where {G <: Geometry, T <: AbstractFloat}
if supports_autoenzyme(g) && FP <: Union{Float32, Float64}
AutoEnzyme()
# Enzyme only works with these FP types
uses_Enzyme_supported_FP_type = (FP <: Union{Float32, Float64})

if supports_autoenzyme(G) && uses_Enzyme_supported_FP_type
return AutoEnzyme()
else
FiniteDifference()
return FiniteDifference()
end
end

function _default_diff_method(::G, ::Type{T}) where {G <: Geometry, T <: AbstractFloat}
_default_diff_method(G, T)
# If provided a geometry instance, re-run with the type as argument
function _default_diff_method(
::G,
::Type{T}
) where {G <: Geometry, T <: AbstractFloat}
return _default_diff_method(G, T)
end

################################################################################
Expand Down
40 changes: 25 additions & 15 deletions test/combinations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ This file includes tests for:
end

# Shortcut constructor for geometries with typical support structure
function SupportStatus(g::Geometry, autoenzyme = MeshIntegrals.supports_autoenzyme(g))
function SupportStatus(
g::Geometry;
autoenzyme = true # Assume supported unless explicitly flagged
)
N = Meshes.paramdim(g)
if N == 1
# line/curve
Expand All @@ -73,11 +76,15 @@ This file includes tests for:
end
end

# Generate applicable tests for this geometry
function runtests(
testable::TestableGeometry,
supports::SupportStatus = SupportStatus(testable.geometry);
testable::TestableGeometry;
autoenzyme = true, # Assume supported unless explicitly flagged
rtol = sqrt(eps())
)
# Determine support matrix for this geometry
supports = SupportStatus(testable.geometry; autoenzyme = autoenzyme)

# Test alias functions
for alias in (lineintegral, surfaceintegral, volumeintegral)
# if supports.alias
Expand Down Expand Up @@ -117,17 +124,20 @@ This file includes tests for:
end # for

iter_diff_methods = (
(supports.autoenzyme, AutoEnzyme()),
(true, FiniteDifference()),
(supports.autoenzyme, AutoEnzyme())
)

for (supported, diff_method) in iter_diff_methods
for (supported, method) in iter_diff_methods
f = testable.integrand
geometry = testable.geometry
sol = testable.solution

if supported
@test integral(
testable.integrand, testable.geometry; diff_method = diff_method)≈testable.solution rtol=rtol
@test integral(f, geometry; diff_method = method)≈sol rtol=rtol
@test MeshIntegrals.supports_autoenzyme(testable.geometry) == true
JoshuaLampert marked this conversation as resolved.
Show resolved Hide resolved
JoshuaLampert marked this conversation as resolved.
Show resolved Hide resolved
else
@test_throws "not supported" integral(
testable.integrand, testable.geometry; diff_method = diff_method)
@test_throws "not supported" integral(f, geometry; diff_method = method)
@test MeshIntegrals.supports_autoenzyme(testable.geometry) == false
end
end # for
Expand Down Expand Up @@ -192,7 +202,7 @@ end

# Package and run tests
testable = TestableGeometry(integrand, curve, solution)
runtests(testable; rtol = 0.5e-2)
runtests(testable; autoenzyme = false, rtol = 0.5e-2)
end

@testitem "Meshes.Box 1D" setup=[Combinations] begin
Expand Down Expand Up @@ -333,7 +343,7 @@ end

# Package and run tests
testable = TestableGeometry(integrand, cyl, solution)
runtests(testable)
runtests(testable; autoenzyme = false)
end

@testitem "Meshes.CylinderSurface" setup=[Combinations] begin
Expand All @@ -348,7 +358,7 @@ end

# Package and run tests
testable = TestableGeometry(integrand, cyl, solution)
runtests(testable)
runtests(testable; autoenzyme = false)
end

@testitem "Meshes.Disk" setup=[Combinations] begin
Expand Down Expand Up @@ -465,7 +475,7 @@ end
runtests(testable)
end

@testitem "ParametrizedCurve" setup=[Combinations] begin
@testitem "Meshes.ParametrizedCurve" setup=[Combinations] begin
using CoordRefSystems: Polar

# Geometries
Expand All @@ -485,9 +495,9 @@ end

# Package and run tests
testable_cart = TestableGeometry(integrand, curve_cart, solution)
runtests(testable_cart)
runtests(testable_cart; autoenzyme = false)
testable_polar = TestableGeometry(integrand, curve_polar, solution)
runtests(testable_polar)
runtests(testable_polar; autoenzyme = false)
end

@testitem "Meshes.Plane" setup=[Combinations] begin
Expand Down
34 changes: 28 additions & 6 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,22 @@ end
@test _ones(Float32, 2) == (1.0f0, 1.0f0)
end

@testitem "Differentiation" setup=[Utils] begin
# _default_diff_method
sphere = Sphere(Point(0, 0, 0), 1.0)
@test _default_diff_method(Meshes.Sphere, Float64) isa AutoEnzyme
@test _default_diff_method(sphere, Float64) isa AutoEnzyme
@test _default_diff_method(sphere, BigFloat) isa FiniteDifference
@testitem "Differentiation (Enzyme extension loaded)" setup=[Utils] begin
# _default_diff_method -- using type or instance, Enzyme-supported combination
let sphere = Sphere(Point(0, 0, 0), 1.0)
@test _default_diff_method(Meshes.Sphere, Float64) isa AutoEnzyme
@test _default_diff_method(sphere, Float64) isa AutoEnzyme
end

# _default_diff_method -- Enzyme-unsupported FP types
@test _default_diff_method(Meshes.Sphere, Float16) isa FiniteDifference
@test _default_diff_method(Meshes.Sphere, BigFloat) isa FiniteDifference

# _default_diff_method -- geometries that currently error with AutoEnzyme
@test _default_diff_method(Meshes.BezierCurve, Float64) isa FiniteDifference
@test _default_diff_method(Meshes.CylinderSurface, Float64) isa FiniteDifference
@test _default_diff_method(Meshes.Cylinder, Float64) isa FiniteDifference
@test _default_diff_method(Meshes.ParametrizedCurve, Float64) isa FiniteDifference

# FiniteDifference
@test FiniteDifference().ε ≈ 1e-6
Expand All @@ -45,6 +55,18 @@ end
@test_throws ArgumentError jacobian(box, zeros(3), AutoEnzyme())
end

@testitem "Differentiation (Enzyme extension not loaded)" begin
using Meshes
using MeshIntegrals
using MeshIntegrals: _default_diff_method

# _default_diff_method -- using type or instance, Enzyme-supported combination
let sphere = Sphere(Point(0, 0, 0), 1.0)
@test _default_diff_method(Meshes.Sphere, Float64) isa FiniteDifference
@test _default_diff_method(sphere, Float64) isa FiniteDifference
end
end

@testitem "_ParametricGeometry" setup=[Utils] begin
pt_n = Point(0, 3, 0)
pt_w = Point(-7, 0, 0)
Expand Down
Loading