-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training improvements #17
Changes from 7 commits
6c5f4ed
1c089cf
97c18d3
5e5eaa0
855807a
7734db5
18e467e
0d1b24f
313309e
e4eba0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import os | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Callable, final, Type | ||
from typing import Any, Callable, final, Type, Optional | ||
|
||
import numpy as np | ||
import torch as t | ||
|
@@ -8,6 +9,7 @@ | |
from torch.utils.data import DataLoader | ||
from tqdm import tqdm # type: ignore | ||
from transformer_lens.hook_points import HookedRootModule, HookPoint # type: ignore | ||
from IPython.display import clear_output | ||
|
||
import wandb # type: ignore | ||
from iit.model_pairs.ll_model import LLModel | ||
|
@@ -18,6 +20,24 @@ | |
from iit.utils.index import Ix, TorchIndex | ||
from iit.utils.metric import MetricStoreCollection, MetricType | ||
|
||
def in_notebook() -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can do this by just importing tqdm: much cleaner that way. (at least according to this) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried using just tqdm, but it definitely didn't work in notebook mode. I think I hunted down all of the print statements, too. Cleaner compromise than what's here now: moved this block to utils/tqdm.py, and added from iit.utils.tqdm import tqdm. |
||
try: | ||
# This will only work in Jupyter notebooks | ||
shell = get_ipython().__class__.__name__ # type: ignore | ||
if shell == 'ZMQInteractiveShell': | ||
return True # Jupyter notebook or qtconsole | ||
elif shell == 'TerminalInteractiveShell': | ||
return False # Terminal running IPython | ||
else: | ||
return False # Other types of interactive shells | ||
except NameError: | ||
return False # Probably standard Python interpreter | ||
|
||
if in_notebook(): | ||
from tqdm.notebook import tqdm | ||
else: | ||
from tqdm import tqdm | ||
|
||
|
||
class BaseModelPair(ABC): | ||
hl_model: HookedRootModule | ||
|
@@ -177,7 +197,7 @@ def get_IIT_loss_over_batch( | |
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) | ||
label_idx = self.get_label_idxs() | ||
# IIT loss is only computed on the tokens we care about | ||
loss = loss_fn(ll_output[label_idx.as_index], hl_output[label_idx.as_index]) | ||
loss = loss_fn(ll_output[label_idx.as_index].to(hl_output.device), hl_output[label_idx.as_index]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably just raise if dataset, hl_model and ll_model aren't on the same device during init/starting training. This usually just hides the main problem and makes it harder to find bugs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, I'll add an assert to the beginning of train() and remove all of these. |
||
return loss | ||
|
||
def clip_grad_fn(self) -> None: | ||
|
@@ -218,11 +238,9 @@ def train( | |
self, | ||
train_set: IITDataset, | ||
test_set: IITDataset, | ||
optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam, | ||
epochs: int = 1000, | ||
use_wandb: bool = False, | ||
wandb_name_suffix: str = "", | ||
optimizer_kwargs: dict = {}, | ||
) -> None: | ||
training_args = self.training_args | ||
print(f"{training_args=}") | ||
|
@@ -242,15 +260,19 @@ def train( | |
|
||
early_stop = training_args["early_stop"] | ||
|
||
optimizer_kwargs['lr'] = training_args["lr"] | ||
optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs) | ||
optimizer = training_args['optimizer_cls'](self.ll_model.parameters(), **training_args['optimizer_kwargs']) | ||
loss_fn = self.loss_fn | ||
scheduler_cls = training_args.get("lr_scheduler", None) | ||
scheduler_kwargs = training_args.get("scheduler_kwargs", {}) | ||
if scheduler_cls == t.optim.lr_scheduler.ReduceLROnPlateau: | ||
mode = training_args.get("scheduler_mode", "max") | ||
lr_scheduler = scheduler_cls(optimizer, mode=mode, factor=0.1, patience=10) | ||
if 'patience' not in scheduler_kwargs: | ||
scheduler_kwargs['patience'] = 10 | ||
if 'factor' not in scheduler_kwargs: | ||
scheduler_kwargs['factor'] = 0.1 | ||
lr_scheduler = scheduler_cls(optimizer, mode=mode, **scheduler_kwargs) | ||
elif scheduler_cls: | ||
lr_scheduler = scheduler_cls(optimizer) | ||
lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs) | ||
|
||
if use_wandb and not wandb.run: | ||
wandb.init(project="iit", name=wandb_name_suffix, | ||
|
@@ -261,19 +283,35 @@ def train( | |
wandb.config.update({"method": self.wandb_method}) | ||
wandb.run.log_code() # type: ignore | ||
|
||
for epoch in tqdm(range(epochs)): | ||
train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer) | ||
test_metrics = self._run_eval_epoch(test_loader, loss_fn) | ||
if scheduler_cls: | ||
self.step_scheduler(lr_scheduler, test_metrics) | ||
self.test_metrics = test_metrics | ||
self.train_metrics = train_metrics | ||
self._print_and_log_metrics( | ||
epoch, MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), use_wandb | ||
) | ||
|
||
# Set seed before iterating on loaders for reproduceablility. | ||
t.manual_seed(training_args["seed"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to use a generator for loaders like we do for numpy? I think I used to set this once globally in the training script before- my bad. :( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not totally sure? I got this solution here. It seems like the random operation is set when the dataloader is turned into an iterable, and someone could use torch functions between initializing and training the model pair, which could hinder reproducibility without putting something here. |
||
with tqdm(range(epochs), desc="Training Epochs") as epoch_pbar: | ||
with tqdm(total=len(train_loader), desc="Training Batches") as batch_pbar: | ||
for epoch in range(epochs): | ||
batch_pbar.reset() | ||
|
||
train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer, batch_pbar) | ||
test_metrics = self._run_eval_epoch(test_loader, loss_fn) | ||
if scheduler_cls: | ||
self.step_scheduler(lr_scheduler, test_metrics) | ||
self.test_metrics = test_metrics | ||
self.train_metrics = train_metrics | ||
current_epoch_log = self._print_and_log_metrics( | ||
epoch=epoch, | ||
metrics=MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), | ||
optimizer=optimizer, | ||
use_wandb=use_wandb, | ||
) | ||
|
||
if early_stop and self._check_early_stop_condition(test_metrics): | ||
break | ||
epoch_pbar.update(1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nicer if we can move the entire logic to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved this logic to _print_and_log_metrics. I think everything that makes up the string is already being logged to wandb. |
||
epoch_pbar.set_postfix_str(current_epoch_log.strip(', ')) | ||
epoch_pbar.set_description(f"Epoch {epoch + 1}/{epochs}") | ||
|
||
if early_stop and self._check_early_stop_condition(test_metrics): | ||
break | ||
|
||
self._run_epoch_extras(epoch_number=epoch+1) | ||
|
||
if use_wandb: | ||
wandb.log({"final epoch": epoch}) | ||
|
@@ -298,16 +336,16 @@ def _run_train_epoch( | |
self, | ||
loader: DataLoader, | ||
loss_fn: Callable[[Tensor, Tensor], Tensor], | ||
optimizer: t.optim.Optimizer | ||
optimizer: t.optim.Optimizer, | ||
pbar: tqdm | ||
) -> MetricStoreCollection: | ||
self.ll_model.train() | ||
train_metrics = self.make_train_metrics() | ||
for i, (base_input, ablation_input) in tqdm( | ||
enumerate(loader), total=len(loader) | ||
): | ||
for i, (base_input, ablation_input) in enumerate(loader): | ||
train_metrics.update( | ||
self.run_train_step(base_input, ablation_input, loss_fn, optimizer) | ||
) | ||
pbar.update(1) | ||
return train_metrics | ||
|
||
@final | ||
|
@@ -341,17 +379,36 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo | |
return True | ||
|
||
@final | ||
@staticmethod | ||
def _print_and_log_metrics( | ||
self, | ||
epoch: int, | ||
metrics: MetricStoreCollection, | ||
use_wandb: bool = False | ||
) -> None: | ||
print(f"\nEpoch {epoch}:", end=" ") | ||
optimizer: t.optim.Optimizer, | ||
use_wandb: bool = False, | ||
print_metrics: bool = True, | ||
) -> str: | ||
|
||
# Print the current epoch's metrics | ||
current_epoch_log = f"lr: {optimizer.param_groups[0]['lr']:.2e}, " | ||
for k in self.training_args.keys(): | ||
if 'weight' in k and 'schedule' not in k: | ||
current_epoch_log += f"{k}: {self.training_args[k]:.2e}, " | ||
if use_wandb: | ||
wandb.log({"epoch": epoch}) | ||
wandb.log({"lr": optimizer.param_groups[0]["lr"]}) | ||
|
||
for metric in metrics: | ||
print(metric, end=", ") | ||
if metric.type == MetricType.ACCURACY: | ||
current_epoch_log += f"{metric.get_name()}: {metric.get_value():.2f}, " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. str(metric) does this automatically There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed to current_epoch_log += str(metric) + ", " |
||
else: | ||
current_epoch_log += f"{metric.get_name()}: {metric.get_value():.2e}, " | ||
if use_wandb: | ||
wandb.log({metric.get_name(): metric.get_value()}) | ||
print() | ||
if print_metrics: | ||
tqdm.write(f'Epoch {epoch+1}: {current_epoch_log.strip(", ")}') | ||
|
||
return current_epoch_log | ||
|
||
def _run_epoch_extras(self, epoch_number: int) -> None: | ||
""" Optional method for running extra code at the end of each epoch """ | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,12 +19,12 @@ def __init__( | |
training_args: dict = {} | ||
): | ||
default_training_args = { | ||
"lr": 0.001, | ||
"atol": 5e-2, | ||
"early_stop": True, | ||
"use_single_loss": False, | ||
"use_single_loss": True, | ||
"iit_weight": 1.0, | ||
"behavior_weight": 1.0, | ||
"iit_weight_schedule" : lambda s, i: s, | ||
"behavior_weight_schedule" : lambda s, i: s, | ||
} | ||
training_args = {**default_training_args, **training_args} | ||
super().__init__(hl_model, ll_model, corr=corr, training_args=training_args) | ||
|
@@ -56,8 +56,9 @@ def get_behaviour_loss_over_batch( | |
) -> Tensor: | ||
base_x, base_y = base_input[0:2] | ||
output = self.ll_model(base_x) | ||
behavior_loss = loss_fn(output, base_y) | ||
return behavior_loss | ||
indx = self.get_label_idxs() | ||
evanhanders marked this conversation as resolved.
Show resolved
Hide resolved
|
||
behaviour_loss = loss_fn(output[indx.as_index], base_y[indx.as_index].to(output.device)) | ||
return behaviour_loss | ||
|
||
def step_on_loss(self, loss: Tensor, optimizer: t.optim.Optimizer) -> None: | ||
optimizer.zero_grad() | ||
|
@@ -108,32 +109,38 @@ def run_eval_step( | |
|
||
# compute IIT loss and accuracy | ||
label_idx = self.get_label_idxs() | ||
hl_node = self.sample_hl_name() | ||
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) | ||
hl_output.to(ll_output.device) | ||
hl_output = hl_output[label_idx.as_index] | ||
ll_output = ll_output[label_idx.as_index] | ||
if self.hl_model.is_categorical(): | ||
loss = loss_fn(ll_output, hl_output) | ||
if ll_output.shape == hl_output.shape: | ||
# To handle the case when labels are one-hot | ||
hl_output = t.argmax(hl_output, dim=-1) | ||
top1 = t.argmax(ll_output, dim=-1) | ||
accuracy = (top1 == hl_output).float().mean() | ||
IIA = accuracy.item() | ||
else: | ||
loss = loss_fn(ll_output, hl_output) | ||
IIA = ((ll_output - hl_output).abs() < atol).float().mean().item() | ||
#don't just sample one HL node, compute IIA for all HL nodes and average. | ||
iias = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would increase the validation time for large tasks(eg: IOI) like crazy. Maybe would be better to keep it under a flag and optionally add another tqdm to the validation round. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added "val_IIA_sampling" to training metrics with default value of "random" (can choose "all"). |
||
for hl_node in self.corr.keys(): | ||
# hl_node = self.sample_hl_name() | ||
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node) | ||
hl_output = hl_output.to(ll_output.device) | ||
hl_output = hl_output[label_idx.as_index] | ||
ll_output = ll_output[label_idx.as_index] | ||
if self.hl_model.is_categorical(): | ||
loss = loss_fn(ll_output, hl_output) | ||
if ll_output.shape == hl_output.shape: | ||
# To handle the case when labels are one-hot | ||
hl_output = t.argmax(hl_output, dim=-1) | ||
top1 = t.argmax(ll_output, dim=-1) | ||
accuracy = (top1 == hl_output).float().mean() | ||
IIA = accuracy.item() | ||
else: | ||
loss = loss_fn(ll_output, hl_output) | ||
IIA = ((ll_output - hl_output).abs() < atol).float().mean().item() | ||
iias.append(IIA) | ||
IIA = sum(iias) / len(iias) | ||
|
||
# compute behavioral accuracy | ||
base_x, base_y = base_input[0:2] | ||
output = self.ll_model(base_x) | ||
output = self.ll_model(base_x)[label_idx.as_index] #convert ll logits -> one-hot max label | ||
if self.hl_model.is_categorical(): | ||
top1 = t.argmax(output, dim=-1) | ||
if output.shape == base_y.shape: | ||
if output.shape[-1] == base_y.shape[-1]: | ||
# To handle the case when labels are one-hot | ||
# TODO: is there a better way? | ||
base_y = t.argmax(base_y, dim=-1) | ||
base_y = t.argmax(base_y, dim=-1).squeeze() | ||
base_y = base_y.to(top1.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, moving to devices here might not be best for debugging... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed .to(device) line |
||
accuracy = (top1 == base_y).float().mean() | ||
else: | ||
accuracy = ((output - base_y).abs() < atol).float().mean() | ||
|
@@ -151,4 +158,8 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo | |
return metric.get_value() == 100 | ||
else: | ||
return super()._check_early_stop_condition(test_metrics) | ||
return False | ||
return False | ||
|
||
def _run_epoch_extras(self, epoch_number: int) -> None: | ||
self.training_args['iit_weight'] = self.training_args['iit_weight_schedule'](self.training_args['iit_weight'], epoch_number) | ||
self.training_args['behavior_weight'] = self.training_args['behavior_weight_schedule'](self.training_args['behavior_weight'], epoch_number) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed here? Don't think it is being used...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not! Good catch, that was leftover from getting tqdm stuff working.