Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit train!, unify update!, and auto-translate the two Adams #2082

Merged
merged 18 commits into from
Nov 20, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Oct 13, 2022

This is a non-breaking alternative to #2029, and part of #1986's goal to kill off Flux.Optimise.

It adds an explicit train!(loss, model, data, opt), which uses Optimisers.jl inside, and needs opt = Flux.setup(Adam(), model) before use. Tries to provide a friendly upgrade path:

  • setup accepts the old Adam(), which is unchanged & can still be used with implicit train!. I hope this makes it easy to translate working code, as new & old will work in the same session. Then v0.14 will export the new Adam() instead, but code won't change.
  • If you forget to call setup, it will still train, but will warn you if state is being lost.
  • Most wrong mixtures of implicit and explicit arguments should produce helpful errors. Also an error is d in data which isn't a Tuple. This avoids the weird batchmaybe thing, which sometimes splats and sometimes doesn't.
  • Flux.update! === Optimisers.update!, so you can't use the wrong one. (At present there are two setups.)

Needs doc updates, but probably not this PR. I think this would free us to remove every mention of implicit parameters from the docs, during 0.13.

Then 0.14 can delete Flux.Optimise and the train!(..., ::Params, ...) function completely.

This train! would like to have methods like mse(m, x, y) = mse(m(x), y) for all the loss functions, to allow train!(mse, model, data, opt) rather than defining a trivial wrapper function every time. Not this PR though. (Now #2090.)

The meaning of data is still the same as before -- it's an iterator usually over tuples, which are usually splatted into the loss. This is (IMO) a confusing feature of train!, and perhaps the implicit / explicit break is the time to fix that too. One possibility would be to take arrays not an iterator train!(mse, model, X, Y; opt):

function train!(loss, model, data...; opt, batchsize=nothing, kw...)
  if batchsize != nothing
    for d in DataLoader(data; batchsize, kw...)
      g, _... = gradient(loss, model, d...)  # same arguments order as train!
      update!(opt, model, g)
    end
  else
    g, _... = gradient(loss, model, data...)  # exact same arguments as train!
    update!(opt, model, g)
... 

It also adds an explicit way of changing the AD used, via @train_autodiff. RFC, I guess. Tests for it run on Tracker & Yota. Removed in bbc0f85, for now, to make things more orthogonal. The macro was the same as this bit of 2029.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about merging this and #2083? With the latest changes, they are orthogonal, so it would be rebasing #2083 to replace explicit_withgradient once this is merged.

src/train.jl Outdated
Comment on lines 95 to 97
* Instead of `loss` being a function which typically accepts two arguments
(the input `x` and expected output `y` from each element of `data`)
now it should typically accept three, the first of which is the `model` itself.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never restricted specifically to 2 arguments (and we don't seem to restrict to 3 now either). I think the change is

  • <= v0.13: loss accepts as many arguments as length(first(data))
  • > v0.13: loss accepts at least 1 argument, the model, and can accept additional N additional arguments where N = length(first(data))

I think the distinction is important, since for things like language models, length(first(data)) == 1 (conceivably), and the loss handles taking the single data argument and turning it into a supervised problem.

Copy link
Member Author

@mcabbott mcabbott Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I don't know how to word this. I think all the doc examples have 2 & need 3, so wrote "typically". Maybe it needs more explanation.

<= v0.13: loss accepts as many arguments as length(first(data))

It's weirder than that, because if first(data) isn't a tuple, then it isn't splatted. I made the new code simply demand a tuple, else error, so that (at least for now) there is just one path.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the "3 arguments" bit, to say just "Instead of loss being a function which accepts only the data, now it must also accept the model itself, as the first argument."

@mcabbott
Copy link
Member Author

What do you think about merging this and #2083? With the latest changes, they are orthogonal

Yes. I simplified this to just the train!-API story, kicking the question of whether or how you might select an AD to use down the road.

I have comments on #2083 which I should tidy up and post there.

I think the major question here is whether, when making a new and incompatible method for train!, we should

  • Do the minimal thing as here, and keep data as something which iterates tuples, and 4th positional argument is the optimiser state. Or:
  • Do more substantial changes at the same time, e.g. to avoid having 4 positional arguments.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 1, 2022

Do more substantial changes at the same time

So the confusing features of train! are:

  1. It takes 4 positional arguments in a hard-to-remember order
  2. It doesn't really do much more than a loop with gradient + update!
  3. It splats the data into the loss, sometimes (or in this PR, always, else error)

One alternative would be to change it to

train!(loss, model, data...; state, kw...)

which calls gradient(loss, model, data...) -- no mystery order, no mystery splat. And then to build in DataLoader, so that train!(loss, model, data...; state, batchsize = 32) breaks up data for you. Each batch is the same shape as the original, though, so this does not introduce confusion about argument order.

Since the keyword state (or something?) is required, this method should not be easily confused with the old one. But it does need the data in quite a different form.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only hesitancy around this PR involve choices made to keep train! as close to the old one as possible. So, do not require users to capture the return value, etc.

I would prefer returning both the model and the state from train!. In v0.13.X, we do not need to document this change to the behavior. We silently upgrade to full Optimisers.jl support and throw deprecation warnings whenever Params and AbstractOptimiser are used. Then some versions down the line, we kill support for implicit.

Also, is the a reason for a separate Train submodule?

(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ())
```
"""
function setup(rule::Optimisers.AbstractRule, model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am hesitant to create a function in the Flux namespace that clashes with Optimisers.jl. It is hard enough already to keep track of where "Flux functions" actually come from.

Why not extend Optimisers.setup for Flux.Optimise.AbstractOptimiser and remove the mutability check? I am guessing this is to guard against immutable models since train! does not return the model?

Copy link
Member Author

@mcabbott mcabbott Nov 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I did (in addition) extend Optimisers.setup that way, but in fact I made an error. Can change it.

Yes the guard against immutable models is the point. All Flux models are assumed mutable right now, and this just makes the check explicit.

I don't love the collision, but neither name is exported, and the consequences of using the wrong one are (I think) slight. You lose the safety check but any model which does work with Flux.setup will also work correctly with Optimisers.setup.

We can of course make train! return the model. But this isn't enough, as you also have to re-do your code to keep not discard the returned model. It's a bit awkward.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I picture Flux 0.14 deleting Flux.Optimise and exporting Adam etc from Optimisers.jl.

Code that goes using Flux; opt = Flux.setup(Adam(), model); train!(loss, model, data, opt) will work equally well on 0.13 and 0.14. You don't have to load Optimisers.jl yourself at all, and all will be safe.

If you do load Optimisers.jl yourself and use its functions, then you have opted into the model, _ = update!(opt, model, grad) thing where you are supposed to get back the new model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code that goes using Flux; opt = Flux.setup(Adam(), model); train!(loss, model, data, opt) will work equally well on 0.13 and 0.14.

I guess the question is should this be the case? I do think there should be a patch release of v0.13.X that accepts Optimisers.Adam, etc. and upgrades Flux.Optimise.Adam with a warning. This will allow train! to work like quoted above too. But in v0.14, I was expecting that we force people to start using model = train!(...). Previously, train! and update! worked similarly (mutating optimizers and model), and we could say train! is "just" a loop. Diverging how they work seems worse than a minor code refactor on a breaking release. Especially given people will get warnings from before.

Copy link
Member Author

@mcabbott mcabbott Nov 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think there should be a patch release of v0.13.X that accepts Optimisers.Adam, etc. and upgrades Flux.Optimise.Adam with a warning

Yes. That's what this PR makes.

I was expecting that we force people to start using model = train!(...)

Especially given people will get warnings from before

But how? You want train! not to mutate, so that everyone will wonder why their model isn't training, and why it's called train!? Or worse to make it return a copy and write NaN into the old model to trash it? These seem awful to me, deliberate breakage for which we gain nothing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's forget my suggestions about the warnings, as I agree about the orthogonality.

One option here is to return the model from train! which would allow for immutable models to work. Mutable models still don't need to capture the return value to work. So, we don't force people to do model = train!(...). And we still have Flux.setup here to work in the reverse direction: warn if any leaf is immutable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I agree that if we are adding Flux.setup, then this seems like something that can be revisited later too.

Copy link
Member Author

@mcabbott mcabbott Nov 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we should consider is changing Optimisers. Maybe its present function should be called update!! as that's something of a convention for "tries to mutate but may fail".

Then in [email protected], we can introduce a new function update! which demands mutability, fails on immutable parameters. And that's the one we identify with Flux's function.

That's now FluxML/Optimisers.jl#116

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think train! should guarantee mutation, this is the widespread julian convention. We can have a train and a train!! for non-mutating and mutate-if-possible versions.
In that case, whether it returns the model or not hasn't great relevance. Base functions such as replace! and map! return the mutated input. Maybe just for REPL usage convenience? In our case returning the model in the repl would just be an annoyance I guess.

Copy link
Member Author

@mcabbott mcabbott Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the present one returns nothing, which as you say means you don't get a screenful of stuff, and also serves as a reminder that it mutates the model.

I think I'd be happiest if update! did the same. I mean Flux.update! does now, but after unifying with Optimisers.update! too.

I understand the attraction of state, model = update(state, model, grad) but IMO it's a pain to remember the order, and update! is now in a weird place where it does guarantee to mutate the state, but not the model.

src/train.jl Show resolved Hide resolved
@mcabbott
Copy link
Member Author

mcabbott commented Nov 10, 2022

choices made to keep train! as close to the old one as possible.

See suggestions above for how we might change it. But one of the goals is making the transition as easy as possible, which argues for keeping the weird data splat thing.

So, do not require users to capture the return value, etc.

How could you require this? You could encourage it, but since code without it will continue to work, it seems tricky to keep saying "you really ought to...".

I would prefer returning both the model and the state from train!.

Note that there is no need to return the state. This is guaranteed to be mutated, even for a model of immutable arrays. We could change update! to return just one thing, which would be less confusing (one less thing to remember the order of!) but perhaps a confusing change, and would have to differ from update no-bang.

We silently upgrade to full Optimisers.jl support and throw deprecation warnings whenever Params and AbstractOptimiser are used. Then some versions down the line, we kill support for implicit.

Yes, that's precisely what this PR amis to do. It makes claims that 0.14 will remove this. But how soon that is, we can see.

Also, is the a reason for a separate Train submodule?

No very strong one, there's a Losses module. And an Optimise module which (1) is a terrible near-clash, and (2) contains neatly everything which the no-more-implicit change will delete entirely.

Would be fine to remove the sub-module, maybe that's better?

@mcabbott mcabbott changed the title Add explicit train! without removing implicit one Add explicit train!, unify update!, and auto-translate the two Adams Nov 11, 2022
@mcabbott mcabbott mentioned this pull request Nov 16, 2022
3 tasks
* `data` must iterate tuples, otherwise you get an error.
(Previously non-tuple types were not splatted into the loss.
Pass in `((d,) for d in data)` to simulate this.)
* `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Documenter need Flux.Train.setup instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe. I just tried, and I think neither this new train! nor setup appear in the docs at present. That section needs to be re-worked for explicit parameters.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc update is #2114, will need to be rebased on this & then checked.

src/train.jl Outdated Show resolved Hide resolved
src/train.jl Outdated Show resolved Hide resolved
src/train.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

I would keep the batchmemaybe(d) thing. It is not uncommon for datasets to yield named tuples or dictionaries (e.g. hugging face does that) and wrapping dataloaders into generators is ugly and annoying. And also because of the let's try to disrupt as little as possible argument.

@mcabbott mcabbott mentioned this pull request Nov 19, 2022
@CarloLucibello
Copy link
Member

conditional on fixing the failing test LGTM

Comment on lines +72 to +80
gs = gradient(marg -> marg(x), m)
@test gs isa Tuple
@test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly
@test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly
@test_throws ErrorException Flux.update!(Flux.Adam(), m, gs) # friendly
@test_throws ErrorException Flux.update!(Flux.Adam(), m, gs[1]) # friendly
s = Flux.setup(Adam(), m)
@info "ignore this warning, just testing an upgrade path:"
Flux.update!(s, m, gs) # Chain + Tuple can be unambiguously sorted out
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most recent commits add some friendly errors to most ways you could use update! wrong, by mixing up implicit & explicit bits, or forgetting to call setup.

Comment on lines -115 to -127
The dimensions of these model parameters depend on the number of inputs and outputs. Since models can have hundreds of inputs and several layers, it helps to have a function to collect the parameters into the data structure Flux expects:

```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
julia> parameters = Flux.params(predict)
Params([Float32[0.9066542], Float32[0.0]])
```

These are the parameters Flux will change, one step at a time, to improve predictions. At each step, the contents of this `Params` object changes too, since it is just a collection of references to the mutable arrays inside the model:

```jldoctest overview
julia> predict.weight in parameters, predict.bias in parameters
(true, true)
```
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also new, remove all params from "overview".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants