-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sensitivity wrt LR restarts #8
Comments
The model is ShuffleNet V2, the dataset is Imagenet 1K. This isn't necessarily a bug, just wanted to bring this to author's attention. Feel free to close. |
Thanks for bringing this up. In our analysis & experiments, we haven't try any learning rate restarts. I agree this issue may due to numerical instability or algorithm design. Will look into it later. |
Also, RAdam didn't obviate all needs for warmup : ( we found in some cases, adding additional warmup gets a better performance (some discussions are put at: https://github.com/LiyuanLucasLiu/RAdam#questions-and-discussions). |
I found that the problem is in the initial 5 steps.
The default setting for betas is (0.9, 0.999). Thus the internal variables are changed as following:
Note, that step_size doesn't depend on gradient value and it scales learning_rate. Is it better to set step_size equal to 0 if N_sma < 5? |
@e-sha the For example, in the first update, the first momentum is set to You can refer the implementation of the |
I got it. You are right.
For RAdam it outputs
The output for Adam is:
The last is much better. |
Thanks for letting us know @e-sha, can you provide a full script to reproduce the result? I'm not sure why |
I found that the problem is in gradient values. |
@e-sha Thanks for letting us know : -) I guess you mean the problem is in parameter values? or gradient values? I think in the first iteration SGDM, although with the bias adjustment, should have smaller updates comparing to Adam. For example, even the gradient is larger than one, it will be first multiplied with 0.1, then multiplied with 10 * learning_rate. |
Yes, your are right. In the first iteration SGDM makes step equal to learning_rate * gradient. |
I see, I understand it know, thanks for sharing. BTW, people find that using the gradient clip also helps to stabilize the model training. |
I'm observing sensitivity wrt LR restarts in a typical SGDR schedule with cosine annealing as in Loschilov & Hutter. RAdam still seems to be doing better than AdamW so far, but the jumps imply possible numerical instability at LR discontinuities.
Here's the training loss compared to AdamW (PyTorch 1.2.0 version):
Here's the validation loss:
What's the recommendation here? Should I use warmup in every cycle rather than just in the beginning? I thought RAdam was supposed to obviate the need for warmup. Is this a bug?
The text was updated successfully, but these errors were encountered: