-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reverse weight decay #567
base: main
Are you sure you want to change the base?
Reverse weight decay #567
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR needs to go into the train-olmo-large
branch, no?
You are right. Then we need to make sure we compute this every time.
…On Fri, May 3, 2024, 08:45 Akshita Bhagia ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In olmo/train.py
<#567 (comment)>:
> + if should_log_optim_metrics_this_step:
+ emb_decay_factor = 1.0 - optim_metrics["param/transformer.wte.weight.norm"]
+ else:
+ emb_decay_factor = 1.0
We compute the norm of the gradient every step (
grad/transformer.wte.weight.norm), not the norm of the parameter itself (
param/transformer.wte.weight.norm). Don't we need the latter?
—
Reply to this email directly, view it on GitHub
<#567 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAHAYPRVJA3KTX3NEXIYH5DZAOWLNAVCNFSM6AAAAABHFMSCVWVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDAMZYGQ2DGOJZHA>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Done |
@epwalsh , can you look at this as well? This gets all up in your code. |
olmo/optim.py
Outdated
if cfg.optimizer.decay_embeddings: | ||
decay.add(fpn) | ||
elif cfg.optimizer.reverse_embedding_decay: | ||
embeddings_decay.add(fpn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if these are both set? We should check against that somewhere.
CHANGELOG.md
Outdated
@@ -23,6 +23,7 @@ shared memory implementation can be used by passing `use_legacy_shared_mem_impl` | |||
- Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks | |||
- Tokenizer patch | |||
- Added option to specify number of model replicas when using hybrid sharding. | |||
- Added reverse_embedding_decay option. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name also needs to be updated.
@@ -43,6 +43,7 @@ def clip_grads_and_collect_metrics( | |||
global_step: int, | |||
collect_param_metrics: bool = True, | |||
process_group: Optional[dist.ProcessGroup] = None, | |||
regularize_embeddings: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this a parameter to this function? Shouldn't it be just captured in the parameter groups? That's how all the other regularization works.
if group["name"] == "embedding_decay_group": | ||
group["weight_decay"] *= emb_decay_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does't this multiply up emb_decay_factor
across batches? It feels like this should just be set, not multiplied? Or is there some other bit that resets group["weight_decay"]
every time?
emb_norm = optim_metrics["param/transformer.wte.weight.norm"] | ||
emb_size = self.cfg.model.embedding_size or self.cfg.model.vocab_size | ||
emb_std = math.sqrt(math.pow(emb_norm, 2) / float(emb_size * self.cfg.model.vocab_size)) | ||
emb_decay_factor = 1.0 - emb_std |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're using this to plug into the value for WD, that means it needs to be negative when we want to pull up the values. So then it would be emb_std - 1
?
) | ||
|
||
emb_norm = optim_metrics["param/transformer.wte.weight.norm"] | ||
emb_size = self.cfg.model.embedding_size or self.cfg.model.vocab_size | ||
emb_std = math.sqrt(math.pow(emb_norm, 2) / float(emb_size * self.cfg.model.vocab_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the denominator should be float(self.cfg.model.d_model * emb_size)
. And I'm not sure about the numerator either... I don't see how this is equivalent to standard deviation since the summation terms in the norm are not centered by the mean, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update: @AkshitaB and I discussed this, we think we need to calculate this metric separately in optim.py
.
We also talked about how this standard deviation will be a little biased since it will include parts of the embedding that never are never used, since we inflate the embedding size beyond vocab size to be a multiple of 128. But this is probably okay since that's only a small part of the embeddings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I think this is a big problem. Embeddings will want to be small, so this will push them up. Unused, or rarely used embeddings will never get updated, so they will get bigger and bigger, skewing the calculation of the stddev more and more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Figuring out which embeddings to exclude from the stddev computation is going to be tricky in the distributed setting though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking out loud here... what if we force the unused params to be zero from the beginning? They would still bias standard deviation by as much as they are different from the mean, but they would always be zero.. I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would work if we were starting with this from scratch, but what about the case when we want to use this to "rescue" a run? Can we explicitly make the unused embeddings zero when we load the model? And will it matter if we do so halfway through training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we explicitly make the unused embeddings zero when we load the model?
I think that's our best bet. I can't think of any issues that would introduce in the middle of training. I suspect those parameters are 0 anyway due to weight decay and zero gradients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rare tokens would still be an issue, but not any more than they always are.
Goal: Perform reverse weight decay on embeddings
Multiply weight_decay factor for the embeddings layer by
(1 - norm(embeddings))
TODO:
I tried this on a tiny test model config and got an overflow error. Possibly this will not be an issue with the actual model.
Note: I created the branch from
train-olmo-large
. See this for actual diffs for this PR.