-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SpecialFunctions
simple functions
#384
Conversation
amazing! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so regarding using Union
to constraint types in methods, i don't think it's a good the best idea. i'm pro the idea by @oxinabox that types in method signatures should be just used for dispatching, not for as type constraints.
the reasons why are:
- HLO dialects and XLA can (and will) change: checkout for example the bug i informed about
stablehlo.cbrt
on complex numbers, they are probably gonna remove support for it. this forces us to be super cautious of every update on it (which we kinda are already, but we wouldn't need to update the code). - it goes a lil bit against composability: a user can come, implement their own reduced float type in Julia, tell Reactant how to lower it to MLIR and it should work.
Union
s have the problem that are not extendable.
this is how i developed Ops in principle
also check it out, that as of the current implementation, there are are some methods that will fail. take SpecialFunctions.gamma
function SpecialFunctions.gamma(x::TracedRNumber{T}) where {T<:ReactantFloat}
return exp(Ops.lgamma(x))
end
it will be correctly dispatched on TracedRNumber{Float64}
, but what will happen with TracedRNumber{Int64}
or TracedRNumber{ComplexF64}
. if SpecialFunctions.gamma
's definition is wide enough (like no type constraint, which i'm not sure it does but it's an example), then it will try to trace, won't throw a MethodError
, and... it will probably do nasty things inside.
IMO it's not bad that it fails on MLIR while it doesn't segfault. and if it fails, you can use the with_debug
func to get the full Julia + MLIR stacktrace.
ext/ReactantSpecialFunctionsExt.jl
Outdated
|
||
function SpecialFunctions.logfactorial( | ||
x::TracedRNumber{T} | ||
) where {T<:Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean this is where I think we should have a ReactantInteger = Union{actual types}, which we can update once if the spec / support changes throughout
Thank for the answer! Firstly, I want to add that for your example, it supports TracedRNumber{Int64} using the fallback method here: Is it not enough to get composability back?
My major problem here is the difference between |
okay, but keep in mind that how about this: you remove the type constraint of the method i pointed (or leave it to just
keep in mind that the problem is that StableHLO spec supports complex numbers for
so i think there are different visions in the project in this sense and probably we need to discuss in the next meeeting, but undoubtely it's better to have a Julia error than a XLA error. my opinion is that when choosing between better errors on the Julia side or better dispatch/composability (at least in Ops), we should prefer better dispatch/composability. but there can be different opinions in this aspect. |
a completely generic fallback with
|
you can put a constraint on |
also, now that you do the conversion, I would add tests with different input types. |
@mofeing I disagree with explicitly marking broader support here (e.g. of Number). In particular, we only want to override methods which we know we can implement fully with Ops (aka with stablehlo). See #384 (comment) for example |
yeah this looks like it needs a rebase (and also using concrete types per above) |
The round ceil floor tests now fail and should be fixed but otherwise LGTM to merge! |
hmm CI still does not seem to like it: Test threw exception
Expression: #= /home/runner/work/Reactant.jl/Reactant.jl/test/basic.jl:582 =# @jit(floor.(ConcreteRNumber.([5617217728973374046 -3007515435434531812 -5171351284935898551; -157483289222511147 -424687138072795833 -2567171450103491194; 6843301087142890975 1766303807932526665 -9188973971342681186]))) == floor.([5617217728973374046 -3007515435434531812 -5171351284935898551; -157483289222511147 -424687138072795833 -2567171450103491194; 6843301087142890975 1766303807932526665 -9188973971342681186])
MethodError: no method matching floor(::Reactant.TracedRNumber{Int64})
Closest candidates are:
floor(::Type{Bool}, ::AbstractFloat)
@ Base float.jl:389
floor(::Type{Dates.Date}, ::Union{Dates.Day, Dates.Week, Dates.TimePeriod, Dates.TimeType}, ::Type{P}) where P<:Dates.Period
@ Dates /opt/hostedtoolcache/julia/1.10.7/x64/share/julia/stdlib/v1.10/Dates/src/rounding.jl:287
floor(::Missing; sigdigits, digits, base)
@ Base missing.jl:155
...
Stacktrace:
[1] macro expansion
@ ~/work/Reactant.jl/Reactant.jl/src/utils.jl:0 [inlined]
[2] call_with_reactant(::typeof(floor), ::Reactant.TracedRNumber{Int64})
@ Reactant ~/work/Reactant.jl/Reactant.jl/src/utils.jl:527
[3] _broadcast_getindex_evalf |
@glou-nes can you rebase? |
@glou-nes it looks like some test is segfaulting here, can you isolate it into a MWE to debug? |
It seems that the segfault was from this: Anyway, I removed the second part it's not needed.
|
Simple case of #381.