Fix broadcasts which are type unstable with Dual numbers #1441
+53
−40
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Closes #1439.
This is acheived by moving a dispatch from the element type of the output of broadcasting, to each individual element within the pullback, along with ensuring non-concrete element type arrays take the same path as concrete Dual arrays. This should hopefully compile away when the eltype is concrete, and indeed some simple benchmarks show no loss of performance, but I've hardly been exhaustive.
I've also added a few tests covering various real / complex input / output combinations, and a specific case that produced errors before rather than silently failing.
While its nice that this works, it could be worth adding a note to the documentation about the performance of broadcasting which has a type stable forward pass but becomes type unstable on Dual inputs, and perhaps that likewise such Dual input stability is required for the code to work on the GPU. But I'm not sure where that could go.
Edit: just to note, I did also try merging the complex and real input branches into one function that dispatched according to the argument type, but while it worked well on CPU this seemed to stump the GPU compiler on some cases for reasons I don't understand.
PR Checklist