Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training improvements #17

Merged
102 changes: 73 additions & 29 deletions iit/model_pairs/base_model_pair.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, final, Type
from typing import Any, Callable, final, Type, Optional

import numpy as np
import torch as t
Expand All @@ -17,6 +18,8 @@
from iit.utils.iit_dataset import IITDataset
from iit.utils.index import Ix, TorchIndex
from iit.utils.metric import MetricStoreCollection, MetricType
from iit.utils.tqdm import tqdm



class BaseModelPair(ABC):
Expand Down Expand Up @@ -218,11 +221,9 @@ def train(
self,
train_set: IITDataset,
test_set: IITDataset,
optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam,
epochs: int = 1000,
use_wandb: bool = False,
wandb_name_suffix: str = "",
optimizer_kwargs: dict = {},
) -> None:
training_args = self.training_args
print(f"{training_args=}")
Expand All @@ -233,6 +234,10 @@ def train(
assert isinstance(test_set, IITDataset), ValueError(
f"test_set is not an instance of IITDataset, but {type(test_set)}"
)
assert self.ll_model.device == self.hl_model.device, ValueError(
"ll_model and hl_model are not on the same device"
)

train_loader, test_loader = self.make_loaders(
train_set,
test_set,
Expand All @@ -242,15 +247,19 @@ def train(

early_stop = training_args["early_stop"]

optimizer_kwargs['lr'] = training_args["lr"]
optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs)
optimizer = training_args['optimizer_cls'](self.ll_model.parameters(), **training_args['optimizer_kwargs'])
loss_fn = self.loss_fn
scheduler_cls = training_args.get("lr_scheduler", None)
scheduler_kwargs = training_args.get("scheduler_kwargs", {})
if scheduler_cls == t.optim.lr_scheduler.ReduceLROnPlateau:
mode = training_args.get("scheduler_mode", "max")
lr_scheduler = scheduler_cls(optimizer, mode=mode, factor=0.1, patience=10)
if 'patience' not in scheduler_kwargs:
scheduler_kwargs['patience'] = 10
if 'factor' not in scheduler_kwargs:
scheduler_kwargs['factor'] = 0.1
lr_scheduler = scheduler_cls(optimizer, mode=mode, **scheduler_kwargs)
elif scheduler_cls:
lr_scheduler = scheduler_cls(optimizer)
lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs)

if use_wandb and not wandb.run:
wandb.init(project="iit", name=wandb_name_suffix,
Expand All @@ -261,19 +270,32 @@ def train(
wandb.config.update({"method": self.wandb_method})
wandb.run.log_code() # type: ignore

for epoch in tqdm(range(epochs)):
train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer)
test_metrics = self._run_eval_epoch(test_loader, loss_fn)
if scheduler_cls:
self.step_scheduler(lr_scheduler, test_metrics)
self.test_metrics = test_metrics
self.train_metrics = train_metrics
self._print_and_log_metrics(
epoch, MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), use_wandb
)

# Set seed before iterating on loaders for reproduceablility.
t.manual_seed(training_args["seed"])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use a generator for loaders like we do for numpy? I think I used to set this once globally in the training script before- my bad. :(

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally sure? I got this solution here. It seems like the random operation is set when the dataloader is turned into an iterable, and someone could use torch functions between initializing and training the model pair, which could hinder reproducibility without putting something here.

with tqdm(range(epochs), desc="Training Epochs") as epoch_pbar:
with tqdm(total=len(train_loader), desc="Training Batches") as batch_pbar:
for epoch in range(epochs):
batch_pbar.reset()

train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer, batch_pbar)
test_metrics = self._run_eval_epoch(test_loader, loss_fn)
if scheduler_cls:
self.step_scheduler(lr_scheduler, test_metrics)
self.test_metrics = test_metrics
self.train_metrics = train_metrics
self._print_and_log_metrics(
epoch=epoch,
metrics=MetricStoreCollection(train_metrics.metrics + test_metrics.metrics),
optimizer=optimizer,
use_wandb=use_wandb,
epoch_pbar=epoch_pbar
)

if early_stop and self._check_early_stop_condition(test_metrics):
break
if early_stop and self._check_early_stop_condition(test_metrics):
break

self._run_epoch_extras(epoch_number=epoch+1)

if use_wandb:
wandb.log({"final epoch": epoch})
Expand All @@ -298,16 +320,16 @@ def _run_train_epoch(
self,
loader: DataLoader,
loss_fn: Callable[[Tensor, Tensor], Tensor],
optimizer: t.optim.Optimizer
optimizer: t.optim.Optimizer,
pbar: tqdm
) -> MetricStoreCollection:
self.ll_model.train()
train_metrics = self.make_train_metrics()
for i, (base_input, ablation_input) in tqdm(
enumerate(loader), total=len(loader)
):
for i, (base_input, ablation_input) in enumerate(loader):
train_metrics.update(
self.run_train_step(base_input, ablation_input, loss_fn, optimizer)
)
pbar.update(1)
return train_metrics

@final
Expand Down Expand Up @@ -341,17 +363,39 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo
return True

@final
@staticmethod
def _print_and_log_metrics(
self,
epoch: int,
metrics: MetricStoreCollection,
use_wandb: bool = False
) -> None:
print(f"\nEpoch {epoch}:", end=" ")
optimizer: t.optim.Optimizer,
use_wandb: bool = False,
print_metrics: bool = True,
epoch_pbar: Optional[tqdm] = None,
) -> str:

