Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpecialFunctions simple functions #384

Merged
merged 35 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
52c87fc
`SpecialFunctions` simple functions
glou-nes Dec 16, 2024
520f71f
review
glou-nes Dec 16, 2024
50d7e8b
missing Ext
glou-nes Dec 16, 2024
dfbd6bf
reviews
glou-nes Dec 16, 2024
3289169
format
glou-nes Dec 16, 2024
deb935f
feedbacks
glou-nes Dec 16, 2024
deee78f
remove usage of `ReactantFloat`, simplify signatures
glou-nes Dec 16, 2024
9f87dc2
fix
glou-nes Dec 16, 2024
381b6d4
real bound
glou-nes Dec 16, 2024
bc3a07f
Merge branch 'main' into special_functions
glou-nes Dec 19, 2024
8c3a991
feedback
glou-nes Dec 19, 2024
539eaca
remove assert
glou-nes Dec 19, 2024
ec93d18
Merge branch 'main' into special_functions
wsmoses Dec 19, 2024
4d22a58
update signature
glou-nes Dec 28, 2024
e57bf9d
Merge branch 'main' into special_functions
glou-nes Dec 28, 2024
0ddc365
add missing def for 1.10, increase ~ tolerence for MacOS
glou-nes Dec 28, 2024
667575a
missing def int, julia 1.10
glou-nes Dec 29, 2024
a7004eb
format
glou-nes Dec 29, 2024
bb191ab
format 2
glou-nes Dec 29, 2024
c370148
Merge branch 'main' into special_functions
glou-nes Dec 29, 2024
a38883c
error
glou-nes Dec 29, 2024
bd84cb2
simplify rounding
glou-nes Dec 29, 2024
05822d3
Revert "simplify rounding"
glou-nes Dec 29, 2024
dca9021
disable tests
glou-nes Dec 29, 2024
9a7f0b5
revert
glou-nes Dec 29, 2024
62fa41c
revert
glou-nes Dec 29, 2024
6880d16
test CI
glou-nes Dec 29, 2024
f60de70
good order
glou-nes Dec 29, 2024
35f0e9d
remove fancy call
glou-nes Dec 29, 2024
af777ac
test
glou-nes Dec 30, 2024
b3df340
new test
glou-nes Dec 30, 2024
f6bc8a0
round &co need float
glou-nes Dec 30, 2024
6c16535
Merge branch 'main' into special_functions
wsmoses Dec 31, 2024
0a35b0a
Merge branch 'main' into special_functions
glou-nes Jan 1, 2025
88b2a12
format
glou-nes Jan 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[sources.ReactantCore]
path = "lib/ReactantCore"
Expand All @@ -37,6 +38,7 @@ ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantNNlibExt = "NNlib"
ReactantSpecialFunctionsExt = "SpecialFunctions"
ReactantPythonCallExt = "PythonCall"
ReactantRandom123Ext = "Random123"
ReactantStatisticsExt = "Statistics"
Expand Down
118 changes: 118 additions & 0 deletions ext/ReactantSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
module ReactantSpecialFunctionsExt
using SpecialFunctions
using Reactant: Ops, Reactant, TracedRNumber, ReactantFloat, ReactantInt, ReactantFloatInt
using Reactant.TracedRNumberOverrides: float

for fn in [:digamma, :erf, :erfc, (:loggamma, :lgamma)]
(fns, fno) = fn isa Tuple ? fn : (fn, fn)
@eval(function SpecialFunctions.$fns(x::TracedRNumber{<:ReactantFloatInt})
return Ops.$fno(float(x))
end)
end

function SpecialFunctions.gamma(x::TracedRNumber{<:ReactantFloat})
return exp(Ops.lgamma(float(x)))
end

function SpecialFunctions.gamma(n::TracedRNumber{<:ReactantInt})
return round(gamma(float(n)))
end

function SpecialFunctions.loggamma1p(x::TracedRNumber{<:ReactantFloat})
return loggamma(1 + x)
end

function SpecialFunctions.logfactorial(x::TracedRNumber{<:ReactantInt})
return loggamma(1 + x)
end

# SpecialFunctions.invdigamma

function SpecialFunctions.trigamma(x::TracedRNumber{<:ReactantFloatInt})
return Ops.polygamma(Ops.constant(Float64(1)), float(x))#TODO: change Ops definition
end

function SpecialFunctions.polygamma(
n::TracedRNumber{<:ReactantFloatInt}, x::TracedRNumber{<:ReactantFloatInt}
)
return Ops.polygamma(float(n), float(x))
end

# SpecialFunctions.gamma_inc

# SpecialFunctions.gamma_inc_inv

function SpecialFunctions.loggammadiv(
a::TracedRNumber{T}, b::TracedRNumber{T}
) where {T<:ReactantFloat}
return log(gamma(b) / gamma(a + b))
end

#SpecialFunctions.gamma ...

function SpecialFunctions.beta(
x::TracedRNumber{T}, y::TracedRNumber{T}
) where {T<:ReactantFloatInt}
return gamma(x) * gamma(y) / gamma(x + y)
end

function SpecialFunctions.logbeta(
x::TracedRNumber{T}, y::TracedRNumber{T}
) where {T<:ReactantFloatInt}
return log(abs(beta(x, y)))
end

#TODO: sign function
#SpecialFunctions.logabsbeta
#SpecialFunctions.logabsbinomial

#SpecialFunctions.beta...

#utilities...

function SpecialFunctions.erf(
x::TracedRNumber{T}, y::TracedRNumber{T}
) where {T<:ReactantFloatInt}
return erf(y) - erf(x)
end

#SpecialFunctions.erfcinv

function SpecialFunctions.logerf(
x::TracedRNumber{T}, y::TracedRNumber{T}
) where {T<:ReactantFloatInt}
return log(erf(x, y))
end

function SpecialFunctions.erfcx(x::TracedRNumber{<:ReactantFloatInt})
return exp(float(x^2)) * erfc(x)
end

function SpecialFunctions.logerfc(x::TracedRNumber{<:ReactantFloatInt})
return log(erfc(x))
end

function SpecialFunctions.logerfcx(x::TracedRNumber{<:ReactantFloatInt})
return log(erfcx(x))
end

#Unsupported complex
#SpecialFunctions.erfi

#SpecialFunctions.erfinv
#SpecialFunctions.dawson
#SpecialFunctions.faddeeva

#Airy and Related Functions

#Bessel ...

#Elliptic Integrals

function SpecialFunctions.zeta(
z::TracedRNumber{T}, s::TracedRNumber{T}
) where {T<:ReactantFloatInt}
return Ops.zeta(z, s)
end

end # module ReactantSpecialFunctionsExt
45 changes: 12 additions & 33 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,21 @@ using Enzyme
struct ReactantABI <: Enzyme.EnzymeCore.ABI end

@static if isdefined(Core, :BFloat16)
const ReactantPrimitive = Union{
Bool,
Int8,
UInt8,
Int16,
UInt16,
Int32,
UInt32,
Int64,
UInt64,
Float16,
Core.BFloat16,
Float32,
Float64,
Complex{Float32},
Complex{Float64},
}
const ReactantFloat = Union{Float16,Core.BFloat16,Float32,Float64}
else
const ReactantPrimitive = Union{
Bool,
Int8,
UInt8,
Int16,
UInt16,
Int32,
UInt32,
Int64,
UInt64,
Float16,
Float32,
Float64,
Complex{Float32},
Complex{Float64},
}
const ReactantFloat = Union{Float16,Float32,Float64}
end

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}

const ReactantFloatInt = Union{
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
}

const ReactantPrimitive = Union{
Bool,Base.uniontypes(ReactantFloatInt)...,Complex{Float32},Complex{Float64}
}

abstract type RNumber{T<:ReactantPrimitive} <: Number end

abstract type RArray{T,N} <: AbstractArray{T,N} end
Expand Down
24 changes: 22 additions & 2 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ for (jlop, hloop) in (
(:(Base.log), :log),
(:(Base.log1p), :log_plus_one),
(:(Base.sqrt), :sqrt),
(:(Base.ceil), :ceil),
(:(Base.floor), :floor),
)
@eval $(jlop)(@nospecialize(lhs::TracedRNumber)) = Ops.$(hloop)(lhs)
end
Expand All @@ -231,6 +229,28 @@ function Base.float(x::TracedRNumber{T}) where {T}
return TracedUtils.promote_to(TracedRNumber{float(T)}, x)
end

using Reactant: ReactantFloat

Base.round(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundNearest)
Base.floor(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundDown)
Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundUp)

function Base.round(A::TracedRNumber{<:ReactantFloat}, ::RoundingMode{R}) where {R}
if R == :Nearest
Ops.round_nearest_even(A)
elseif R == :Up
Ops.ceil(A)
elseif R == :Down
Ops.floor(A)
else
error("$R is unsupported")
end
end

Base.round(A::TracedRNumber{<:Integer}) = A
Base.floor(A::TracedRNumber{<:Integer}) = A
Base.ceil(A::TracedRNumber{<:Integer}) = A

# Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays
Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...)
function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T}
Expand Down
6 changes: 6 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,12 @@ end
end
end

@testset "$op" for op in [:round, :ceil, :floor]
for x in (rand(Float32, (3, 3)), rand(Int), rand(Float64), rand(Int, (3, 3)))
@eval @test @jit($op.(ConcreteRNumber.($x))) == $op.($x)
end
end

@testset "dynamic indexing" begin
x = randn(5, 3)
x_ra = Reactant.to_rarray(x)
Expand Down
44 changes: 44 additions & 0 deletions test/integration/special_functions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using SpecialFunctions, Reactant

macro ≈(a, b)
return quote
isapprox(Array($a), $b; atol=1e-14)
end
end

@testset "$op" for (op, n_args) in [
(:gamma, 1),
(:loggamma, 1),
(:digamma, 1),
(:trigamma, 1),
(:beta, 2),
(:logbeta, 2),
(:erf, 1),
(:erf, 2),
(:erfc, 1),
(:logerf, 2),
(:erfcx, 1),
(:logerfc, 1),
(:logerfcx, 1),
]
for data in ([0.5, 0.6], [2, 4])
x = data[1:n_args]
@eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op($x...)
glou-nes marked this conversation as resolved.
Show resolved Hide resolved
end
end

@testset "loggamma1p" begin
@test SpecialFunctions.loggamma1p(0.5) ≈
@jit SpecialFunctions.loggamma1p(ConcreteRNumber(0.5))
end

@testset "loggammadiv" begin
@test SpecialFunctions.loggammadiv(150, 20) ≈
@jit SpecialFunctions.loggammadiv(ConcreteRNumber(150), ConcreteRNumber(20))
end

@testset "zeta" begin
s = ConcreteRArray([1.0, 2.0, 50.0])
z = ConcreteRArray([1e-8, 0.001, 2.0])
@test SpecialFunctions.zeta.(Array(s), Array(z)) ≈ @jit SpecialFunctions.zeta.(s, z)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
# @safetestset "CUDA" include("integration/cuda.jl")
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
@safetestset "AbstractFFTs" include("integration/fft.jl")
@safetestset "SpecialFunctions" include("integration/special_functions.jl")
@safetestset "Random" include("integration/random.jl")
@safetestset "Python" include("integration/python.jl")
end
Expand Down
Loading