From 52c87fce0e658b14fc93f024416f8bd371394a67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 16 Dec 2024 04:01:38 +0100 Subject: [PATCH 01/29] `SpecialFunctions` simple functions --- Project.toml | 2 + ext/ReactantSpecialFunctions.jl | 125 ++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 ext/ReactantSpecialFunctions.jl diff --git a/Project.toml b/Project.toml index 9af7dafef..57528c8bf 100644 --- a/Project.toml +++ b/Project.toml @@ -32,6 +32,7 @@ path = "lib/ReactantCore" ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" ReactantNNlibExt = "NNlib" +ReactantSpecialFunctions = "SpecialFunctions" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" @@ -59,3 +60,4 @@ julia = "1.10" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/ext/ReactantSpecialFunctions.jl b/ext/ReactantSpecialFunctions.jl new file mode 100644 index 000000000..e98489de6 --- /dev/null +++ b/ext/ReactantSpecialFunctions.jl @@ -0,0 +1,125 @@ +using Reactant: Reactant, Ops +using Reactant.TracedUtils + +function SpecialFunctions.gamma(x::TracedRNumber{T}) where {T} + x = promote_to(TracedRNumber{Float64}, x) + return exp(Ops.lgamma(x)) +end + +#TODO: add factorial function +#=function SpecialFunctions.gamma( + n::TracedRNumber{T} +) where {T<:Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}} + factorial(n) +end +=# + +function SpecialFunctions.loggamma(x::TracedRNumber{T}) where {T} + x = promote_to(TracedRNumber{Float64}, x) + return Ops.lgamma(x) +end + +function SpecialFunctions.loggamma1p(x::TracedRNumber{T}) where {T} + return loggamma(1 + x) +end + +function SpecialFunctions.logfactorial( + x::TracedRNumber{<:T} +) where {T<:Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}} + return loggamma(1 + x) +end + +function SpecialFunctions.digamma(x::TracedRNumber{T}) where {T} + x = promote_to(TracedRNumber{Float64}, x) + return Ops.digamma(x) +end + +# SpecialFunctions.invdigamma + +function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T} + return Ops.polygamma(Ops.constant(2.0), x) +end + +function SpecialFunctions.polygamma(n::TracedRNumber{T}, x::TracedRNumber{T}) where {T} + x = promote_to(TracedRNumber{Float64}, x) + return Ops.polygamma(n, x) +end + +# SpecialFunctions.gamma_inc + +# SpecialFunctions.gamma_inc_inv + +function SpecialFunctions.loggammadiv(a::TracedRNumber{T}, b::TracedRNumber{T}) where {T} + return log(gamma(b) / gamma(a + b)) +end + +#SpecialFunctions.gamma ... + +function SpecialFunctions.beta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} + return gamma(x) * gamma(y) / gamma(x + y) +end + +function SpecialFunctions.logbeta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} + return log(abs(beta(x, y))) +end + +#TODO: sign function +#SpecialFunctions.logabsbeta +#SpecialFunctions.logabsbinomial + +#SpecialFunctions.beta... + +#utilities... + +function SpecialFunctions.erf(x::TracedRNumber{T}) where {T} + x = promote_to(TracedRNumber{Float64}, x) + return Ops.erf(x) +end + +function SpecialFunctions.erf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} + return y - x +end + +function SpecialFunctions.erfc(x::TracedRNumber{T}) where {T} + x = promote_to(TracedRNumber{Float64}, x) + return Ops.erfc(x) +end + +function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} + return log(erf(x, y)) +end + +#SpecialFunctions.erfcinv + +function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} + return log(erf(x, y)) +end + +function SpecialFunctions.erfcx(x::TracedRNumber{T}) where {T} + return exp(x^2) * erfc(x) +end + +function SpecialFunctions.logerfc(x::TracedRNumber{T}) where {T} + return log(erfc(x)) +end + +function SpecialFunctions.logerfcx(x::TracedRNumber{T}) where {T} + return log(erfcx(x)) +end + +#Unsupported complex +#SpecialFunctions.erfi + +#SpecialFunctions.erfinv +#SpecialFunctions.dawson +#SpecialFunctions.faddeeva + +#Airy and Related Functions + +#Bessel ... + +#Elliptic Integrals + +function SpecialFunctions.zeta(z::TracedRNumber{T}, s::TracedRNumber{T}) where {T} + return Ops.zeta(z,s) +end \ No newline at end of file From 520f71f1d176e7fc0a421ed40ea0c07e270717a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 16 Dec 2024 05:09:41 +0100 Subject: [PATCH 02/29] review --- ext/ReactantSpecialFunctions.jl | 12 +++------ test/integration/special_functions.jl | 35 +++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 40 insertions(+), 8 deletions(-) create mode 100644 test/integration/special_functions.jl diff --git a/ext/ReactantSpecialFunctions.jl b/ext/ReactantSpecialFunctions.jl index e98489de6..474b8cd2b 100644 --- a/ext/ReactantSpecialFunctions.jl +++ b/ext/ReactantSpecialFunctions.jl @@ -1,5 +1,5 @@ using Reactant: Reactant, Ops -using Reactant.TracedUtils +using Reactant.TracedUtils: promote_to function SpecialFunctions.gamma(x::TracedRNumber{T}) where {T} x = promote_to(TracedRNumber{Float64}, x) @@ -37,7 +37,7 @@ end # SpecialFunctions.invdigamma function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T} - return Ops.polygamma(Ops.constant(2.0), x) + return Ops.polygamma(Ops.constant(1.0), x) end function SpecialFunctions.polygamma(n::TracedRNumber{T}, x::TracedRNumber{T}) where {T} @@ -77,7 +77,7 @@ function SpecialFunctions.erf(x::TracedRNumber{T}) where {T} end function SpecialFunctions.erf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} - return y - x + return erf(y) - erf(x) end function SpecialFunctions.erfc(x::TracedRNumber{T}) where {T} @@ -85,10 +85,6 @@ function SpecialFunctions.erfc(x::TracedRNumber{T}) where {T} return Ops.erfc(x) end -function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} - return log(erf(x, y)) -end - #SpecialFunctions.erfcinv function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} @@ -121,5 +117,5 @@ end #Elliptic Integrals function SpecialFunctions.zeta(z::TracedRNumber{T}, s::TracedRNumber{T}) where {T} - return Ops.zeta(z,s) + return Ops.zeta(z, s) end \ No newline at end of file diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl new file mode 100644 index 000000000..51aa230d0 --- /dev/null +++ b/test/integration/special_functions.jl @@ -0,0 +1,35 @@ +using SpecialFunctions, Reactant +@testset "Generic" begin + values = [0.5, 0.6] + for (op, n_args) in [ + (:gamma, 1), + (:loggamma, 1), + (:loggamma1p, 1), + (:digamma, 1), + (:trigamma, 1), + (:beta, 2), + (:logbeta, 2), + (:erf, 1), + (:erf, 2), + (:erfc, 1), + (:logerf, 2), + (:erfcx, 1), + (:logerfc, 1), + (:logerfcx, 1), + ] + x = values[1:n_args] + @eval @test @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) ≈ + SpecialFunctions.$op($x...) + end +end + +@testset "loggammadiv" begin + @test SpecialFunctions.loggammadiv(150, 20) ≈ + @jit SpecialFunctions.loggammadiv(ConcreteRNumber(150), ConcreteRNumber(20)) +end + +@testset "zeta" begin + s = ConcreteRArray([1.0, 2.0, 50.0]) + z = ConcreteRArray([1e-8, 0.001, 2.0]) + @test SpecialFunctions.zeta.(Array(s), Array(z)) ≈ @jit SpecialFunctions.zeta.(s, z) +end diff --git a/test/runtests.jl b/test/runtests.jl index fddc963ce..eeb9472df 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" @safetestset "Linear Algebra" include("integration/linear_algebra.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") + @safetestset "SpecialFunctions" include("integration/special_functions.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" From 50d7e8bbb5ea7b1042e06931aee6fe600e7ff533 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 16 Dec 2024 05:23:06 +0100 Subject: [PATCH 03/29] missing Ext --- Project.toml | 2 +- src/Reactant.jl | 28 ++++++---------------------- 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index 57528c8bf..20c5f8aad 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ path = "lib/ReactantCore" ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" ReactantNNlibExt = "NNlib" -ReactantSpecialFunctions = "SpecialFunctions" +ReactantSpecialFunctionsExt = "SpecialFunctions" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" diff --git a/src/Reactant.jl b/src/Reactant.jl index ba2da588d..e42dbb112 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -16,25 +16,12 @@ using Enzyme struct ReactantABI <: Enzyme.EnzymeCore.ABI end @static if isdefined(Core, :BFloat16) - const ReactantPrimitive = Union{ - Bool, - Int8, - UInt8, - Int16, - UInt16, - Int32, - UInt32, - Int64, - UInt64, - Float16, - Core.BFloat16, - Float32, - Float64, - Complex{Float32}, - Complex{Float64}, - } + ReactantFloat = Union{Float16, BFloat16, Float32, Float64} else - const ReactantPrimitive = Union{ + ReactantFloat = Union{Float16, Float32, Float64} +end + +const ReactantPrimitive = Union{ Bool, Int8, UInt8, @@ -44,13 +31,10 @@ else UInt32, Int64, UInt64, - Float16, - Float32, - Float64, Complex{Float32}, Complex{Float64}, + ReactantFloat... } -end abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end abstract type RNumber{T<:ReactantPrimitive} <: Number end From dfbd6bf8ff18be9312e740371b5e654961b3de6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 16 Dec 2024 05:51:06 +0100 Subject: [PATCH 04/29] reviews --- Project.toml | 3 +- ...ions.jl => ReactantSpecialFunctionsExt.jl} | 41 ++++++++++++------- src/Reactant.jl | 6 +-- 3 files changed, 31 insertions(+), 19 deletions(-) rename ext/{ReactantSpecialFunctions.jl => ReactantSpecialFunctionsExt.jl} (72%) diff --git a/Project.toml b/Project.toml index 20c5f8aad..ca576a39c 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [sources.ReactantCore] path = "lib/ReactantCore" @@ -60,4 +61,4 @@ julia = "1.10" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + diff --git a/ext/ReactantSpecialFunctions.jl b/ext/ReactantSpecialFunctionsExt.jl similarity index 72% rename from ext/ReactantSpecialFunctions.jl rename to ext/ReactantSpecialFunctionsExt.jl index 474b8cd2b..d490b7ce3 100644 --- a/ext/ReactantSpecialFunctions.jl +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -1,8 +1,15 @@ -using Reactant: Reactant, Ops +module ReactantSpecialFunctionsExt +using SpecialFunctions +using Reactant: Ops, Reactant, ReactantFloat, TracedRNumber using Reactant.TracedUtils: promote_to -function SpecialFunctions.gamma(x::TracedRNumber{T}) where {T} - x = promote_to(TracedRNumber{Float64}, x) +for fn in [:gamma, :loggamma, :digamma, :erf, :erfc] + @eval(function SpecialFunctions.$fn(x::TracedRNumber{<:Number}) + return $fn(promote_to(TracedRNumber{Float64}, x)) + end) +end + +function SpecialFunctions.gamma(x::TracedRNumber{T}) where {T<:ReactantFloat} return exp(Ops.lgamma(x)) end @@ -14,8 +21,7 @@ end end =# -function SpecialFunctions.loggamma(x::TracedRNumber{T}) where {T} - x = promote_to(TracedRNumber{Float64}, x) +function SpecialFunctions.loggamma(x::TracedRNumber{T}) where {T<:ReactantFloat} return Ops.lgamma(x) end @@ -24,25 +30,30 @@ function SpecialFunctions.loggamma1p(x::TracedRNumber{T}) where {T} end function SpecialFunctions.logfactorial( - x::TracedRNumber{<:T} + x::TracedRNumber{T} ) where {T<:Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}} return loggamma(1 + x) end -function SpecialFunctions.digamma(x::TracedRNumber{T}) where {T} - x = promote_to(TracedRNumber{Float64}, x) +function SpecialFunctions.digamma(x::TracedRNumber{T}) where {T<:ReactantFloat} return Ops.digamma(x) end # SpecialFunctions.invdigamma function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T} - return Ops.polygamma(Ops.constant(1.0), x) + return Ops.polygamma(Ops.constant(T(1)), x) +end + +function SpecialFunctions.polygamma( + n::TracedRNumber{T}, x::TracedRNumber{T} +) where {T<:ReactantFloat} + return Ops.polygamma(n, x) end function SpecialFunctions.polygamma(n::TracedRNumber{T}, x::TracedRNumber{T}) where {T} x = promote_to(TracedRNumber{Float64}, x) - return Ops.polygamma(n, x) + return polygamma(n, x) end # SpecialFunctions.gamma_inc @@ -71,8 +82,7 @@ end #utilities... -function SpecialFunctions.erf(x::TracedRNumber{T}) where {T} - x = promote_to(TracedRNumber{Float64}, x) +function SpecialFunctions.erf(x::TracedRNumber{T}) where {T<:ReactantFloat} return Ops.erf(x) end @@ -80,8 +90,7 @@ function SpecialFunctions.erf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T return erf(y) - erf(x) end -function SpecialFunctions.erfc(x::TracedRNumber{T}) where {T} - x = promote_to(TracedRNumber{Float64}, x) +function SpecialFunctions.erfc(x::TracedRNumber{T}) where {T<:ReactantFloat} return Ops.erfc(x) end @@ -118,4 +127,6 @@ end function SpecialFunctions.zeta(z::TracedRNumber{T}, s::TracedRNumber{T}) where {T} return Ops.zeta(z, s) -end \ No newline at end of file +end + +end # module ReactantSpecialFunctionsExt \ No newline at end of file diff --git a/src/Reactant.jl b/src/Reactant.jl index e42dbb112..a8748b712 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -16,9 +16,9 @@ using Enzyme struct ReactantABI <: Enzyme.EnzymeCore.ABI end @static if isdefined(Core, :BFloat16) - ReactantFloat = Union{Float16, BFloat16, Float32, Float64} + const ReactantFloat = Union{Float16, Core.BFloat16, Float32, Float64} else - ReactantFloat = Union{Float16, Float32, Float64} + const ReactantFloat = Union{Float16, Float32, Float64} end const ReactantPrimitive = Union{ @@ -33,7 +33,7 @@ const ReactantPrimitive = Union{ UInt64, Complex{Float32}, Complex{Float64}, - ReactantFloat... + Base.uniontypes(ReactantFloat)... } abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end From 3289169ef7c679c3f858904a53be49d07c1c720b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 16 Dec 2024 05:55:48 +0100 Subject: [PATCH 05/29] format --- src/Reactant.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index a8748b712..fedf6c3ee 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -16,25 +16,25 @@ using Enzyme struct ReactantABI <: Enzyme.EnzymeCore.ABI end @static if isdefined(Core, :BFloat16) - const ReactantFloat = Union{Float16, Core.BFloat16, Float32, Float64} + const ReactantFloat = Union{Float16,Core.BFloat16,Float32,Float64} else - const ReactantFloat = Union{Float16, Float32, Float64} + const ReactantFloat = Union{Float16,Float32,Float64} end const ReactantPrimitive = Union{ - Bool, - Int8, - UInt8, - Int16, - UInt16, - Int32, - UInt32, - Int64, - UInt64, - Complex{Float32}, - Complex{Float64}, - Base.uniontypes(ReactantFloat)... - } + Bool, + Int8, + UInt8, + Int16, + UInt16, + Int32, + UInt32, + Int64, + UInt64, + Complex{Float32}, + Complex{Float64}, + Base.uniontypes(ReactantFloat)..., +} abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end abstract type RNumber{T<:ReactantPrimitive} <: Number end From deb935f7a34fe286fe3f619173901b32cde3c681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 16 Dec 2024 11:29:03 +0100 Subject: [PATCH 06/29] feedbacks --- ext/ReactantSpecialFunctionsExt.jl | 12 ++++---- test/integration/special_functions.jl | 41 ++++++++++++++------------- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl index d490b7ce3..63db8de24 100644 --- a/ext/ReactantSpecialFunctionsExt.jl +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -1,11 +1,11 @@ module ReactantSpecialFunctionsExt using SpecialFunctions using Reactant: Ops, Reactant, ReactantFloat, TracedRNumber -using Reactant.TracedUtils: promote_to +using Reactant.TracedRNumberOverrides: float -for fn in [:gamma, :loggamma, :digamma, :erf, :erfc] +for fn in [:gamma, :loggamma, :digamma, :trigamma, :erf, :erfc] @eval(function SpecialFunctions.$fn(x::TracedRNumber{<:Number}) - return $fn(promote_to(TracedRNumber{Float64}, x)) + return $fn(float(x)) end) end @@ -41,7 +41,7 @@ end # SpecialFunctions.invdigamma -function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T} +function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T<: ReactantFloat} return Ops.polygamma(Ops.constant(T(1)), x) end @@ -52,7 +52,7 @@ function SpecialFunctions.polygamma( end function SpecialFunctions.polygamma(n::TracedRNumber{T}, x::TracedRNumber{T}) where {T} - x = promote_to(TracedRNumber{Float64}, x) + x = promote_to(TracedRNumber{T}, x) return polygamma(n, x) end @@ -101,7 +101,7 @@ function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where end function SpecialFunctions.erfcx(x::TracedRNumber{T}) where {T} - return exp(x^2) * erfc(x) + return exp(float(x^2)) * erfc(x) end function SpecialFunctions.logerfc(x::TracedRNumber{T}) where {T} diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl index 51aa230d0..fa620561e 100644 --- a/test/integration/special_functions.jl +++ b/test/integration/special_functions.jl @@ -1,28 +1,31 @@ using SpecialFunctions, Reactant -@testset "Generic" begin - values = [0.5, 0.6] - for (op, n_args) in [ - (:gamma, 1), - (:loggamma, 1), - (:loggamma1p, 1), - (:digamma, 1), - (:trigamma, 1), - (:beta, 2), - (:logbeta, 2), - (:erf, 1), - (:erf, 2), - (:erfc, 1), - (:logerf, 2), - (:erfcx, 1), - (:logerfc, 1), - (:logerfcx, 1), - ] - x = values[1:n_args] +@testset "$op" for (op, n_args) in [ + (:gamma, 1), + (:loggamma, 1), + (:digamma, 1), + (:trigamma, 1), + (:beta, 2), + (:logbeta, 2), + (:erf, 1), + (:erf, 2), + (:erfc, 1), + (:logerf, 2), + (:erfcx, 1), + (:logerfc, 1), + (:logerfcx, 1), +] + for data in ([0.5, 0.6], [2, 4]) + x = data[1:n_args] @eval @test @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) ≈ SpecialFunctions.$op($x...) end end +@testset "loggamma1p" begin + @test SpecialFunctions.loggamma1p(0.5) ≈ + @jit SpecialFunctions.loggamma1p(ConcreteRNumber(0.5)) +end + @testset "loggammadiv" begin @test SpecialFunctions.loggammadiv(150, 20) ≈ @jit SpecialFunctions.loggammadiv(ConcreteRNumber(150), ConcreteRNumber(20)) From deee78facc59350798334cd9ac2d7e78590fb65e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 16 Dec 2024 14:37:29 +0100 Subject: [PATCH 07/29] remove usage of `ReactantFloat`, simplify signatures --- ext/ReactantSpecialFunctionsExt.jl | 46 +++++++++++------------------- 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl index 63db8de24..8d3ff300e 100644 --- a/ext/ReactantSpecialFunctionsExt.jl +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -1,16 +1,16 @@ module ReactantSpecialFunctionsExt using SpecialFunctions -using Reactant: Ops, Reactant, ReactantFloat, TracedRNumber +using Reactant: Ops, Reactant, TracedRNumber using Reactant.TracedRNumberOverrides: float -for fn in [:gamma, :loggamma, :digamma, :trigamma, :erf, :erfc] - @eval(function SpecialFunctions.$fn(x::TracedRNumber{<:Number}) - return $fn(float(x)) +for fn in [:digamma, :erf, :erfc] + @eval(function SpecialFunctions.$fn(x::TracedRNumber) + return Ops.$fn(float(x)) end) end -function SpecialFunctions.gamma(x::TracedRNumber{T}) where {T<:ReactantFloat} - return exp(Ops.lgamma(x)) +function SpecialFunctions.gamma(x::TracedRNumber) + return exp(Ops.lgamma(float(x))) end #TODO: add factorial function @@ -21,11 +21,11 @@ end end =# -function SpecialFunctions.loggamma(x::TracedRNumber{T}) where {T<:ReactantFloat} - return Ops.lgamma(x) +function SpecialFunctions.loggamma(x::TracedRNumber) + return Ops.lgamma(float(x)) end -function SpecialFunctions.loggamma1p(x::TracedRNumber{T}) where {T} +function SpecialFunctions.loggamma1p(x::TracedRNumber) return loggamma(1 + x) end @@ -35,20 +35,14 @@ function SpecialFunctions.logfactorial( return loggamma(1 + x) end -function SpecialFunctions.digamma(x::TracedRNumber{T}) where {T<:ReactantFloat} - return Ops.digamma(x) -end - # SpecialFunctions.invdigamma -function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T<: ReactantFloat} - return Ops.polygamma(Ops.constant(T(1)), x) +function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T} + return Ops.polygamma(Ops.constant(T(1)), float(x)) end -function SpecialFunctions.polygamma( - n::TracedRNumber{T}, x::TracedRNumber{T} -) where {T<:ReactantFloat} - return Ops.polygamma(n, x) +function SpecialFunctions.polygamma(n::TracedRNumber{T}, x::TracedRNumber) + return Ops.polygamma(float(n), float(x)) end function SpecialFunctions.polygamma(n::TracedRNumber{T}, x::TracedRNumber{T}) where {T} @@ -82,33 +76,25 @@ end #utilities... -function SpecialFunctions.erf(x::TracedRNumber{T}) where {T<:ReactantFloat} - return Ops.erf(x) -end - function SpecialFunctions.erf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} return erf(y) - erf(x) end -function SpecialFunctions.erfc(x::TracedRNumber{T}) where {T<:ReactantFloat} - return Ops.erfc(x) -end - #SpecialFunctions.erfcinv function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} return log(erf(x, y)) end -function SpecialFunctions.erfcx(x::TracedRNumber{T}) where {T} +function SpecialFunctions.erfcx(x::TracedRNumber) return exp(float(x^2)) * erfc(x) end -function SpecialFunctions.logerfc(x::TracedRNumber{T}) where {T} +function SpecialFunctions.logerfc(x::TracedRNumber) return log(erfc(x)) end -function SpecialFunctions.logerfcx(x::TracedRNumber{T}) where {T} +function SpecialFunctions.logerfcx(x::TracedRNumber) return log(erfcx(x)) end From 9f87dc224341edbb7390ef8e251ab227abd9087e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 16 Dec 2024 17:15:42 +0100 Subject: [PATCH 08/29] fix --- ext/ReactantSpecialFunctionsExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl index 8d3ff300e..e71553e5c 100644 --- a/ext/ReactantSpecialFunctionsExt.jl +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -30,8 +30,8 @@ function SpecialFunctions.loggamma1p(x::TracedRNumber) end function SpecialFunctions.logfactorial( - x::TracedRNumber{T} -) where {T<:Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}} + x::TracedRNumber{<:Integer} +) return loggamma(1 + x) end @@ -41,7 +41,7 @@ function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T} return Ops.polygamma(Ops.constant(T(1)), float(x)) end -function SpecialFunctions.polygamma(n::TracedRNumber{T}, x::TracedRNumber) +function SpecialFunctions.polygamma(n::TracedRNumber, x::TracedRNumber) return Ops.polygamma(float(n), float(x)) end From 381b6d40387e3a1ffacb17595de9093889aead0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 16 Dec 2024 21:31:48 +0100 Subject: [PATCH 09/29] real bound --- ext/ReactantSpecialFunctionsExt.jl | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl index e71553e5c..a76a08a58 100644 --- a/ext/ReactantSpecialFunctionsExt.jl +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -4,7 +4,7 @@ using Reactant: Ops, Reactant, TracedRNumber using Reactant.TracedRNumberOverrides: float for fn in [:digamma, :erf, :erfc] - @eval(function SpecialFunctions.$fn(x::TracedRNumber) + @eval(function SpecialFunctions.$fn(x::TracedRNumber{<:Real}) return Ops.$fn(float(x)) end) end @@ -16,12 +16,12 @@ end #TODO: add factorial function #=function SpecialFunctions.gamma( n::TracedRNumber{T} -) where {T<:Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}} +) where {T<:Integer} factorial(n) end =# -function SpecialFunctions.loggamma(x::TracedRNumber) +function SpecialFunctions.loggamma(x::TracedRNumber{<:Real}) return Ops.lgamma(float(x)) end @@ -41,12 +41,11 @@ function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T} return Ops.polygamma(Ops.constant(T(1)), float(x)) end -function SpecialFunctions.polygamma(n::TracedRNumber, x::TracedRNumber) +function SpecialFunctions.polygamma(n::TracedRNumber{<:Real}, x::TracedRNumber{<:Real}) return Ops.polygamma(float(n), float(x)) end function SpecialFunctions.polygamma(n::TracedRNumber{T}, x::TracedRNumber{T}) where {T} - x = promote_to(TracedRNumber{T}, x) return polygamma(n, x) end @@ -54,17 +53,17 @@ end # SpecialFunctions.gamma_inc_inv -function SpecialFunctions.loggammadiv(a::TracedRNumber{T}, b::TracedRNumber{T}) where {T} +function SpecialFunctions.loggammadiv(a::TracedRNumber{T}, b::TracedRNumber{T}) where {T<:Real} return log(gamma(b) / gamma(a + b)) end #SpecialFunctions.gamma ... -function SpecialFunctions.beta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} +function SpecialFunctions.beta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:Real} return gamma(x) * gamma(y) / gamma(x + y) end -function SpecialFunctions.logbeta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} +function SpecialFunctions.logbeta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:Real} return log(abs(beta(x, y))) end @@ -76,25 +75,25 @@ end #utilities... -function SpecialFunctions.erf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} +function SpecialFunctions.erf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:Real} return erf(y) - erf(x) end #SpecialFunctions.erfcinv -function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T} +function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:Real} return log(erf(x, y)) end -function SpecialFunctions.erfcx(x::TracedRNumber) +function SpecialFunctions.erfcx(x::TracedRNumber{<:Real}) return exp(float(x^2)) * erfc(x) end -function SpecialFunctions.logerfc(x::TracedRNumber) +function SpecialFunctions.logerfc(x::TracedRNumber{<:Real}) return log(erfc(x)) end -function SpecialFunctions.logerfcx(x::TracedRNumber) +function SpecialFunctions.logerfcx(x::TracedRNumber{<:Real}) return log(erfcx(x)) end @@ -111,7 +110,7 @@ end #Elliptic Integrals -function SpecialFunctions.zeta(z::TracedRNumber{T}, s::TracedRNumber{T}) where {T} +function SpecialFunctions.zeta(z::TracedRNumber{T}, s::TracedRNumber{T}) where {T<:Real} return Ops.zeta(z, s) end From 8c3a991aed23ed19fd43fc40f5a73b0ecdde68d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Thu, 19 Dec 2024 04:18:18 +0100 Subject: [PATCH 10/29] feedback --- ext/ReactantSpecialFunctionsExt.jl | 48 ++++++++++++------------------ src/TracedRNumber.jl | 19 ++++++++++-- test/basic.jl | 6 ++++ 3 files changed, 42 insertions(+), 31 deletions(-) diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl index a76a08a58..1bcf5eab8 100644 --- a/ext/ReactantSpecialFunctionsExt.jl +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -1,59 +1,49 @@ module ReactantSpecialFunctionsExt using SpecialFunctions -using Reactant: Ops, Reactant, TracedRNumber +using Reactant: Ops, Reactant, TracedRNumber, ReactantFloat using Reactant.TracedRNumberOverrides: float -for fn in [:digamma, :erf, :erfc] - @eval(function SpecialFunctions.$fn(x::TracedRNumber{<:Real}) - return Ops.$fn(float(x)) +for fn in [:digamma, :erf, :erfc, (:loggamma, :lgamma)] + (fns, fno) = fn isa Tuple ? fn : (fn, fn) + @eval(function SpecialFunctions.$fns(x::TracedRNumber{<:Real}) + return Ops.$fno(float(x)) end) end -function SpecialFunctions.gamma(x::TracedRNumber) +function SpecialFunctions.gamma(x::TracedRNumber{<:Real}) return exp(Ops.lgamma(float(x))) end -#TODO: add factorial function -#=function SpecialFunctions.gamma( - n::TracedRNumber{T} -) where {T<:Integer} - factorial(n) +function SpecialFunctions.gamma(n::TracedRNumber{<:Integer}) + return round(gamma(float(n))) end -=# -function SpecialFunctions.loggamma(x::TracedRNumber{<:Real}) - return Ops.lgamma(float(x)) -end - -function SpecialFunctions.loggamma1p(x::TracedRNumber) +function SpecialFunctions.loggamma1p(x::TracedRNumber{<:Real}) + @assert abs(x) < 1 return loggamma(1 + x) end -function SpecialFunctions.logfactorial( - x::TracedRNumber{<:Integer} -) +function SpecialFunctions.logfactorial(x::TracedRNumber{<:Integer}) return loggamma(1 + x) end # SpecialFunctions.invdigamma -function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T} - return Ops.polygamma(Ops.constant(T(1)), float(x)) +function SpecialFunctions.trigamma(x::TracedRNumber{<:Real}) + return Ops.polygamma(Ops.constant(Float64(1)), float(x))#TODO: change Ops definition end function SpecialFunctions.polygamma(n::TracedRNumber{<:Real}, x::TracedRNumber{<:Real}) return Ops.polygamma(float(n), float(x)) end -function SpecialFunctions.polygamma(n::TracedRNumber{T}, x::TracedRNumber{T}) where {T} - return polygamma(n, x) -end - # SpecialFunctions.gamma_inc # SpecialFunctions.gamma_inc_inv -function SpecialFunctions.loggammadiv(a::TracedRNumber{T}, b::TracedRNumber{T}) where {T<:Real} +function SpecialFunctions.loggammadiv( + a::TracedRNumber{T}, b::TracedRNumber{T} +) where {T<:Real} return log(gamma(b) / gamma(a + b)) end @@ -85,15 +75,15 @@ function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where return log(erf(x, y)) end -function SpecialFunctions.erfcx(x::TracedRNumber{<:Real}) +function SpecialFunctions.erfcx(x::TracedRNumber{<:Real}) return exp(float(x^2)) * erfc(x) end -function SpecialFunctions.logerfc(x::TracedRNumber{<:Real}) +function SpecialFunctions.logerfc(x::TracedRNumber{<:Real}) return log(erfc(x)) end -function SpecialFunctions.logerfcx(x::TracedRNumber{<:Real}) +function SpecialFunctions.logerfcx(x::TracedRNumber{<:Real}) return log(erfcx(x)) end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index df664031e..24d558011 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -206,8 +206,6 @@ for (jlop, hloop) in ( (:(Base.log), :log), (:(Base.log1p), :log_plus_one), (:(Base.sqrt), :sqrt), - (:(Base.ceil), :ceil), - (:(Base.floor), :floor), ) @eval $(jlop)(@nospecialize(lhs::TracedRNumber)) = Ops.$(hloop)(lhs) end @@ -233,6 +231,23 @@ function Base.float(x::TracedRNumber{T}) where {T} return TracedUtils.promote_to(TracedRNumber{float(T)}, x) end +using Reactant: ReactantFloat +function Base.round(A::TracedRNumber{<:ReactantFloat}, ::RoundingMode{R}) where {R} + if R == :Nearest + Ops.round_nearest_even(A) + elseif R == :Up + Ops.ceil(A) + elseif R == :Down + Ops.floor(A) + else + @error "$R is unsupported" + end +end + +function Base.round(A::TracedRNumber{<:Integer}, _) + return A +end + # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T} diff --git a/test/basic.jl b/test/basic.jl index 5eff286ed..d97a733db 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -577,6 +577,12 @@ end end end +@testset "$op" for op in [:round, :ceil, :floor] + for x in (rand(Float32, (3, 3)), rand(Int), rand(Float64), rand(Int, (3, 3))) + @eval @test @jit($op.(ConcreteRNumber.($x))) == $op.($x) + end +end + @testset "dynamic indexing" begin x = randn(5, 3) x_ra = Reactant.to_rarray(x) From 539eaca8f76435e2a13f939391f7766d6b1baf89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Thu, 19 Dec 2024 04:22:41 +0100 Subject: [PATCH 11/29] remove assert --- ext/ReactantSpecialFunctionsExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl index 1bcf5eab8..cc5b328de 100644 --- a/ext/ReactantSpecialFunctionsExt.jl +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -19,7 +19,6 @@ function SpecialFunctions.gamma(n::TracedRNumber{<:Integer}) end function SpecialFunctions.loggamma1p(x::TracedRNumber{<:Real}) - @assert abs(x) < 1 return loggamma(1 + x) end From 4d22a588e16b783c42bcd88bc446adca66b52549 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sat, 28 Dec 2024 01:06:38 +0100 Subject: [PATCH 12/29] update signature --- ext/ReactantSpecialFunctionsExt.jl | 34 +++++++++++++++--------------- src/Reactant.jl | 14 +++++++++--- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl index cc5b328de..882b8397f 100644 --- a/ext/ReactantSpecialFunctionsExt.jl +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -1,38 +1,38 @@ module ReactantSpecialFunctionsExt using SpecialFunctions -using Reactant: Ops, Reactant, TracedRNumber, ReactantFloat +using Reactant: Ops, Reactant, TracedRNumber, ReactantFloat, ReactantInt, ReactantFloatInt using Reactant.TracedRNumberOverrides: float for fn in [:digamma, :erf, :erfc, (:loggamma, :lgamma)] (fns, fno) = fn isa Tuple ? fn : (fn, fn) - @eval(function SpecialFunctions.$fns(x::TracedRNumber{<:Real}) + @eval(function SpecialFunctions.$fns(x::TracedRNumber{<:ReactantFloatInt}) return Ops.$fno(float(x)) end) end -function SpecialFunctions.gamma(x::TracedRNumber{<:Real}) +function SpecialFunctions.gamma(x::TracedRNumber{<:ReactantFloat}) return exp(Ops.lgamma(float(x))) end -function SpecialFunctions.gamma(n::TracedRNumber{<:Integer}) +function SpecialFunctions.gamma(n::TracedRNumber{<:ReactantInt}) return round(gamma(float(n))) end -function SpecialFunctions.loggamma1p(x::TracedRNumber{<:Real}) +function SpecialFunctions.loggamma1p(x::TracedRNumber{<:ReactantFloat}) return loggamma(1 + x) end -function SpecialFunctions.logfactorial(x::TracedRNumber{<:Integer}) +function SpecialFunctions.logfactorial(x::TracedRNumber{<:ReactantInt}) return loggamma(1 + x) end # SpecialFunctions.invdigamma -function SpecialFunctions.trigamma(x::TracedRNumber{<:Real}) +function SpecialFunctions.trigamma(x::TracedRNumber{<:ReactantFloatInt}) return Ops.polygamma(Ops.constant(Float64(1)), float(x))#TODO: change Ops definition end -function SpecialFunctions.polygamma(n::TracedRNumber{<:Real}, x::TracedRNumber{<:Real}) +function SpecialFunctions.polygamma(n::TracedRNumber{<:ReactantFloatInt}, x::TracedRNumber{<:ReactantFloatInt}) return Ops.polygamma(float(n), float(x)) end @@ -42,17 +42,17 @@ end function SpecialFunctions.loggammadiv( a::TracedRNumber{T}, b::TracedRNumber{T} -) where {T<:Real} +) where {T<:ReactantFloat} return log(gamma(b) / gamma(a + b)) end #SpecialFunctions.gamma ... -function SpecialFunctions.beta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:Real} +function SpecialFunctions.beta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:ReactantFloatInt} return gamma(x) * gamma(y) / gamma(x + y) end -function SpecialFunctions.logbeta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:Real} +function SpecialFunctions.logbeta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:ReactantFloatInt} return log(abs(beta(x, y))) end @@ -64,25 +64,25 @@ end #utilities... -function SpecialFunctions.erf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:Real} +function SpecialFunctions.erf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:ReactantFloatInt} return erf(y) - erf(x) end #SpecialFunctions.erfcinv -function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:Real} +function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:ReactantFloatInt} return log(erf(x, y)) end -function SpecialFunctions.erfcx(x::TracedRNumber{<:Real}) +function SpecialFunctions.erfcx(x::TracedRNumber{<:ReactantFloatInt}) return exp(float(x^2)) * erfc(x) end -function SpecialFunctions.logerfc(x::TracedRNumber{<:Real}) +function SpecialFunctions.logerfc(x::TracedRNumber{<:ReactantFloatInt}) return log(erfc(x)) end -function SpecialFunctions.logerfcx(x::TracedRNumber{<:Real}) +function SpecialFunctions.logerfcx(x::TracedRNumber{<:ReactantFloatInt}) return log(erfcx(x)) end @@ -99,7 +99,7 @@ end #Elliptic Integrals -function SpecialFunctions.zeta(z::TracedRNumber{T}, s::TracedRNumber{T}) where {T<:Real} +function SpecialFunctions.zeta(z::TracedRNumber{T}, s::TracedRNumber{T}) where {T<:ReactantFloatInt} return Ops.zeta(z, s) end diff --git a/src/Reactant.jl b/src/Reactant.jl index 9a6556992..d1441244e 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -23,8 +23,7 @@ else const ReactantFloat = Union{Float16,Float32,Float64} end -const ReactantPrimitive = Union{ - Bool, +const ReactantInt = Union{ Int8, UInt8, Int16, @@ -33,9 +32,18 @@ const ReactantPrimitive = Union{ UInt32, Int64, UInt64, +} + +const ReactantFloatInt = Union{ + Base.uniontypes(ReactantInt)..., + Base.uniontypes(ReactantFloat)... +} + +const ReactantPrimitive = Union{ + Bool, + Base.uniontypes(ReactantFloatInt)..., Complex{Float32}, Complex{Float64}, - Base.uniontypes(ReactantFloat)..., } abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end From 0ddc365486f7470fae531ce940a53438ce3923f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 00:38:06 +0100 Subject: [PATCH 13/29] add missing def for 1.10, increase ~ tolerence for MacOS --- src/TracedRNumber.jl | 5 +++++ test/integration/special_functions.jl | 6 ++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 24d558011..f63795b65 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -232,6 +232,11 @@ function Base.float(x::TracedRNumber{T}) where {T} end using Reactant: ReactantFloat + +Base.round(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundNearest) +Base.floor(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundDown) +Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundUp) + function Base.round(A::TracedRNumber{<:ReactantFloat}, ::RoundingMode{R}) where {R} if R == :Nearest Ops.round_nearest_even(A) diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl index fa620561e..5ebb5b637 100644 --- a/test/integration/special_functions.jl +++ b/test/integration/special_functions.jl @@ -1,4 +1,7 @@ using SpecialFunctions, Reactant + +macro ≈(a,b) return quote isapprox($a, $b, atol=1e-14) end end + @testset "$op" for (op, n_args) in [ (:gamma, 1), (:loggamma, 1), @@ -16,8 +19,7 @@ using SpecialFunctions, Reactant ] for data in ([0.5, 0.6], [2, 4]) x = data[1:n_args] - @eval @test @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) ≈ - SpecialFunctions.$op($x...) + @eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op($x...) end end From 667575a658dadc54e91f4d88d578b91d4f7d4be6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 03:17:57 +0100 Subject: [PATCH 14/29] missing def int, julia 1.10 --- src/TracedRNumber.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index f63795b65..6b70de0dc 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -249,9 +249,9 @@ function Base.round(A::TracedRNumber{<:ReactantFloat}, ::RoundingMode{R}) where end end -function Base.round(A::TracedRNumber{<:Integer}, _) - return A -end +Base.round(A::TracedRNumber{<:Integer}) = A +Base.floor(A::TracedRNumber{<:Integer}) = A +Base.ceil(A::TracedRNumber{<:Integer}) = A # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) From a7004eb2d05b14e379a61e934e8020459b1cb430 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 03:43:18 +0100 Subject: [PATCH 15/29] format --- test/integration/special_functions.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl index 5ebb5b637..b58a5f960 100644 --- a/test/integration/special_functions.jl +++ b/test/integration/special_functions.jl @@ -1,6 +1,10 @@ using SpecialFunctions, Reactant -macro ≈(a,b) return quote isapprox($a, $b, atol=1e-14) end end +macro ≈(a, b) + return quote + isapprox($a, $b; atol=1e-14) + end +end @testset "$op" for (op, n_args) in [ (:gamma, 1), @@ -19,7 +23,9 @@ macro ≈(a,b) return quote isapprox($a, $b, atol=1e-14) end end ] for data in ([0.5, 0.6], [2, 4]) x = data[1:n_args] - @eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op($x...) + @eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op( + $x... + ) end end From bb191ab25d1992b5183bb43ce0059d8f303972c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 13:37:12 +0100 Subject: [PATCH 16/29] format 2 --- ext/ReactantSpecialFunctionsExt.jl | 24 ++++++++++++++++++------ src/Reactant.jl | 19 +++---------------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl index 882b8397f..db516c8cb 100644 --- a/ext/ReactantSpecialFunctionsExt.jl +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -32,7 +32,9 @@ function SpecialFunctions.trigamma(x::TracedRNumber{<:ReactantFloatInt}) return Ops.polygamma(Ops.constant(Float64(1)), float(x))#TODO: change Ops definition end -function SpecialFunctions.polygamma(n::TracedRNumber{<:ReactantFloatInt}, x::TracedRNumber{<:ReactantFloatInt}) +function SpecialFunctions.polygamma( + n::TracedRNumber{<:ReactantFloatInt}, x::TracedRNumber{<:ReactantFloatInt} +) return Ops.polygamma(float(n), float(x)) end @@ -48,11 +50,15 @@ end #SpecialFunctions.gamma ... -function SpecialFunctions.beta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:ReactantFloatInt} +function SpecialFunctions.beta( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} return gamma(x) * gamma(y) / gamma(x + y) end -function SpecialFunctions.logbeta(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:ReactantFloatInt} +function SpecialFunctions.logbeta( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} return log(abs(beta(x, y))) end @@ -64,13 +70,17 @@ end #utilities... -function SpecialFunctions.erf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:ReactantFloatInt} +function SpecialFunctions.erf( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} return erf(y) - erf(x) end #SpecialFunctions.erfcinv -function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where {T<:ReactantFloatInt} +function SpecialFunctions.logerf( + x::TracedRNumber{T}, y::TracedRNumber{T} +) where {T<:ReactantFloatInt} return log(erf(x, y)) end @@ -99,7 +109,9 @@ end #Elliptic Integrals -function SpecialFunctions.zeta(z::TracedRNumber{T}, s::TracedRNumber{T}) where {T<:ReactantFloatInt} +function SpecialFunctions.zeta( + z::TracedRNumber{T}, s::TracedRNumber{T} +) where {T<:ReactantFloatInt} return Ops.zeta(z, s) end diff --git a/src/Reactant.jl b/src/Reactant.jl index d1441244e..7f858fbea 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -23,27 +23,14 @@ else const ReactantFloat = Union{Float16,Float32,Float64} end -const ReactantInt = Union{ - Int8, - UInt8, - Int16, - UInt16, - Int32, - UInt32, - Int64, - UInt64, -} +const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64} const ReactantFloatInt = Union{ - Base.uniontypes(ReactantInt)..., - Base.uniontypes(ReactantFloat)... + Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... } const ReactantPrimitive = Union{ - Bool, - Base.uniontypes(ReactantFloatInt)..., - Complex{Float32}, - Complex{Float64}, + Bool,Base.uniontypes(ReactantFloatInt)...,Complex{Float32},Complex{Float64} } abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end From a38883c76a594bfc6802dff4ec163ac14be615b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 15:26:17 +0100 Subject: [PATCH 17/29] error --- src/TracedRNumber.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 92926e4fe..a1ccb4ec7 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -243,7 +243,7 @@ function Base.round(A::TracedRNumber{<:ReactantFloat}, ::RoundingMode{R}) where elseif R == :Down Ops.floor(A) else - @error "$R is unsupported" + error("$R is unsupported") end end From bd84cb295258e546d836992c22a68cf3fd91aba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 17:11:05 +0100 Subject: [PATCH 18/29] simplify rounding --- src/TracedRNumber.jl | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index a1ccb4ec7..d730c0bb6 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -231,21 +231,9 @@ end using Reactant: ReactantFloat -Base.round(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundNearest) -Base.floor(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundDown) -Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundUp) - -function Base.round(A::TracedRNumber{<:ReactantFloat}, ::RoundingMode{R}) where {R} - if R == :Nearest - Ops.round_nearest_even(A) - elseif R == :Up - Ops.ceil(A) - elseif R == :Down - Ops.floor(A) - else - error("$R is unsupported") - end -end +Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A) +Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A) +Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A) Base.round(A::TracedRNumber{<:Integer}) = A Base.floor(A::TracedRNumber{<:Integer}) = A From 05822d315f48388bbea214c1d17c1f8e1a983bc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 18:00:00 +0100 Subject: [PATCH 19/29] Revert "simplify rounding" This reverts commit bd84cb295258e546d836992c22a68cf3fd91aba0. --- src/TracedRNumber.jl | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index d730c0bb6..a1ccb4ec7 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -231,9 +231,21 @@ end using Reactant: ReactantFloat -Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A) -Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A) -Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A) +Base.round(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundNearest) +Base.floor(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundDown) +Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundUp) + +function Base.round(A::TracedRNumber{<:ReactantFloat}, ::RoundingMode{R}) where {R} + if R == :Nearest + Ops.round_nearest_even(A) + elseif R == :Up + Ops.ceil(A) + elseif R == :Down + Ops.floor(A) + else + error("$R is unsupported") + end +end Base.round(A::TracedRNumber{<:Integer}) = A Base.floor(A::TracedRNumber{<:Integer}) = A From dca9021b7c915714e3a19c93fdea09e4250df6ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 18:01:37 +0100 Subject: [PATCH 20/29] disable tests --- test/integration/special_functions.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl index b58a5f960..cda515846 100644 --- a/test/integration/special_functions.jl +++ b/test/integration/special_functions.jl @@ -23,9 +23,7 @@ end ] for data in ([0.5, 0.6], [2, 4]) x = data[1:n_args] - @eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op( - $x... - ) + #@eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op($x...) end end From 9a7f0b5ef0d4847f7813683e53e6916179774f5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 19:35:13 +0100 Subject: [PATCH 21/29] revert --- test/integration/special_functions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl index cda515846..3455be37e 100644 --- a/test/integration/special_functions.jl +++ b/test/integration/special_functions.jl @@ -2,7 +2,7 @@ using SpecialFunctions, Reactant macro ≈(a, b) return quote - isapprox($a, $b; atol=1e-14) + isapprox(Array($a), $b; atol=1e-14) end end @@ -23,7 +23,7 @@ end ] for data in ([0.5, 0.6], [2, 4]) x = data[1:n_args] - #@eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op($x...) + @eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op($x...) end end From 62fa41cb0ccc9c51dbfd434bd04326899e753af9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 20:19:10 +0100 Subject: [PATCH 22/29] revert --- test/integration/special_functions.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl index 3455be37e..b58a5f960 100644 --- a/test/integration/special_functions.jl +++ b/test/integration/special_functions.jl @@ -2,7 +2,7 @@ using SpecialFunctions, Reactant macro ≈(a, b) return quote - isapprox(Array($a), $b; atol=1e-14) + isapprox($a, $b; atol=1e-14) end end @@ -23,7 +23,9 @@ end ] for data in ([0.5, 0.6], [2, 4]) x = data[1:n_args] - @eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op($x...) + @eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op( + $x... + ) end end From 6880d1693555aabdf880d0448a2e8ea840404cf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 29 Dec 2024 23:44:01 +0100 Subject: [PATCH 23/29] test CI --- test/integration/special_functions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl index b58a5f960..01ac3e96d 100644 --- a/test/integration/special_functions.jl +++ b/test/integration/special_functions.jl @@ -23,7 +23,7 @@ end ] for data in ([0.5, 0.6], [2, 4]) x = data[1:n_args] - @eval @test @≈ @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...)) SpecialFunctions.$op( + @eval @test @≈ @jit(float(SpecialFunctions.$op(ConcreteRNumber.($x)...))) SpecialFunctions.$op( $x... ) end From f60de70e2fc8c1666e41c8ee5fecb6e503a212a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 30 Dec 2024 00:14:58 +0100 Subject: [PATCH 24/29] good order --- test/integration/special_functions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl index 01ac3e96d..f7c2aec61 100644 --- a/test/integration/special_functions.jl +++ b/test/integration/special_functions.jl @@ -23,7 +23,7 @@ end ] for data in ([0.5, 0.6], [2, 4]) x = data[1:n_args] - @eval @test @≈ @jit(float(SpecialFunctions.$op(ConcreteRNumber.($x)...))) SpecialFunctions.$op( + @eval @test @≈ float(@jit(SpecialFunctions.$op(ConcreteRNumber.($x)...))) SpecialFunctions.$op( $x... ) end From 35f0e9db3409b9a3121e011c54c8542ac1ca9c04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 30 Dec 2024 00:51:10 +0100 Subject: [PATCH 25/29] remove fancy call --- test/integration/special_functions.jl | 97 +++++++++++++++++++++------ 1 file changed, 76 insertions(+), 21 deletions(-) diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl index f7c2aec61..60eedfbfd 100644 --- a/test/integration/special_functions.jl +++ b/test/integration/special_functions.jl @@ -6,27 +6,82 @@ macro ≈(a, b) end end -@testset "$op" for (op, n_args) in [ - (:gamma, 1), - (:loggamma, 1), - (:digamma, 1), - (:trigamma, 1), - (:beta, 2), - (:logbeta, 2), - (:erf, 1), - (:erf, 2), - (:erfc, 1), - (:logerf, 2), - (:erfcx, 1), - (:logerfc, 1), - (:logerfcx, 1), -] - for data in ([0.5, 0.6], [2, 4]) - x = data[1:n_args] - @eval @test @≈ float(@jit(SpecialFunctions.$op(ConcreteRNumber.($x)...))) SpecialFunctions.$op( - $x... - ) - end +@testset "gamma" begin + @test SpecialFunctions.gamma(0.5) ≈ @jit(SpecialFunctions.gamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.gamma(2) ≈ @jit(SpecialFunctions.gamma(ConcreteRNumber(2))) +end + +@testset "loggamma" begin + @test SpecialFunctions.loggamma(0.5) ≈ + @jit(SpecialFunctions.loggamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.loggamma(2) ≈ @jit(SpecialFunctions.loggamma(ConcreteRNumber(2))) +end + +@testset "digamma" begin + @test SpecialFunctions.digamma(0.5) ≈ + @jit(SpecialFunctions.digamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.digamma(2) ≈ @jit(SpecialFunctions.digamma(ConcreteRNumber(2))) +end + +@testset "trigamma" begin + @test SpecialFunctions.trigamma(0.5) ≈ + @jit(SpecialFunctions.trigamma(ConcreteRNumber(0.5))) + @test SpecialFunctions.trigamma(2) ≈ @jit(SpecialFunctions.trigamma(ConcreteRNumber(2))) +end + +@testset "beta" begin + @test SpecialFunctions.beta(0.5, 0.6) ≈ + @jit(SpecialFunctions.beta(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.beta(2, 4) ≈ + @jit(SpecialFunctions.beta(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "logbeta" begin + @test SpecialFunctions.logbeta(0.5, 0.6) ≈ + @jit(SpecialFunctions.logbeta(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.logbeta(2, 4) ≈ + @jit(SpecialFunctions.logbeta(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "erf" begin + @test SpecialFunctions.erf(0.5) ≈ @jit(SpecialFunctions.erf(ConcreteRNumber(0.5))) + @test SpecialFunctions.erf(2) ≈ @jit(SpecialFunctions.erf(ConcreteRNumber(2))) +end + +@testset "erf with 2 arguments" begin + @test SpecialFunctions.erf(0.5, 0.6) ≈ + @jit(SpecialFunctions.erf(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.erf(2, 4) ≈ + @jit(SpecialFunctions.erf(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "erfc" begin + @test SpecialFunctions.erfc(0.5) ≈ @jit(SpecialFunctions.erfc(ConcreteRNumber(0.5))) + @test SpecialFunctions.erfc(2) ≈ @jit(SpecialFunctions.erfc(ConcreteRNumber(2))) +end + +@testset "logerf" begin + @test SpecialFunctions.logerf(0.5, 0.6) ≈ + @jit(SpecialFunctions.logerf(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) + @test SpecialFunctions.logerf(2, 4) ≈ + @jit(SpecialFunctions.logerf(ConcreteRNumber(2), ConcreteRNumber(4))) +end + +@testset "erfcx" begin + @test SpecialFunctions.erfcx(0.5) ≈ @jit(SpecialFunctions.erfcx(ConcreteRNumber(0.5))) + @test SpecialFunctions.erfcx(2) ≈ @jit(SpecialFunctions.erfcx(ConcreteRNumber(2))) +end + +@testset "logerfc" begin + @test SpecialFunctions.logerfc(0.5) ≈ + @jit(SpecialFunctions.logerfc(ConcreteRNumber(0.5))) + @test SpecialFunctions.logerfc(2) ≈ @jit(SpecialFunctions.logerfc(ConcreteRNumber(2))) +end + +@testset "logerfcx" begin + @test SpecialFunctions.logerfcx(0.5) ≈ + @jit(SpecialFunctions.logerfcx(ConcreteRNumber(0.5))) + @test SpecialFunctions.logerfcx(2) ≈ @jit(SpecialFunctions.logerfcx(ConcreteRNumber(2))) end @testset "loggamma1p" begin From af777ac89f9bc05d55cc02da3614c06c59d745b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 30 Dec 2024 01:35:22 +0100 Subject: [PATCH 26/29] test --- src/TracedRNumber.jl | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index a1ccb4ec7..07303a1f8 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -229,27 +229,15 @@ function Base.float(x::TracedRNumber{T}) where {T} return TracedUtils.promote_to(TracedRNumber{float(T)}, x) end -using Reactant: ReactantFloat - -Base.round(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundNearest) -Base.floor(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundDown) -Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Base.round(A, RoundUp) - -function Base.round(A::TracedRNumber{<:ReactantFloat}, ::RoundingMode{R}) where {R} - if R == :Nearest - Ops.round_nearest_even(A) - elseif R == :Up - Ops.ceil(A) - elseif R == :Down - Ops.floor(A) - else - error("$R is unsupported") - end -end +using Reactant: ReactantFloat, ReactantInt + +Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A) +Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A) +Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A) -Base.round(A::TracedRNumber{<:Integer}) = A -Base.floor(A::TracedRNumber{<:Integer}) = A -Base.ceil(A::TracedRNumber{<:Integer}) = A +Base.round(A::TracedRNumber{<:ReactantInt}) = A +Base.floor(A::TracedRNumber{<:ReactantInt}) = A +Base.ceil(A::TracedRNumber{<:ReactantInt}) = A # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) From b3df3401015853effdffc8ad1a6b68726b997f39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 30 Dec 2024 02:36:13 +0100 Subject: [PATCH 27/29] new test --- src/TracedRNumber.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 07303a1f8..79c042a7a 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -229,15 +229,11 @@ function Base.float(x::TracedRNumber{T}) where {T} return TracedUtils.promote_to(TracedRNumber{float(T)}, x) end -using Reactant: ReactantFloat, ReactantInt +using Reactant: ReactantFloatInt -Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A) -Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A) -Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A) - -Base.round(A::TracedRNumber{<:ReactantInt}) = A -Base.floor(A::TracedRNumber{<:ReactantInt}) = A -Base.ceil(A::TracedRNumber{<:ReactantInt}) = A +Base.round(A::TracedRNumber{<:ReactantFloatInt}) = Ops.round_nearest_even(A) +Base.floor(A::TracedRNumber{<:ReactantFloatInt}) = Ops.floor(A) +Base.ceil(A::TracedRNumber{<:ReactantFloatInt}) = Ops.ceil(A) # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) From f6bc8a0d58a7fa37a15ccbaf8ac7c143eb10f834 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Mon, 30 Dec 2024 03:02:04 +0100 Subject: [PATCH 28/29] round &co need float --- src/TracedRNumber.jl | 8 ++++---- test/basic.jl | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 79c042a7a..aeb76e0e1 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -229,11 +229,11 @@ function Base.float(x::TracedRNumber{T}) where {T} return TracedUtils.promote_to(TracedRNumber{float(T)}, x) end -using Reactant: ReactantFloatInt +using Reactant: ReactantFloat -Base.round(A::TracedRNumber{<:ReactantFloatInt}) = Ops.round_nearest_even(A) -Base.floor(A::TracedRNumber{<:ReactantFloatInt}) = Ops.floor(A) -Base.ceil(A::TracedRNumber{<:ReactantFloatInt}) = Ops.ceil(A) +Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A) +Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A) +Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A) # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) diff --git a/test/basic.jl b/test/basic.jl index 58297f286..e6bfed797 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -578,7 +578,7 @@ end end @testset "$op" for op in [:round, :ceil, :floor] - for x in (rand(Float32, (3, 3)), rand(Int), rand(Float64), rand(Int, (3, 3))) + for x in (rand(Float32, (3, 3)), rand(Float64)) @eval @test @jit($op.(ConcreteRNumber.($x))) == $op.($x) end end From 88b2a12b3ea67fbd7ae24ea21f9424f6330cc2c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Wed, 1 Jan 2025 02:43:19 +0100 Subject: [PATCH 29/29] format --- ext/ReactantSpecialFunctionsExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantSpecialFunctionsExt.jl b/ext/ReactantSpecialFunctionsExt.jl index db516c8cb..9ed006e38 100644 --- a/ext/ReactantSpecialFunctionsExt.jl +++ b/ext/ReactantSpecialFunctionsExt.jl @@ -115,4 +115,4 @@ function SpecialFunctions.zeta( return Ops.zeta(z, s) end -end # module ReactantSpecialFunctionsExt \ No newline at end of file +end # module ReactantSpecialFunctionsExt