Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpecialFunctions simple functions #384

Merged
merged 35 commits into from
Jan 1, 2025
Merged

Conversation

glou-nes
Copy link
Contributor

Simple case of #381.

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2024

amazing!

Project.toml Outdated Show resolved Hide resolved
Project.toml Outdated Show resolved Hide resolved
Copy link
Collaborator

@mofeing mofeing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so regarding using Union to constraint types in methods, i don't think it's a good the best idea. i'm pro the idea by @oxinabox that types in method signatures should be just used for dispatching, not for as type constraints.

the reasons why are:

  1. HLO dialects and XLA can (and will) change: checkout for example the bug i informed about stablehlo.cbrt on complex numbers, they are probably gonna remove support for it. this forces us to be super cautious of every update on it (which we kinda are already, but we wouldn't need to update the code).
  2. it goes a lil bit against composability: a user can come, implement their own reduced float type in Julia, tell Reactant how to lower it to MLIR and it should work. Unions have the problem that are not extendable.

this is how i developed Ops in principle

also check it out, that as of the current implementation, there are are some methods that will fail. take SpecialFunctions.gamma

function SpecialFunctions.gamma(x::TracedRNumber{T}) where {T<:ReactantFloat}
    return exp(Ops.lgamma(x))
end

it will be correctly dispatched on TracedRNumber{Float64}, but what will happen with TracedRNumber{Int64} or TracedRNumber{ComplexF64}. if SpecialFunctions.gamma's definition is wide enough (like no type constraint, which i'm not sure it does but it's an example), then it will try to trace, won't throw a MethodError, and... it will probably do nasty things inside.

IMO it's not bad that it fails on MLIR while it doesn't segfault. and if it fails, you can use the with_debug func to get the full Julia + MLIR stacktrace.


function SpecialFunctions.logfactorial(
x::TracedRNumber{T}
) where {T<:Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glou-nes @avik-pal @wsmoses i think it's better if we don't hardcode limits like this because the spec and XLA implementation can change. we can use a Integer here if you prefer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean this is where I think we should have a ReactantInteger = Union{actual types}, which we can update once if the spec / support changes throughout

ext/ReactantSpecialFunctionsExt.jl Outdated Show resolved Hide resolved
test/integration/special_functions.jl Outdated Show resolved Hide resolved
@glou-nes
Copy link
Contributor Author

glou-nes commented Dec 16, 2024

so regarding using Union to constraint types in methods, i don't think it's a good the best idea. i'm pro the idea by @oxinabox that types in method signatures should be just used for dispatching, not for as type constraints.

the reasons why are:

1. HLO dialects and XLA can (and will) change: checkout for example the bug i informed about `stablehlo.cbrt` on complex numbers, they are probably gonna remove support for it. this forces us to be super cautious of every update on it (which we kinda are already, but we wouldn't need to update the code).

2. it goes a lil bit against composability: a user can come, implement their own reduced float type in Julia, tell Reactant how to lower it to MLIR and it should work. `Union`s have the problem that are not extendable.

this is how i developed Ops in principle

also check it out, that as of the current implementation, there are are some methods that will fail. take SpecialFunctions.gamma

function SpecialFunctions.gamma(x::TracedRNumber{T}) where {T<:ReactantFloat}
    return exp(Ops.lgamma(x))
end

it will be correctly dispatched on TracedRNumber{Float64}, but what will happen with TracedRNumber{Int64} or TracedRNumber{ComplexF64}. if SpecialFunctions.gamma's definition is wide enough (like no type constraint, which i'm not sure it does but it's an example), then it will try to trace, won't throw a MethodError, and... it will probably do nasty things inside.

IMO it's not bad that it fails on MLIR while it doesn't segfault. and if it fails, you can use the with_debug func to get the full Julia + MLIR stacktrace.

Thank for the answer! Firstly, I want to add that for your example, it supports TracedRNumber{Int64} using the fallback method here: Is it not enough to get composability back?

for fn in [:gamma, :loggamma, :digamma, :erf, :erfc]
    @eval(function SpecialFunctions.$fn(x::TracedRNumber{<:Number})
        return $fn(float(x))
    end)
end

My major problem here is the difference between SpecialFunctions and StableHLO semantics. it's really similar to stablehlo.cbrt (complex number and integers are not supported with StableHLO) but are supported by SpecialFunctions. I tried to follow SpecialFunctions typing (which is a bit strict using Number and Integer for instance) and error in this file with complex number for instance. I suppose this file handle both semantics to support smooth integration with functions of SpecialFunctions and is a kind of opaque interface. I understand the idea for Ops need to be completely unconstrained. It is not better to get a Julia dispatch error than a XLA runtime for that kind of external integration?

@mofeing
Copy link
Collaborator

mofeing commented Dec 16, 2024

Thank for the answer! Firstly, I want to add that for your example, it supports TracedRNumber{Int64} using the fallback method here: Is it not enough to get composability back?

for fn in [:gamma, :loggamma, :digamma, :erf, :erfc]
    @eval(function SpecialFunctions.$fn(x::TracedRNumber{<:Number})
        return $fn(float(x))
    end)
end

okay, but keep in mind that float on Complex returns a Complex{<:AbstractFloat}.

how about this: you remove the type constraint of the method i pointed (or leave it to just TracedRNumber or TracedRNumber{<:Real}), and you call float inside. no need for Unions that way.

My major problem here is the difference between SpecialFunctions and StableHLO semantics. it's really similar to stablehlo.cbrt (complex number and integers are not supported with StableHLO) but are supported by SpecialFunctions

keep in mind that the problem is that StableHLO spec supports complex numbers for cbrt but XLA doesn't. and StableHLO spec should be a reflect of XLA so it's a bug in the spec and seems like they are going to remove it from the spec.

I suppose this file handle both semantics to support smooth integration with functions of SpecialFunctions and is a kind of opaque interface. I understand the idea for Ops need to be completely unconstrained. It is not better to get a Julia dispatch error than a XLA runtime for that kind of external integration?

so i think there are different visions in the project in this sense and probably we need to discuss in the next meeeting, but undoubtely it's better to have a Julia error than a XLA error. my opinion is that when choosing between better errors on the Julia side or better dispatch/composability (at least in Ops), we should prefer better dispatch/composability. but there can be different opinions in this aspect.

@glou-nes
Copy link
Contributor Author

a completely generic fallback with TracedRNumber is better!
Both version, generated runtime error with complex anyway. At least now, it's simpler.

error: 'chlo.lgamma' op operand #0 must be tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<complex<f64>>'

@mofeing
Copy link
Collaborator

mofeing commented Dec 16, 2024

a completely generic fallback with TracedRNumber is better! Both version, generated runtime error with complex anyway. At least now, it's simpler.

error: 'chlo.lgamma' op operand #0 must be tensor of f4E2M1FN type or f6E2M3FN type or f6E3M2FN type or f8E3M4 type or f8E4M3 type or f8E4M3FN type or f8E4M3FNUZ type or f8E4M3B11FNUZ type or f8E5M2 type or f8E5M2FNUZ type or f8E8M0FNU type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<complex<f64>>'

you can put a constraint on <:Real in that case

@mofeing
Copy link
Collaborator

mofeing commented Dec 16, 2024

also, now that you do the conversion, I would add tests with different input types.

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2024

@mofeing I disagree with explicitly marking broader support here (e.g. of Number). In particular, we only want to override methods which we know we can implement fully with Ops (aka with stablehlo).

See #384 (comment) for example

@wsmoses
Copy link
Member

wsmoses commented Dec 18, 2024

yeah this looks like it needs a rebase (and also using concrete types per above)

@wsmoses
Copy link
Member

wsmoses commented Dec 28, 2024

The round ceil floor tests now fail and should be fixed but otherwise LGTM to merge!

@wsmoses
Copy link
Member

wsmoses commented Dec 29, 2024

hmm CI still does not seem to like it:

  Test threw exception
  Expression: #= /home/runner/work/Reactant.jl/Reactant.jl/test/basic.jl:582 =# @jit(floor.(ConcreteRNumber.([5617217728973374046 -3007515435434531812 -5171351284935898551; -157483289222511147 -424687138072795833 -2567171450103491194; 6843301087142890975 1766303807932526665 -9188973971342681186]))) == floor.([5617217728973374046 -3007515435434531812 -5171351284935898551; -157483289222511147 -424687138072795833 -2567171450103491194; 6843301087142890975 1766303807932526665 -9188973971342681186])
  MethodError: no method matching floor(::Reactant.TracedRNumber{Int64})
  
  Closest candidates are:
    floor(::Type{Bool}, ::AbstractFloat)
     @ Base float.jl:389
    floor(::Type{Dates.Date}, ::Union{Dates.Day, Dates.Week, Dates.TimePeriod, Dates.TimeType}, ::Type{P}) where P<:Dates.Period
     @ Dates /opt/hostedtoolcache/julia/1.10.7/x64/share/julia/stdlib/v1.10/Dates/src/rounding.jl:287
    floor(::Missing; sigdigits, digits, base)
     @ Base missing.jl:155
    ...
  
  Stacktrace:
    [1] macro expansion
      @ ~/work/Reactant.jl/Reactant.jl/src/utils.jl:0 [inlined]
    [2] call_with_reactant(::typeof(floor), ::Reactant.TracedRNumber{Int64})
      @ Reactant ~/work/Reactant.jl/Reactant.jl/src/utils.jl:527
    [3] _broadcast_getindex_evalf

@wsmoses
Copy link
Member

wsmoses commented Dec 29, 2024

@glou-nes can you rebase?

src/TracedRNumber.jl Outdated Show resolved Hide resolved
@wsmoses
Copy link
Member

wsmoses commented Dec 29, 2024

@glou-nes it looks like some test is segfaulting here, can you isolate it into a MWE to debug?

@glou-nes
Copy link
Contributor Author

glou-nes commented Jan 1, 2025

@glou-nes it looks like some test is segfaulting here, can you isolate it into a MWE to debug?

It seems that the segfault was from this: Anyway, I removed the second part it's not needed.

Base.round(A::TracedRNumber{<:ReactantFloat}) = Ops.round_nearest_even(A)
Base.floor(A::TracedRNumber{<:ReactantFloat}) = Ops.floor(A)
Base.ceil(A::TracedRNumber{<:ReactantFloat}) = Ops.ceil(A)

Base.round(A::TracedRNumber{<:Integer}) = A
Base.floor(A::TracedRNumber{<:Integer}) = A
Base.ceil(A::TracedRNumber{<:Integer}) = A

@wsmoses wsmoses merged commit 0343a39 into EnzymeAD:main Jan 1, 2025
28 of 38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants