Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training improvements #17

Merged

Conversation

evanhanders
Copy link
Collaborator

A bunch of small changes to make training smoother & a bit more robust; most importantly:

  • use_single_loss is now true and default optimizer arguments for Adam use beta = (0.9, 0.9), as opposed to (0.9, 0.999). I find that this makes training with a single loss step more robust, but please check this for your cases!
  • progress bars are now reused during training and a log is output below them instead of them being recreated every epoch.
  • removes redundancy between training_args definitions
  • puts SIIT loss calculation into its own function like IIT loss.
  • Changes how nodes are sampled for SIIT and makes this a training argument; now there's an option so that SIIT can randomly sample from all of the nodes that aren't in the circuit and ablate those, rather than just ablating one.

mypy type checking and pytest tests/ passes, but it's possible some downstream stuff broke? Unclear.

@evanhanders evanhanders marked this pull request as ready for review August 22, 2024 22:32
@@ -18,6 +20,24 @@
from iit.utils.index import Ix, TorchIndex
from iit.utils.metric import MetricStoreCollection, MetricType

def in_notebook() -> bool:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do this by just importing tqdm: much cleaner that way. (at least according to this)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using just tqdm, but it definitely didn't work in notebook mode. I think I hunted down all of the print statements, too.

Cleaner compromise than what's here now: moved this block to utils/tqdm.py, and added from iit.utils.tqdm import tqdm.

@@ -177,7 +197,7 @@ def get_IIT_loss_over_batch(
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node)
label_idx = self.get_label_idxs()
# IIT loss is only computed on the tokens we care about
loss = loss_fn(ll_output[label_idx.as_index], hl_output[label_idx.as_index])
loss = loss_fn(ll_output[label_idx.as_index].to(hl_output.device), hl_output[label_idx.as_index])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably just raise if dataset, hl_model and ll_model aren't on the same device during init/starting training. This usually just hides the main problem and makes it harder to find bugs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I'll add an assert to the beginning of train() and remove all of these.


if early_stop and self._check_early_stop_condition(test_metrics):
break
epoch_pbar.update(1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nicer if we can move the entire logic to _print_and_log_metrics. current_epoch_log can remain there. And logging it to wandb might be useful as well!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this logic to _print_and_log_metrics. I think everything that makes up the string is already being logged to wandb.

for metric in metrics:
print(metric, end=", ")
if metric.type == MetricType.ACCURACY:
current_epoch_log += f"{metric.get_name()}: {metric.get_value():.2f}, "
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str(metric) does this automatically

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to current_epoch_log += str(metric) + ", "

iit/model_pairs/iit_behavior_model_pair.py Outdated Show resolved Hide resolved
@@ -21,14 +21,9 @@ def __init__(
training_args: dict = {}
):
default_training_args = {
"batch_size": 256,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be really helpful if we could maintain the default args as they were before. Or at least store the default hyperparams we used before in some config for reproducibility.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all the defaults are preserved (they were just set in multiple places) EXCEPT I did change use_single_loss and optimizer_kwargs. I'll change those back to the defaults from before.

"strict_weight": 1.0,
"clip_grad_norm": 1.0,
"strict_weight_schedule" : lambda s, i: s,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a cool idea!

Maybe it is better to implement it as

@property 
def strict_weight_at_epoch(self):
    return self.training_args.strict_weight_schedule(<args_from_self>)

Instead of changing the strict weight variable after each epoch? (or a method like strict_weight_for_epoch = self.get_scheduled_strict_weight() and then calculate the loss).

This lambda is also throwing me off a bit, maybe renaming the args will make it clearer...

It also seems like this is achievable by using different optimisers/lrs for each loss (and not using single loss)? No idea which one's better though...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, unclear to me which is the right way to go right now. I think it's best to remove it (right now it's not doing anything) and if you find that this is a useful idea down the road you can add it how you see fit?

iit_loss = 0
ll_loss = 0
behavior_loss = 0
iit_loss = t.zeros(1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely sure why this is needed- You can usually add floats and tensors without messing up the grads, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for mypy type-checking. step_on_loss expects a Tensor instead of a float and the .item() call at the end is a type error if this isn't a Tensor.

I think the right way to resolve this is to remove the if isintance(Tensor) logic at the end of the function, since it's now always a tensor. I'll do that.

@@ -8,6 +9,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm # type: ignore
from transformer_lens.hook_points import HookedRootModule, HookPoint # type: ignore
from IPython.display import clear_output
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed here? Don't think it is being used...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not! Good catch, that was leftover from getting tqdm stuff working.

)

# Set seed before iterating on loaders for reproduceablility.
t.manual_seed(training_args["seed"])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use a generator for loaders like we do for numpy? I think I used to set this once globally in the training script before- my bad. :(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally sure? I got this solution here. It seems like the random operation is set when the dataloader is turned into an iterable, and someone could use torch functions between initializing and training the model pair, which could hinder reproducibility without putting something here.

@cybershiptrooper
Copy link
Owner

cybershiptrooper commented Aug 23, 2024

mypy type checking and pytest tests/ passes, but it's possible some downstream stuff broke? Unclear.

It doesn't look particularly problematic. I'll have a more careful look in a while.
Might be worth checking if circuits benchmark tests pass after updating its poetry version...

Thanks for adding these!

Will definitely check my cases though. This seems important in general- somehow I can't reproduce the 4 new trained cases after pulling the newer PRs. Maybe these changes help. :")

@evanhanders
Copy link
Collaborator Author

OK! I think I responded to all of your changes and pushed updates. Also found a problem that was causing circuits-bench tests to fail and fixed that so they all are passing on my end.

I'll be offline for the next three weeks starting in a few hours, so if there are other problems / stylistic things, please feel free to edit the branch of my repo / this PR to get those fixed!

@evanhanders
Copy link
Collaborator Author

Also I just added back in one .to(device) in the eval() step. It's really helpful for me to not have my entire dataset on cuda / mps, especially when training successive models in a notebook, so putting the dataset labels on the model's device in run_eval_step is helpful.

@cybershiptrooper
Copy link
Owner

cybershiptrooper commented Aug 23, 2024

Great! The changes look fine now.

Merging.

Thanks for the PR!

@cybershiptrooper cybershiptrooper merged commit e0be350 into cybershiptrooper:main Aug 23, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants