Skip to content

Commit

Permalink
fix learning rate bug
Browse files Browse the repository at this point in the history
  • Loading branch information
cybershiptrooper committed Sep 4, 2024
1 parent 73c1663 commit 84fdd6f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
6 changes: 5 additions & 1 deletion iit/model_pairs/base_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,11 @@ def train(

early_stop = training_args["early_stop"]

optimizer = training_args['optimizer_cls'](self.ll_model.parameters(), **training_args['optimizer_kwargs'])
optimizer = training_args['optimizer_cls'](
self.ll_model.parameters(),
lr=training_args["lr"],
**training_args['optimizer_kwargs']
)
loss_fn = self.loss_fn
scheduler_cls = training_args.get("lr_scheduler", None)
scheduler_kwargs = training_args.get("scheduler_kwargs", {})
Expand Down
2 changes: 1 addition & 1 deletion iit/model_pairs/iit_behavior_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
):
default_training_args = {
"atol": 5e-2,
"use_single_loss": False,
"use_single_loss": True,
"iit_weight": 1.0,
"behavior_weight": 1.0,
"val_IIA_sampling": "random", # random or all
Expand Down
3 changes: 2 additions & 1 deletion iit/model_pairs/iit_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ def __init__(
"scheduler_kwargs": {},
"clip_grad_norm": 1.0,
"seed": 0,
"lr": 0.001,
"detach_while_caching": True,
"optimizer_cls": t.optim.Adam,
"optimizer_kwargs" : {
"lr": 0.001,
# "betas": (0.9, 0.9)
},
}
training_args = {**default_training_args, **training_args}
Expand Down

0 comments on commit 84fdd6f

Please sign in to comment.