Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Towards fixing splatting properly #52

Merged
merged 5 commits into from
Nov 20, 2023
Merged

Towards fixing splatting properly #52

merged 5 commits into from
Nov 20, 2023

Conversation

willtebbutt
Copy link
Collaborator

It turns out that while #48 didn't break any of the tests in Umlaut, it did break some examples of splatting in code that I'm testing on. This lead me to dig around a little, and construct the additional test cases for Umlaut that I've added in this PR.

Notably:

  1. the current implementation fails when you splat any data type which doesn't support getfield -- this is most things other than Tuple and NamedTuple.
  2. the previous implementation worked for anything for which getindex works, which is a greater array of things than getfield
  3. the previous implementation fails for iterators which don't have getindex defined on them. I've created one of these by adding a test involving a zip, for which getindex doesn't work.

I'm not entirely sure what the right answer is here. I would really rather not revert to the getindex implementation of unsplat!, because it puts a (potentially) non-primitive on the tape, and doesn't handle all of the cases.

One option is to, in the general case, insert a Core._apply_iterate(itr, tuple, x...) onto the tape, which will output a Tuple -- we could then employ the current strategy on that Tuple (getfield calls etc). This is a little bit annoying, but it at least has the property that it should handle every case. Downstream consumers of the tapes produced by Umlaut could then handle whichever range of options they prefer. We could then add special handlers for Tuples, Vectors`, etc.

What are your thoughts @dfdx ?

@dfdx
Copy link
Owner

dfdx commented Nov 14, 2023

I had a similar problem in Yota.jl when I needed getfield primitive protected from backprop mechanism. I solved it by introducing a special _getfield() function with the same behavior but different "primitiveness". Can we do the same thing here? E.g. introduce some internal function which is guaranteed to be a primitive and does what we need for all the types we care about?

@willtebbutt
Copy link
Collaborator Author

willtebbutt commented Nov 15, 2023

I like this idea. How about the following: for each argument which is getting splatted, we proceed in two steps:

  1. insert a Call which converts to a type which supports getfield (i.e. Tuple). This would be a primitive.
  2. apply the current Base.getfield to produce individual arguments.

The first step can be specialised, depending on the type. For example, if the type is already a Tuple or NamedTuple, then there's nothing that needs doing.

This approach has the limitation that the conversion to a Tuple might be hard to write a rule for, which would make life hard for Yota and my AD project, but we can always refactor in the future if we can think of an improvement.

edit: we could call the conversion function __convert_to_tuple_for_splatting__, or something similarly verbose which makes it clear that something interesting is going on.

@dfdx
Copy link
Owner

dfdx commented Nov 15, 2023

Yes, __convert_to_tuple_for_splatting__ is one option. Maybe a bit more intuitive option is to introduce __getfield__ which gets elements directly from the splatted data type, but I don't have enough examples in my mind to evaluate corner cases. Perhaps we just need to try out something and see how it goes.

@willtebbutt
Copy link
Collaborator Author

I see your point -- the problem is that not every iterable naturally supports a getfield / getindex-like function. The interface that they support is iterate, which lets you get the next element. For example,

x = Iterators.takewhile(>(0), [0.1, 0.1, 0.1, -0.1, 0.1, 0.1, 0.1, 0.1])

You can definitely splat x:

julia> tuple(x...)
(0.1, 0.1, 0.1)

so it's within the scope of what we're interested in, but length(x) and getindex(x, 1) both yield MethodErrors, because you can't tell what length(x) is without iterating over x and seeing when it ends. So the reason that I proposed to do this using two steps is to ensure that we iterate over a collection at most once, and that it makes sense to call our getfield-equivalent function on a data type that it makes sense to do so.

Does this seem reasonable, or am I missing something?

@dfdx
Copy link
Owner

dfdx commented Nov 16, 2023

Ah, right, I didn't take into account that we will need to iterate over collection multiple times. So yes, your solution looks like the best option.

@willtebbutt
Copy link
Collaborator Author

Cool. I'll work on that today.

@codecov-commenter
Copy link

codecov-commenter commented Nov 20, 2023

Codecov Report

Attention: 8 lines in your changes are missing coverage. Please review.

Comparison is base (a82adaf) 1.27% compared to head (7868498) 0.74%.

Files Patch % Lines
src/trace.jl 0.00% 8 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@           Coverage Diff            @@
##            main     #52      +/-   ##
========================================
- Coverage   1.27%   0.74%   -0.53%     
========================================
  Files          8       7       -1     
  Lines        707     670      -37     
========================================
- Hits           9       5       -4     
+ Misses       698     665      -33     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@willtebbutt
Copy link
Collaborator Author

@dfdx I've implemented the proposal -- I think we're good to go if you're happy

@dfdx
Copy link
Owner

dfdx commented Nov 20, 2023

Perfect, thank you!

@dfdx dfdx merged commit 0263cdf into main Nov 20, 2023
3 of 4 checks passed
@dfdx dfdx deleted the wct/fix-splatting branch November 20, 2023 23:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants