Skip to content

Commit

Permalink
Merge pull request #17 from evanhanders/training_improvements
Browse files Browse the repository at this point in the history
Training improvements
  • Loading branch information
cybershiptrooper authored Aug 23, 2024
2 parents bc8ca9b + e4eba0a commit e0be350
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 88 deletions.
102 changes: 73 additions & 29 deletions iit/model_pairs/base_model_pair.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, final, Type
from typing import Any, Callable, final, Type, Optional

import numpy as np
import torch as t
Expand All @@ -17,6 +18,8 @@
from iit.utils.iit_dataset import IITDataset
from iit.utils.index import Ix, TorchIndex
from iit.utils.metric import MetricStoreCollection, MetricType
from iit.utils.tqdm import tqdm



class BaseModelPair(ABC):
Expand Down Expand Up @@ -218,11 +221,9 @@ def train(
self,
train_set: IITDataset,
test_set: IITDataset,
optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam,
epochs: int = 1000,
use_wandb: bool = False,
wandb_name_suffix: str = "",
optimizer_kwargs: dict = {},
) -> None:
training_args = self.training_args
print(f"{training_args=}")
Expand All @@ -233,6 +234,10 @@ def train(
assert isinstance(test_set, IITDataset), ValueError(
f"test_set is not an instance of IITDataset, but {type(test_set)}"
)
assert self.ll_model.device == self.hl_model.device, ValueError(
"ll_model and hl_model are not on the same device"
)

train_loader, test_loader = self.make_loaders(
train_set,
test_set,
Expand All @@ -242,15 +247,19 @@ def train(

early_stop = training_args["early_stop"]

optimizer_kwargs['lr'] = training_args["lr"]
optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs)
optimizer = training_args['optimizer_cls'](self.ll_model.parameters(), **training_args['optimizer_kwargs'])
loss_fn = self.loss_fn
scheduler_cls = training_args.get("lr_scheduler", None)
scheduler_kwargs = training_args.get("scheduler_kwargs", {})
if scheduler_cls == t.optim.lr_scheduler.ReduceLROnPlateau:
mode = training_args.get("scheduler_mode", "max")
lr_scheduler = scheduler_cls(optimizer, mode=mode, factor=0.1, patience=10)
if 'patience' not in scheduler_kwargs:
scheduler_kwargs['patience'] = 10
if 'factor' not in scheduler_kwargs:
scheduler_kwargs['factor'] = 0.1
lr_scheduler = scheduler_cls(optimizer, mode=mode, **scheduler_kwargs)
elif scheduler_cls:
lr_scheduler = scheduler_cls(optimizer)
lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs)

if use_wandb and not wandb.run:
wandb.init(project="iit", name=wandb_name_suffix,
Expand All @@ -261,19 +270,32 @@ def train(
wandb.config.update({"method": self.wandb_method})
wandb.run.log_code() # type: ignore

for epoch in tqdm(range(epochs)):
train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer)
test_metrics = self._run_eval_epoch(test_loader, loss_fn)
if scheduler_cls:
self.step_scheduler(lr_scheduler, test_metrics)
self.test_metrics = test_metrics
self.train_metrics = train_metrics
self._print_and_log_metrics(
epoch, MetricStoreCollection(train_metrics.metrics + test_metrics.metrics), use_wandb
)

# Set seed before iterating on loaders for reproduceablility.
t.manual_seed(training_args["seed"])
with tqdm(range(epochs), desc="Training Epochs") as epoch_pbar:
with tqdm(total=len(train_loader), desc="Training Batches") as batch_pbar:
for epoch in range(epochs):
batch_pbar.reset()

train_metrics = self._run_train_epoch(train_loader, loss_fn, optimizer, batch_pbar)
test_metrics = self._run_eval_epoch(test_loader, loss_fn)
if scheduler_cls:
self.step_scheduler(lr_scheduler, test_metrics)
self.test_metrics = test_metrics
self.train_metrics = train_metrics
self._print_and_log_metrics(
epoch=epoch,
metrics=MetricStoreCollection(train_metrics.metrics + test_metrics.metrics),
optimizer=optimizer,
use_wandb=use_wandb,
epoch_pbar=epoch_pbar
)

if early_stop and self._check_early_stop_condition(test_metrics):
break
if early_stop and self._check_early_stop_condition(test_metrics):
break

self._run_epoch_extras(epoch_number=epoch+1)

if use_wandb:
wandb.log({"final epoch": epoch})
Expand All @@ -298,16 +320,16 @@ def _run_train_epoch(
self,
loader: DataLoader,
loss_fn: Callable[[Tensor, Tensor], Tensor],
optimizer: t.optim.Optimizer
optimizer: t.optim.Optimizer,
pbar: tqdm
) -> MetricStoreCollection:
self.ll_model.train()
train_metrics = self.make_train_metrics()
for i, (base_input, ablation_input) in tqdm(
enumerate(loader), total=len(loader)
):
for i, (base_input, ablation_input) in enumerate(loader):
train_metrics.update(
self.run_train_step(base_input, ablation_input, loss_fn, optimizer)
)
pbar.update(1)
return train_metrics

@final
Expand Down Expand Up @@ -341,17 +363,39 @@ def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bo
return True

@final
@staticmethod
def _print_and_log_metrics(
self,
epoch: int,
metrics: MetricStoreCollection,
use_wandb: bool = False
) -> None:
print(f"\nEpoch {epoch}:", end=" ")
optimizer: t.optim.Optimizer,
use_wandb: bool = False,
print_metrics: bool = True,
epoch_pbar: Optional[tqdm] = None,
) -> str:

# Print the current epoch's metrics
current_epoch_log = f"lr: {optimizer.param_groups[0]['lr']:.2e}, "
for k in self.training_args.keys():
if 'weight' in k and 'schedule' not in k:
current_epoch_log += f"{k}: {self.training_args[k]:.2e}, "
if use_wandb:
wandb.log({"epoch": epoch})
wandb.log({"lr": optimizer.param_groups[0]["lr"]})

for metric in metrics:
print(metric, end=", ")
current_epoch_log += str(metric) + ", "
if use_wandb:
wandb.log({metric.get_name(): metric.get_value()})
print()
if print_metrics:
tqdm.write(f'Epoch {epoch+1}: {current_epoch_log.strip(", ")}')

if epoch_pbar is not None:
epoch_pbar.update(1)
epoch_pbar.set_postfix_str(current_epoch_log.strip(', '))
epoch_pbar.set_description(f"Epoch {epoch + 1}")

return current_epoch_log

def _run_epoch_extras(self, epoch_number: int) -> None:
""" Optional method for running extra code at the end of each epoch """
pass
57 changes: 38 additions & 19 deletions iit/model_pairs/iit_behavior_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from iit.model_pairs.ll_model import LLModel
from iit.utils.correspondence import Correspondence
from iit.utils.metric import MetricStore, MetricStoreCollection, MetricType
from iit.utils.nodes import HLNode


class IITBehaviorModelPair(IITModelPair):
Expand All @@ -19,12 +20,11 @@ def __init__(
training_args: dict = {}
):
default_training_args = {
"lr": 0.001,
"atol": 5e-2,
"early_stop": True,
"use_single_loss": False,
"iit_weight": 1.0,
"behavior_weight": 1.0,
"val_IIA_sampling": "random", # random or all
}
training_args = {**default_training_args, **training_args}
super().__init__(hl_model, ll_model, corr=corr, training_args=training_args)
Expand Down Expand Up @@ -56,8 +56,9 @@ def get_behaviour_loss_over_batch(
) -> Tensor:
base_x, base_y = base_input[0:2]
output = self.ll_model(base_x)
behavior_loss = loss_fn(output, base_y)
return behavior_loss
label_indx = self.get_label_idxs()
behaviour_loss = loss_fn(output[label_indx.as_index], base_y[label_indx.as_index].to(output.device))
return behaviour_loss

def step_on_loss(self, loss: Tensor, optimizer: t.optim.Optimizer) -> None:
optimizer.zero_grad()
Expand Down Expand Up @@ -108,25 +109,43 @@ def run_eval_step(

# compute IIT loss and accuracy
label_idx = self.get_label_idxs()
hl_node = self.sample_hl_name()
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node)
hl_output.to(ll_output.device)
hl_output = hl_output[label_idx.as_index]
ll_output = ll_output[label_idx.as_index]
if self.hl_model.is_categorical():
loss = loss_fn(ll_output, hl_output)
if ll_output.shape == hl_output.shape:
# To handle the case when labels are one-hot
hl_output = t.argmax(hl_output, dim=-1)
top1 = t.argmax(ll_output, dim=-1)
accuracy = (top1 == hl_output).float().mean()
IIA = accuracy.item()

def get_node_IIT_info(hl_node: HLNode) -> tuple[float, Tensor]:
hl_output, ll_output = self.do_intervention(base_input, ablation_input, hl_node)
hl_output = hl_output.to(ll_output.device)
hl_output = hl_output[label_idx.as_index]
ll_output = ll_output[label_idx.as_index]
if self.hl_model.is_categorical():
loss = loss_fn(ll_output, hl_output)
if ll_output.shape == hl_output.shape:
# To handle the case when labels are one-hot
hl_output = t.argmax(hl_output, dim=-1)
top1 = t.argmax(ll_output, dim=-1)
accuracy = (top1 == hl_output).float().mean()
IIA = accuracy.item()
else:
loss = loss_fn(ll_output, hl_output)
IIA = ((ll_output - hl_output).abs() < atol).float().mean().item()
return IIA, loss

if self.training_args["val_IIA_sampling"] == "random":
hl_node = self.sample_hl_name()
IIA, loss = get_node_IIT_info(hl_node)
elif self.training_args["val_IIA_sampling"] == "all":
iias = []
losses = []
for hl_node in self.corr.keys():
IIA, loss = get_node_IIT_info(hl_node)
iias.append(IIA)
losses.append(loss)
IIA = sum(iias) / len(iias)
loss = t.stack(losses).mean()
else:
loss = loss_fn(ll_output, hl_output)
IIA = ((ll_output - hl_output).abs() < atol).float().mean().item()
raise ValueError(f"Invalid val_IIA_sampling: {self.training_args['val_IIA_sampling']}")

# compute behavioral accuracy
base_x, base_y = base_input[0:2]
base_y = base_y.to(self.ll_model.device) # so that input data doesn't all need to be hogging room on device.
output = self.ll_model(base_x)
if self.hl_model.is_categorical():
top1 = t.argmax(output, dim=-1)
Expand Down
9 changes: 6 additions & 3 deletions iit/model_pairs/iit_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,25 @@ def __init__(
self.hl_model.requires_grad_(False)

self.corr = corr
print(self.hl_model.hook_dict)
print(self.corr.keys())
assert all([str(k) in self.hl_model.hook_dict for k in self.corr.keys()])
default_training_args = {
"batch_size": 256,
"lr": 0.001,
"num_workers": 0,
"early_stop": True,
"lr_scheduler": None,
"scheduler_val_metric": ["val/accuracy", "val/IIA"],
"scheduler_mode": "max",
"scheduler_kwargs": {},
"clip_grad_norm": 1.0,
"seed": 0,
"detach_while_caching": True,
"optimizer_cls": t.optim.Adam,
"optimizer_kwargs" : {
"lr": 0.001,
},
}
training_args = {**default_training_args, **training_args}

if isinstance(ll_model, HookedRootModule):
ll_model = LLModel.make_from_hooked_transformer(
ll_model, detach_while_caching=training_args["detach_while_caching"]
Expand Down
5 changes: 2 additions & 3 deletions iit/model_pairs/probed_sequential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def train(
self,
train_set: IITDataset,
test_set: IITDataset,
optimizer_cls: Type[t.optim.Optimizer] = t.optim.Adam,
epochs: int = 1000,
use_wandb: bool = False,
wandb_name_suffix: str = "",
Expand All @@ -107,9 +106,9 @@ def train(
for p in probes.values():
params += list(p.parameters())
optimizer_kwargs['lr'] = training_args["lr"]
probe_optimizer = optimizer_cls(params, **optimizer_kwargs)
probe_optimizer = training_args['optimizer_cls'](params, **optimizer_kwargs)

optimizer = optimizer_cls(self.ll_model.parameters(), **optimizer_kwargs)
optimizer = training_args['optimizer_cls'](self.ll_model.parameters(), **optimizer_kwargs)
loss_fn = t.nn.CrossEntropyLoss()

if use_wandb and not wandb.run:
Expand Down
Loading

0 comments on commit e0be350

Please sign in to comment.