-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Confusing error / silent failure with broadcasted functions with type instability #1439
Comments
An error would be better than the present state, e.g. When |
An array of zeros doesn't seem quite right, in the first MWE above that would lead to incorrect zero gradients if I understand correctly? Assuming Dual seems like it might work, since calling partials on a real or complex simply returns 0.0 anyway, although it might require a rework of the branching on complex inputs. Though I don't think the compiler will remove things if every element is not a Dual, so the quicker branch should be left for when the compiler confirms that the eltype isn't Dual. Perhaps the dispatch on complex outputs could be moved inside the _broadcast_forward and _broadcast_forward_complex loops using another internal function? E.g. on this line Zygote.jl/src/lib/broadcast.jl Line 298 in 2f49370
to split on complex o1 instead to do Zygote.jl/src/lib/broadcast.jl Line 311 in 2f49370
when required, so it is dispatched element wise. Should produce the same code when the eltype is uniform? |
Small update: I have a fix for this written I think, just need to add tests. |
When a function is broadcasted which is type unstable with Dual type inputs, there is a good chance the element type of the resulting output will be abstract, leading to a failure of the logic at
Zygote.jl/src/lib/broadcast.jl
Line 284 in 2f49370
A MWE of silent failure on 1.9.0, in a temporary environment with only Zygote:
In contrast with the expected behaviour of:
A MWE of error, using repeat with the inner keyword as an example which doesn't allow the Dual to leak:
results in:
which leaves it unclear where the Duals originate from, since the forward pass succeeds with incorrect outputs:
In the long run it would be better to fix this, however, in the short term simply adding an error before
Zygote.jl/src/lib/broadcast.jl
Line 284 in 2f49370
The text was updated successfully, but these errors were encountered: