Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training improvements #17

Merged
117 changes: 87 additions & 30 deletions iit/model_pairs/base_model_pair.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, final, Type
from typing import Any, Callable, final, Type, Optional

import numpy as np
import torch as t
Expand All @@ -8,6 +9,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm # type: ignore
from transformer_lens.hook_points import HookedRootModule, HookPoint # type: ignore
from IPython.display import clear_output
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed here? Don't think it is being used...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not! Good catch, that was leftover from getting tqdm stuff working.


import wandb # type: ignore
from iit.model_pairs.ll_model import LLModel
Expand All @@ -18,6 +20,24 @@
from iit.utils.index import Ix, TorchIndex
from iit.utils.metric import MetricStoreCollection, MetricType

def in_notebook() -> bool:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do this by just importing tqdm: much cleaner that way. (at least according to this)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using just tqdm, but it definitely didn't work in notebook mode. I think I hunted down all of the print statements, too.

Cleaner compromise than what's here now: moved this block to utils/tqdm.py, and added from iit.utils.tqdm import tqdm.

try:
# This will only work in Jupyter notebooks
shell = get_ipython().__class__.__name__ # type: ignore
if shell == 'ZMQInteractiveShell':
return True # Jupyter notebook or qtconsole
elif shell == 'TerminalInteractiveShell':
return False # Terminal running IPython
else:
return False # Other types of interactive shells
except NameError:
return False # Probably standard Python interpreter

if in_notebook():
from tqdm.notebook import tqdm
else:
from tqdm import tqdm


class BaseModelPair(ABC):
hl_model: HookedRootModule
Expand Down Expand Up @@ -177,7 +197,7 @@ def get_IIT_loss_over_batch(
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node)
label_idx = self.get_label_idxs()
# IIT loss is only computed on the tokens we care about
loss = loss_fn(ll_output[label_idx.as_index], hl_output[label_idx.as_index])
loss = loss_fn(ll_output[label_idx.as_index].to(hl_output.device), hl_output[label_idx.as_index])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably just raise if dataset, hl_model and ll_model aren't on the same device during init/starting training. This usually just hides the main problem and makes it harder to find bugs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I'll add an assert to the beginning of train() and remove all of these.

return loss

def clip_grad_fn(self) -> None:
Expand Down Expand Up @@ -218,11 +238,9 @@ def train(
self,
train_set: IITDataset,
test_set: IITDataset,
optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam,
epochs: int = 1000,
use_wandb: bool = False,
wandb_name_suffix: str = "",
optimizer_kwargs: dict = {},
) -> None:
training_args = self.training_args
print(f"{training_args=}")
Expand All @@ -242,15 +260,19 @@ def train(

early_stop = training_args["early_stop"]

optimizer_kwargs['lr'] = training_args["lr"]
optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs)
optimizer = training_args['optimizer_cls'](self.ll_model.parameters(), **training_args['optimizer_kwargs'])
loss_fn = self.loss_fn
scheduler_cls = training_args.get("lr_scheduler", None)
scheduler_kwargs = training_args.get("scheduler_kwargs", {})
if scheduler_cls == t.optim.lr_scheduler.ReduceLROnPlateau:
mode = training_args.get("scheduler_mode", "max")
lr_scheduler = scheduler_cls(optimizer, mode=mode, factor=0.1, patience=10)
if 'patience' not in scheduler_kwargs:
scheduler_kwargs['patience'] = 10
if 'factor' not in scheduler_kwargs:
scheduler_kwargs['factor'] = 0.1
lr_scheduler = scheduler_cls(optimizer, mode=mode, **scheduler_kwargs)
elif scheduler_cls:
lr_scheduler = scheduler_cls(optimizer)
lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs)

if use_wandb and not wandb.run:
wandb.init(project="iit", name=wandb_name_suffix,
Expand All @@ -261,19 +283,35 @@ def train(
wandb.config.update({"method": self.wandb_method})
wandb.run.log_code() # type: ignore

for epoch in tqdm(range(epochs)):
train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer)
test_metrics = self._run_eval_epoch(test_loader, loss_fn)
if scheduler_cls:
self.step_scheduler(lr_scheduler, test_metrics)
self.test_metrics = test_metrics
self.train_metrics = train_metrics
self._print_and_log_metrics(
epoch, MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), use_wandb
)

# Set seed before iterating on loaders for reproduceablility.
t.manual_seed(training_args["seed"])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use a generator for loaders like we do for numpy? I think I used to set this once globally in the training script before- my bad. :(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally sure? I got this solution here. It seems like the random operation is set when the dataloader is turned into an iterable, and someone could use torch functions between initializing and training the model pair, which could hinder reproducibility without putting something here.

with tqdm(range(epochs), desc="Training Epochs") as epoch_pbar:
with tqdm(total=len(train_loader), desc="Training Batches") as batch_pbar:
for epoch in range(epochs):
batch_pbar.reset()

train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer, batch_pbar)
test_metrics = self._run_eval_epoch(test_loader, loss_fn)
if scheduler_cls:
self.step_scheduler(lr_scheduler, test_metrics)
self.test_metrics = test_metrics
self.train_metrics = train_metrics
current_epoch_log = self._print_and_log_metrics(
epoch=epoch,
metrics=MetricStoreCollection(train_metrics.metrics + test_metrics.metrics),
optimizer=optimizer,
use_wandb=use_wandb,
)

if early_stop and self._check_early_stop_condition(test_metrics):
break
epoch_pbar.update(1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nicer if we can move the entire logic to _print_and_log_metrics. current_epoch_log can remain there. And logging it to wandb might be useful as well!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this logic to _print_and_log_metrics. I think everything that makes up the string is already being logged to wandb.

epoch_pbar.set_postfix_str(current_epoch_log.strip(', '))
epoch_pbar.set_description(f"Epoch {epoch + 1}/{epochs}")

if early_stop and self._check_early_stop_condition(test_metrics):
break

self._run_epoch_extras(epoch_number=epoch+1)

if use_wandb:
wandb.log({"final epoch": epoch})
Expand All @@ -298,16 +336,16 @@ def _run_train_epoch(
self,
loader: DataLoader,
loss_fn: Callable[[Tensor, Tensor], Tensor],
optimizer: t.optim.Optimizer
optimizer: t.optim.Optimizer,
pbar: tqdm
) -> MetricStoreCollection:
self.ll_model.train()
train_metrics = self.make_train_metrics()
for i, (base_input, ablation_input) in tqdm(
enumerate(loader), total=len(loader)
):
for i, (base_input, ablation_input) in enumerate(loader):
train_metrics.update(
self.run_train_step(base_input, ablation_input, loss_fn, optimizer)
)
pbar.update(1)
return train_metrics

@final
Expand Down Expand Up @@ -341,17 +379,36 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo
return True

@final
@staticmethod
def _print_and_log_metrics(
self,
epoch: int,
metrics: MetricStoreCollection,
use_wandb: bool = False
) -> None:
print(f"\nEpoch {epoch}:", end=" ")
optimizer: t.optim.Optimizer,
use_wandb: bool = False,
print_metrics: bool = True,
) -> str:

# Print the current epoch's metrics
current_epoch_log = f"lr: {optimizer.param_groups[0]['lr']:.2e}, "
for k in self.training_args.keys():
if 'weight' in k and 'schedule' not in k:
current_epoch_log += f"{k}: {self.training_args[k]:.2e}, "
if use_wandb:
wandb.log({"epoch": epoch})
wandb.log({"lr": optimizer.param_groups[0]["lr"]})

for metric in metrics:
print(metric, end=", ")
if metric.type == MetricType.ACCURACY:
current_epoch_log += f"{metric.get_name()}: {metric.get_value():.2f}, "
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str(metric) does this automatically

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to current_epoch_log += str(metric) + ", "

else:
current_epoch_log += f"{metric.get_name()}: {metric.get_value():.2e}, "
if use_wandb:
wandb.log({metric.get_name(): metric.get_value()})
print()
if print_metrics:
tqdm.write(f'Epoch {epoch+1}: {current_epoch_log.strip(", ")}')

return current_epoch_log

def _run_epoch_extras(self, epoch_number: int) -> None:
""" Optional method for running extra code at the end of each epoch """
pass
61 changes: 36 additions & 25 deletions iit/model_pairs/iit_behavior_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def __init__(
training_args: dict = {}
):
default_training_args = {
"lr": 0.001,
"atol": 5e-2,
"early_stop": True,
"use_single_loss": False,
"use_single_loss": True,
"iit_weight": 1.0,
"behavior_weight": 1.0,
"iit_weight_schedule" : lambda s, i: s,
"behavior_weight_schedule" : lambda s, i: s,
}
training_args = {**default_training_args, **training_args}
super().__init__(hl_model, ll_model, corr=corr, training_args=training_args)
Expand Down Expand Up @@ -56,8 +56,9 @@ def get_behaviour_loss_over_batch(
) -> Tensor:
base_x, base_y = base_input[0:2]
output = self.ll_model(base_x)
behavior_loss = loss_fn(output, base_y)
return behavior_loss
indx = self.get_label_idxs()
evanhanders marked this conversation as resolved.
Show resolved Hide resolved
behaviour_loss = loss_fn(output[indx.as_index], base_y[indx.as_index].to(output.device))
return behaviour_loss

def step_on_loss(self, loss: Tensor, optimizer: t.optim.Optimizer) -> None:
optimizer.zero_grad()
Expand Down Expand Up @@ -108,32 +109,38 @@ def run_eval_step(

# compute IIT loss and accuracy
label_idx = self.get_label_idxs()
hl_node = self.sample_hl_name()
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node)
hl_output.to(ll_output.device)
hl_output = hl_output[label_idx.as_index]
ll_output = ll_output[label_idx.as_index]
if self.hl_model.is_categorical():
loss = loss_fn(ll_output, hl_output)
if ll_output.shape == hl_output.shape:
# To handle the case when labels are one-hot
hl_output = t.argmax(hl_output, dim=-1)
top1 = t.argmax(ll_output, dim=-1)
accuracy = (top1 == hl_output).float().mean()
IIA = accuracy.item()
else:
loss = loss_fn(ll_output, hl_output)
IIA = ((ll_output - hl_output).abs() < atol).float().mean().item()
#don't just sample one HL node, compute IIA for all HL nodes and average.
iias = []
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would increase the validation time for large tasks(eg: IOI) like crazy. Maybe would be better to keep it under a flag and optionally add another tqdm to the validation round.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added "val_IIA_sampling" to training metrics with default value of "random" (can choose "all").

for hl_node in self.corr.keys():
# hl_node = self.sample_hl_name()
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node)
hl_output = hl_output.to(ll_output.device)
hl_output = hl_output[label_idx.as_index]
ll_output = ll_output[label_idx.as_index]
if self.hl_model.is_categorical():
loss = loss_fn(ll_output, hl_output)
if ll_output.shape == hl_output.shape:
# To handle the case when labels are one-hot
hl_output = t.argmax(hl_output, dim=-1)
top1 = t.argmax(ll_output, dim=-1)
accuracy = (top1 == hl_output).float().mean()
IIA = accuracy.item()
else:
loss = loss_fn(ll_output, hl_output)
IIA = ((ll_output - hl_output).abs() < atol).float().mean().item()
iias.append(IIA)
IIA = sum(iias) / len(iias)

# compute behavioral accuracy
base_x, base_y = base_input[0:2]
output = self.ll_model(base_x)
output = self.ll_model(base_x)[label_idx.as_index] #convert ll logits -> one-hot max label
if self.hl_model.is_categorical():
top1 = t.argmax(output, dim=-1)
if output.shape == base_y.shape:
if output.shape[-1] == base_y.shape[-1]:
# To handle the case when labels are one-hot
# TODO: is there a better way?
base_y = t.argmax(base_y, dim=-1)
base_y = t.argmax(base_y, dim=-1).squeeze()
base_y = base_y.to(top1.device)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, moving to devices here might not be best for debugging...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed .to(device) line

accuracy = (top1 == base_y).float().mean()
else:
accuracy = ((output - base_y).abs() < atol).float().mean()
Expand All @@ -151,4 +158,8 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo
return metric.get_value() == 100
else:
return super()._check_early_stop_condition(test_metrics)
return False
return False

def _run_epoch_extras(self, epoch_number: int) -> None:
self.training_args['iit_weight'] = self.training_args['iit_weight_schedule'](self.training_args['iit_weight'], epoch_number)
self.training_args['behavior_weight'] = self.training_args['behavior_weight_schedule'](self.training_args['behavior_weight'], epoch_number)
10 changes: 7 additions & 3 deletions iit/model_pairs/iit_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,26 @@ def __init__(
self.hl_model.requires_grad_(False)

self.corr = corr
print(self.hl_model.hook_dict)
print(self.corr.keys())
assert all([str(k) in self.hl_model.hook_dict for k in self.corr.keys()])
default_training_args = {
"batch_size": 256,
"lr": 0.001,
"num_workers": 0,
"early_stop": True,
"lr_scheduler": None,
"scheduler_val_metric": ["val/accuracy", "val/IIA"],
"scheduler_mode": "max",
"scheduler_kwargs": {},
"clip_grad_norm": 1.0,
"seed": 0,
"detach_while_caching": True,
"optimizer_cls": t.optim.Adam,
"optimizer_kwargs" : {
"lr": 0.001,
"betas": (0.9, 0.9)
},
}
training_args = {**default_training_args, **training_args}

if isinstance(ll_model, HookedRootModule):
ll_model = LLModel.make_from_hooked_transformer(
ll_model, detach_while_caching=training_args["detach_while_caching"]
Expand Down
5 changes: 2 additions & 3 deletions iit/model_pairs/probed_sequential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def train(
self,
train_set: IITDataset,
test_set: IITDataset,
optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam,
epochs: int = 1000,
use_wandb: bool = False,
wandb_name_suffix: str = "",
Expand All @@ -107,9 +106,9 @@ def train(
for p in probes.values():
params += list(p.parameters())
optimizer_kwargs['lr'] = training_args["lr"]
probe_optimizer = optimizer_cls(params, **optimizer_kwargs)
probe_optimizer = training_args['optimizer_cls'](params, **optimizer_kwargs)

optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs)
optimizer = training_args['optimizer_cls'](self.ll_model.parameters(), **optimizer_kwargs)
loss_fn = t.nn.CrossEntropyLoss()

if use_wandb and not wandb.run:
Expand Down
Loading
Loading