# Print the current epoch's metrics
current_epoch_log = f"lr: {optimizer.param_groups[0]['lr']:.2e}, "
for k in self.training_args.keys():
if 'weight' in k and 'schedule' not in k:
current_epoch_log += f"{k}: {self.training_args[k]:.2e}, "
if use_wandb:
wandb.log({"epoch": epoch})
wandb.log({"lr": optimizer.param_groups[0]["lr"]})

for metric in metrics:
print(metric, end=", ")
current_epoch_log += str(metric) + ", "
if use_wandb:
wandb.log({metric.get_name(): metric.get_value()})
print()
if print_metrics:
tqdm.write(f'Epoch {epoch+1}: {current_epoch_log.strip(", ")}')

if epoch_pbar is not None:
epoch_pbar.update(1)
epoch_pbar.set_postfix_str(current_epoch_log.strip(', '))
epoch_pbar.set_description(f"Epoch {epoch + 1}")

return current_epoch_log

def _run_epoch_extras(self, epoch_number: int) -> None:
""" Optional method for running extra code at the end of each epoch """
pass
57 changes: 38 additions & 19 deletions iit/model_pairs/iit_behavior_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from iit.model_pairs.ll_model import LLModel
from iit.utils.correspondence import Correspondence
from iit.utils.metric import MetricStore, MetricStoreCollection, MetricType
from iit.utils.nodes import HLNode


class IITBehaviorModelPair(IITModelPair):
Expand All @@ -19,12 +20,11 @@ def __init__(
training_args: dict = {}
):
default_training_args = {
"lr": 0.001,
"atol": 5e-2,
"early_stop": True,
"use_single_loss": False,
"iit_weight": 1.0,
"behavior_weight": 1.0,
"val_IIA_sampling": "random", # random or all
}
training_args = {**default_training_args, **training_args}
super().__init__(hl_model, ll_model, corr=corr, training_args=training_args)
Expand Down Expand Up @@ -56,8 +56,9 @@ def get_behaviour_loss_over_batch(
) -> Tensor:
base_x, base_y = base_input[0:2]
output = self.ll_model(base_x)
behavior_loss = loss_fn(output, base_y)
return behavior_loss
label_indx = self.get_label_idxs()
behaviour_loss = loss_fn(output[label_indx.as_index], base_y[label_indx.as_index].to(output.device))
return behaviour_loss

def step_on_loss(self, loss: Tensor, optimizer: t.optim.Optimizer) -> None:
optimizer.zero_grad()
Expand Down Expand Up @@ -108,25 +109,43 @@ def run_eval_step(

# compute IIT loss and accuracy
label_idx = self.get_label_idxs()
hl_node = self.sample_hl_name()
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node)
hl_output.to(ll_output.device)
hl_output = hl_output[label_idx.as_index]
ll_output = ll_output[label_idx.as_index]
if self.hl_model.is_categorical():
loss = loss_fn(ll_output, hl_output)
if ll_output.shape == hl_output.shape:
# To handle the case when labels are one-hot
hl_output = t.argmax(hl_output, dim=-1)
top1 = t.argmax(ll_output, dim=-1)
accuracy = (top1 == hl_output).float().mean()
IIA = accuracy.item()

def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]:
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node)
hl_output = hl_output.to(ll_output.device)
hl_output = hl_output[label_idx.as_index]
ll_output = ll_output[label_idx.as_index]
if self.hl_model.is_categorical():
loss = loss_fn(ll_output, hl_output)
if ll_output.shape == hl_output.shape:
# To handle the case when labels are one-hot
hl_output = t.argmax(hl_output, dim=-1)
top1 = t.argmax(ll_output, dim=-1)
accuracy = (top1 == hl_output).float().mean()
IIA = accuracy.item()
else:
loss = loss_fn(ll_output, hl_output)
IIA = ((ll_output - hl_output).abs() < atol).float().mean().item()
return IIA, loss

if self.training_args["val_IIA_sampling"] == "random":
hl_node = self.sample_hl_name()
IIA, loss = get_node_IIT_info(hl_node)
elif self.training_args["val_IIA_sampling"] == "all":
iias = []
losses = []
for hl_node in self.corr.keys():
IIA, loss = get_node_IIT_info(hl_node)
iias.append(IIA)
losses.append(loss)
IIA = sum(iias) / len(iias)
loss = t.stack(losses).mean()
else:
loss = loss_fn(ll_output, hl_output)
IIA = ((ll_output - hl_output).abs() < atol).float().mean().item()
raise ValueError(f"Invalid val_IIA_sampling: {self.training_args['val_IIA_sampling']}")

# compute behavioral accuracy
base_x, base_y = base_input[0:2]
base_y = base_y.to(self.ll_model.device) # so that input data doesn't all need to be hogging room on device.
output = self.ll_model(base_x)
if self.hl_model.is_categorical():
top1 = t.argmax(output, dim=-1)
Expand Down
9 changes: 6 additions & 3 deletions iit/model_pairs/iit_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,25 @@ def __init__(
self.hl_model.requires_grad_(False)

self.corr = corr
print(self.hl_model.hook_dict)
print(self.corr.keys())
assert all([str(k) in self.hl_model.hook_dict for k in self.corr.keys()])
default_training_args = {
"batch_size": 256,
"lr": 0.001,
"num_workers": 0,
"early_stop": True,
"lr_scheduler": None,
"scheduler_val_metric": ["val/accuracy", "val/IIA"],
"scheduler_mode": "max",
"scheduler_kwargs": {},
"clip_grad_norm": 1.0,
"seed": 0,
"detach_while_caching": True,
"optimizer_cls": t.optim.Adam,
"optimizer_kwargs" : {
"lr": 0.001,
},
}
training_args = {**default_training_args, **training_args}

if isinstance(ll_model, HookedRootModule):
ll_model = LLModel.make_from_hooked_transformer(
ll_model, detach_while_caching=training_args["detach_while_caching"]
Expand Down
5 changes: 2 additions & 3 deletions iit/model_pairs/probed_sequential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def train(
self,
train_set: IITDataset,
test_set: IITDataset,
optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam,
epochs: int = 1000,
use_wandb: bool = False,
wandb_name_suffix: str = "",
Expand All @@ -107,9 +106,9 @@ def train(
for p in probes.values():
params += list(p.parameters())
optimizer_kwargs['lr'] = training_args["lr"]
probe_optimizer = optimizer_cls(params, **optimizer_kwargs)
probe_optimizer = training_args['optimizer_cls'](params, **optimizer_kwargs)

optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs)
optimizer = training_args['optimizer_cls'](self.ll_model.parameters(), **optimizer_kwargs)
loss_fn = t.nn.CrossEntropyLoss()

if use_wandb and not wandb.run:
Expand Down
Loading
Loading