diff --git a/Project.toml b/Project.toml index 090585d34..6a6187760 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random123 = "74087812-796a-5b5d-8853-05524746bad3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [sources.ReactantCore] path = "lib/ReactantCore" @@ -37,6 +38,7 @@ ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" ReactantCUDAExt = "CUDA" ReactantNNlibExt = "NNlib" +ReactantSpecialFunctionsExt = "SpecialFunctions" ReactantPythonCallExt = "PythonCall" ReactantRandom123Ext = "Random123" ReactantStatisticsExt = "Statistics" diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl new file mode 100644 index 000000000..9ed006e38 --- /dev/null +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -0,0 +1,118 @@ +module ReactantSpecialFunctionsExt +using SpecialFunctions +using Reactant: Ops, Reactant, TracedRNumber, ReactantFloat, ReactantInt, ReactantFloatInt +using Reactant.TracedRNumberOverrides: float + +for fn in [:digamma, :erf, :erfc, (:loggamma, :lgamma)] + (fns, fno) = fn isa Tuple ? fn : (fn, fn) + @eval(function SpecialFunctions.$fns(x::TracedRNumber{<:ReactantFloatInt}) + return Ops.$fno(float(x)) + end) +end + +function SpecialFunctions.gamma(x::TracedRNumber{<:ReactantFloat}) + return exp(Ops.lgamma(float(x))) +end + +function SpecialFunctions.gamma(n::TracedRNumber{<:ReactantInt}) + return round(gamma(float(n))) +end + +function SpecialFunctions.loggamma1p(x::TracedRNumber{<:ReactantFloat}) + return loggamma(1 + x) +end + +function SpecialFunctions.logfactorial(x::TracedRNumber{<:ReactantInt}) + return loggamma(1 + x) +end + +# SpecialFunctions.invdigamma + +function SpecialFunctions.trigamma(x::TracedRNumber{<:ReactantFloatInt}) + return Ops.polygamma(Ops.constant(Float64(1)), float(x))#TODO: change Ops definition +end + +function SpecialFunctions.polygamma( + n::TracedRNumber{<:ReactantFloatInt}, x::TracedRNumber{<:ReactantFloatInt} +) + return Ops.polygamma(float(n), float(x)) +end + +# SpecialFunctions.gamma_inc + +# SpecialFunctions.gamma_inc_inv + +function SpecialFunctions.loggammadiv( + a::TracedRNumber{T}, b::TracedRNumber{T} +) where {T<:ReactantFloat} + return log(gamma(b) / gamma(a + b)) +end + +#SpecialFunctions.gamma ... + +function SpecialFunctions.beta( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} + return gamma(x) * gamma(y) / gamma(x + y) +end + +function SpecialFunctions.logbeta( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} + return log(abs(beta(x, y))) +end + +#TODO: sign function +#SpecialFunctions.logabsbeta +#SpecialFunctions.logabsbinomial + +#SpecialFunctions.beta... + +#utilities... + +function SpecialFunctions.erf( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} + return erf(y) - erf(x) +end + +#SpecialFunctions.erfcinv + +function SpecialFunctions.logerf( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} + return log(erf(x, y)) +end + +function SpecialFunctions.erfcx(x::TracedRNumber{<:ReactantFloatInt}) + return exp(float(x^2)) * erfc(x) +end + +function SpecialFunctions.logerfc(x::TracedRNumber{<:ReactantFloatInt}) + return log(erfc(x)) +end + +function SpecialFunctions.logerfcx(x::TracedRNumber{<:ReactantFloatInt}) + return log(erfcx(x)) +end + +#Unsupported complex +#SpecialFunctions.erfi + +#SpecialFunctions.erfinv +#SpecialFunctions.dawson +#SpecialFunctions.faddeeva + +#Airy and Related Functions + +#Bessel ... + +#Elliptic Integrals + +function SpecialFunctions.zeta( + z::TracedRNumber{T}, s::TracedRNumber{T} +) where {T<:ReactantFloatInt} + return Ops.zeta(z, s) +end + +end # module ReactantSpecialFunctionsExt diff --git a/src/Reactant.jl b/src/Reactant.jl index 8cd761c50..72476f8dc 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -18,42 +18,21 @@ using Enzyme struct ReactantABI <: Enzyme.EnzymeCore.ABI end @static if isdefined(Core, :BFloat16) - const ReactantPrimitive = Union{ - Bool, - Int8, - UInt8, - Int16, - UInt16, - Int32, - UInt32, - Int64, - UInt64, - Float16, - Core.BFloat16, - Float32, - Float64, - Complex{Float32}, - Complex{Float64}, - } + const ReactantFloat = Union{Float16,Core.BFloat16,Float32,Float64} else - const ReactantPrimitive = Union{ - Bool, - Int8, - UInt8, - Int16, - UInt16, - Int32, - UInt32, - Int64, - UInt64, - Float16, - Float32, - Float64, - Complex{Float32}, - Complex{Float64}, - } + const ReactantFloat = Union{Float16,Float32,Float64} end +const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64} + +const ReactantFloatInt = Union{ + Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... +} + +const ReactantPrimitive = Union{ + Bool,Base.uniontypes(ReactantFloatInt)...,Complex{Float32},Complex{Float64} +} + abstract type RNumber{T<:ReactantPrimitive} <: Number end abstract type RArray{T,N} <: AbstractArray{T,N} end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 1e5cfde55..2141d0836 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -210,8 +210,6 @@ for (jlop, hloop) in ( (:(Base.log), :log), (:(Base.log1p), :log_plus_one), (:(Base.sqrt), :sqrt), - (:(Base.ceil), :ceil), - (:(Base.floor), :floor), ) @eval $(jlop)(@nospecialize(lhs::TracedRNumber)) = Ops.$(hloop)(lhs) end @@ -237,6 +235,12 @@ function Base.float(x::TracedRNumber{T}) where {T} return TracedUtils.promote_to(TracedRNumber{float(T)}, x) end +using Reactant: ReactantFloat + +Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A) +Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A) +Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A) + # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T} diff --git a/test/basic.jl b/test/basic.jl index 440dad9e1..8fe89dbf6 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -608,6 +608,12 @@ end end end +@testset "$op" for op in [:round, :ceil, :floor] + for x in (rand(Float32, (3, 3)), rand(Float64)) + @eval @test @jit($op.(ConcreteRNumber.($x))) == $op.($x) + end +end + @testset "dynamic indexing" begin x = randn(5, 3) x_ra = Reactant.to_rarray(x) diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl new file mode 100644 index 000000000..60eedfbfd --- /dev/null +++ b/test/integration/special_functions.jl @@ -0,0 +1,101 @@ +using SpecialFunctions, Reactant + +macro ≈(a, b) + return quote + isapprox($a, $b; atol=1e-14) + end +end + +@testset "gamma" begin + @test SpecialFunctions.gamma(0.5) ≈ @jit(SpecialFunctions.gamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.gamma(2) ≈ @jit(SpecialFunctions.gamma(ConcreteRNumber(2))) +end + +@testset "loggamma" begin + @test SpecialFunctions.loggamma(0.5) ≈ + @jit(SpecialFunctions.loggamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.loggamma(2) ≈ @jit(SpecialFunctions.loggamma(ConcreteRNumber(2))) +end + +@testset "digamma" begin + @test SpecialFunctions.digamma(0.5) ≈ + @jit(SpecialFunctions.digamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.digamma(2) ≈ @jit(SpecialFunctions.digamma(ConcreteRNumber(2))) +end + +@testset "trigamma" begin + @test SpecialFunctions.trigamma(0.5) ≈ + @jit(SpecialFunctions.trigamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.trigamma(2) ≈ @jit(SpecialFunctions.trigamma(ConcreteRNumber(2))) +end + +@testset "beta" begin + @test SpecialFunctions.beta(0.5, 0.6) ≈ + @jit(SpecialFunctions.beta(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.beta(2, 4) ≈ + @jit(SpecialFunctions.beta(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "logbeta" begin + @test SpecialFunctions.logbeta(0.5, 0.6) ≈ + @jit(SpecialFunctions.logbeta(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.logbeta(2, 4) ≈ + @jit(SpecialFunctions.logbeta(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "erf" begin + @test SpecialFunctions.erf(0.5) ≈ @jit(SpecialFunctions.erf(ConcreteRNumber(0.5))) + @test SpecialFunctions.erf(2) ≈ @jit(SpecialFunctions.erf(ConcreteRNumber(2))) +end + +@testset "erf with 2 arguments" begin + @test SpecialFunctions.erf(0.5, 0.6) ≈ + @jit(SpecialFunctions.erf(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.erf(2, 4) ≈ + @jit(SpecialFunctions.erf(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "erfc" begin + @test SpecialFunctions.erfc(0.5) ≈ @jit(SpecialFunctions.erfc(ConcreteRNumber(0.5))) + @test SpecialFunctions.erfc(2) ≈ @jit(SpecialFunctions.erfc(ConcreteRNumber(2))) +end + +@testset "logerf" begin + @test SpecialFunctions.logerf(0.5, 0.6) ≈ + @jit(SpecialFunctions.logerf(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.logerf(2, 4) ≈ + @jit(SpecialFunctions.logerf(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "erfcx" begin + @test SpecialFunctions.erfcx(0.5) ≈ @jit(SpecialFunctions.erfcx(ConcreteRNumber(0.5))) + @test SpecialFunctions.erfcx(2) ≈ @jit(SpecialFunctions.erfcx(ConcreteRNumber(2))) +end + +@testset "logerfc" begin + @test SpecialFunctions.logerfc(0.5) ≈ + @jit(SpecialFunctions.logerfc(ConcreteRNumber(0.5))) + @test SpecialFunctions.logerfc(2) ≈ @jit(SpecialFunctions.logerfc(ConcreteRNumber(2))) +end + +@testset "logerfcx" begin + @test SpecialFunctions.logerfcx(0.5) ≈ + @jit(SpecialFunctions.logerfcx(ConcreteRNumber(0.5))) + @test SpecialFunctions.logerfcx(2) ≈ @jit(SpecialFunctions.logerfcx(ConcreteRNumber(2))) +end + +@testset "loggamma1p" begin + @test SpecialFunctions.loggamma1p(0.5) ≈ + @jit SpecialFunctions.loggamma1p(ConcreteRNumber(0.5)) +end + +@testset "loggammadiv" begin + @test SpecialFunctions.loggammadiv(150, 20) ≈ + @jit SpecialFunctions.loggammadiv(ConcreteRNumber(150), ConcreteRNumber(20)) +end + +@testset "zeta" begin + s = ConcreteRArray([1.0, 2.0, 50.0]) + z = ConcreteRArray([1e-8, 0.001, 2.0]) + @test SpecialFunctions.zeta.(Array(s), Array(z)) ≈ @jit SpecialFunctions.zeta.(s, z) +end diff --git a/test/runtests.jl b/test/runtests.jl index 834d9b504..f0e9ea1f4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,6 +63,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) # @safetestset "CUDA" include("integration/cuda.jl") @safetestset "Linear Algebra" include("integration/linear_algebra.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") + @safetestset "SpecialFunctions" include("integration/special_functions.jl") @safetestset "Random" include("integration/random.jl") @safetestset "Python" include("integration/python.jl") end