-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: overlay eltype conversion #425
Conversation
i think i prefer this over #287 |
For some reason this overlay isn't working in Lux correctly. Error During Test at /mnt/software/lux/Lux.jl/test/reactant/training_tests.jl:48
Got exception outside of a @test
MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
(::Type{T})(::T) where T<:Number
@ Core boot.jl:900
Float32(::IrrationalConstants.Log4π)
@ IrrationalConstants /mnt/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
Float32(::IrrationalConstants.Sqrt3)
@ IrrationalConstants /mnt/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:113
...
Stacktrace:
[1] apply!(o::Lux.ReactantCompatibleOptimisers.ReactantAdamW{Reactant.TracedRNumber{Float32}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, state::Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}, x::Reactant.TracedRArray{Float32, 2}, dx::Reactant.TracedRArray{Float32, 2})
@ Lux.ReactantCompatibleOptimisers /mnt/software/lux/Lux.jl/src/helpers/optimizers.jl:157
[2] #_update!#10
@ /mnt/.julia/packages/Optimisers/a4OnF/src/interface.jl:96 [inlined]
[3] var"#_update!#10"(none::IdDict{Optimisers.Leaf, Any}, none::IdDict{Any, Any}, none::typeof(Optimisers._update!), none::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantAdamW{Reactant.TracedRNumber{Float32}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, none::Reactant.TracedRArray{Float32, 2})
@ Reactant ./<missing>:0
[4] #_update!#10
@ /mnt/.julia/packages/Optimisers/a4OnF/src/interface.jl:93 [inlined]
[5] call_with_reactant(::Optimisers.var"##_update!#10", ::IdDict{Optimisers.Leaf, Any}, ::IdDict{Any, Any}, ::typeof(Optimisers._update!), ::Optimisers.Leaf{Lux.ReactantCompatibleOptimisers.ReactantAdamW{Reactant.TracedRNumber{Float32}, Tuple{Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Reactant.TracedRNumber{Float64}, Reactant.TracedRNumber{Float64}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, ::Reactant.TracedRArray{Float32, 2})
@ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:0 need to investigate a bit into this |
Changing the return type is extremely dangerous and should be avoided if possible (since then the IR can get into an inconsistent state and crash) |
Should we overload the regular function (outside the interpreter) instead? It would be somewhat inconsistent with how convert generally works |
Alternative thought what if we made tracedarray{T} <: AbstractArray{TracedRNumber{T}} |
this is problematic, since many times you are gonna expect that check out what i described in #308 (comment) and #287 (comment) |
This is a very common usecase: https://github.com/FluxML/Optimisers.jl/blob/34250b20103cdd42cb0b3eeae0ce8ee118576eb1/src/rules.jl#L133 as an example