Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add step! #1833

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open

add step! #1833

wants to merge 12 commits into from

Conversation

FelixBenning
Copy link

@FelixBenning FelixBenning commented Jan 13, 2022

Add step! like suggested in #666 as a single step of train! to allow for more exotic optimisers to simply overload step! and be able to use the train! wrapper.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • API changes require approval from a committer (different from the author, if applicable)

@DhairyaLGandhi
Copy link
Member

Similar to #1017 which should be complete

@darsnack
Copy link
Member

True, sorry I missed that. Let's just move ahead with this one since it has a docstring and Felix put in some time recently.

src/optimise/train.jl Outdated Show resolved Hide resolved
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of docstring changes. The implementation looks good to me.

Can you also update the docs for "Custom Training loops" to use this new function? And add the docstring into the docs in that location?

One other suggestion is to call this trainstep!, since step! is very generic, and we might want to hold onto the name for something else.

src/optimise/train.jl Outdated Show resolved Hide resolved
src/optimise/train.jl Outdated Show resolved Hide resolved
@@ -81,29 +81,29 @@ batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x

"""
step!(loss, params, opt)
optimstep!(loss, params, opt)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest optimstep! instead of trainstep! to indicate that this is the optimiser interface and keep the ML jargon to a minimum

Copy link
Member

@mcabbott mcabbott Mar 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One vote for something evoking train! to stress that they are closely related.

If the longer-term plan is to use Optimisers.jl, this may not fit with train! at all -- some recent discussion here: #1902 (comment) . In which case there will be an implicit-style train! & Params story, and an explicit-style gradient and Optimisers.update!. With such a divide, this function wants to be clearly on the train! & Params side.

Maybe it should just be 3-arg train!? Without a data iterator, there is no iteration, that's all:

train!(loss, ::Params, data, ::AbstractOptimiser)  # calls loss(d...) for d in data
train!(loss, ::Params, ::AbstractOptimiser)        # calls loss() since there is no data

# Calculate the gradients of the parameters
# with respect to the loss function
grads = Flux.gradient(parameters) do
# Update the parameters based on the chosen
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is right at the beginning instead of in the Custom Training Loop Section. It seems to me like the custom training loop section might either be redundant or demonstrate how to have a custom gradient calculation now.

NEWS.md Outdated Show resolved Hide resolved
src/optimise/Optimise.jl Outdated Show resolved Hide resolved
FelixBenning and others added 2 commits May 24, 2022 12:34
Co-authored-by: Brian Chen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants