Skip to content

Commit

Permalink
Add coord check plotting and cleanup bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Nolan Dey authored and Nolan Dey committed Sep 5, 2024
1 parent f594aa1 commit bcadbc3
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 190 deletions.
158 changes: 0 additions & 158 deletions coord_check/mup/plot.ipynb

This file was deleted.

5 changes: 0 additions & 5 deletions coord_check/mup/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,3 @@ do
--compile=False
done
done
# --wandb_project='owt' \
# --wandb_run_name='gpt2' \
# --warmup_iters=0 \
# --lr_decay_iters=10 \
# --min_lr=6e-5 \
170 changes: 170 additions & 0 deletions coord_check/plot.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions coord_check/sp/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ do
--backend='nccl' \
--device='mps' \
--dtype='float32' \
--compile=False
--compile=False \
--mup_enable_coord_check_logging=True
done
done
done
5 changes: 0 additions & 5 deletions coord_check/sp_with_mup_hidden_init/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,3 @@ do
--compile=False
done
done
# --wandb_project='owt' \
# --wandb_run_name='gpt2' \
# --warmup_iters=0 \
# --lr_decay_iters=10 \
# --min_lr=6e-5 \
5 changes: 0 additions & 5 deletions coord_check/sp_with_mup_hidden_init_and_lr/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,3 @@ do
--compile=False
done
done
# --wandb_project='owt' \
# --wandb_run_name='gpt2' \
# --warmup_iters=0 \
# --lr_decay_iters=10 \
# --min_lr=6e-5 \
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,3 @@ do
--compile=False
done
done
# --wandb_project='owt' \
# --wandb_run_name='gpt2' \
# --warmup_iters=0 \
# --lr_decay_iters=10 \
# --min_lr=6e-5 \
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,3 @@ do
--compile=False
done
done
# --wandb_project='owt' \
# --wandb_run_name='gpt2' \
# --warmup_iters=0 \
# --lr_decay_iters=10 \
# --min_lr=6e-5 \
5 changes: 3 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,12 @@ def forward(self, idx, targets=None):

if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
if self.config.mup_enabled:
### Begin muP code ###
logits *= self.config.mup_output_alpha / self.config.mup_width_multiplier
# Scaling `x` instead of `logits` allows coord check to log change
x *= self.config.mup_output_alpha / self.config.mup_width_multiplier
### End muP code ###
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def log(log_dict):
"lr": lr,
"mfu": running_mfu*100, # convert to percentage
}
if mup_enabled and mup_enable_coord_check_logging and coord_check_dict is not None:
if mup_enable_coord_check_logging and coord_check_dict is not None:
for key in coord_check_dict:
log_dict[key + '_act_abs_mean'] = np.mean(coord_check_dict[key])
if wandb_log:
Expand All @@ -324,7 +324,7 @@ def log(log_dict):
if iter_num == 0 and eval_only:
break

if mup_enabled and mup_enable_coord_check_logging:
if mup_enable_coord_check_logging:
coord_check_dict = {
'token_embedding': [],
'attn': [],
Expand Down Expand Up @@ -388,7 +388,7 @@ def hook(module, input, output, key):
iter_num += 1
local_iter_num += 1

if mup_enabled and mup_enable_coord_check_logging:
if mup_enable_coord_check_logging:
for handle in coord_check_handles:
handle.remove()

Expand Down

0 comments on commit bcadbc3

Please sign in to comment.