Skip to content

Commit

Permalink
fix some bugs, skip loss updates when weight is 0
Browse files Browse the repository at this point in the history
  • Loading branch information
cybershiptrooper committed Aug 6, 2024
1 parent 02101ec commit e3b5be7
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 42 deletions.
102 changes: 60 additions & 42 deletions iit/model_pairs/strict_iit_model_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ def __init__(
}
training_args = {**default_training_args, **training_args}
super().__init__(hl_model, ll_model, corr=corr, training_args=training_args)
self.nodes_not_in_circuit = node_picker.get_nodes_not_in_circuit(
self.ll_model, self.corr
)
self.nodes_not_in_circuit = node_picker.get_nodes_not_in_circuit(self.ll_model, self.corr)
assert (
self.training_args["iit_weight"] > 0
or self.training_args["behavior_weight"] > 0
or self.training_args["strict_weight"] > 0
), ValueError("At least one of the losses should be non-zero")

@staticmethod
def make_train_metrics() -> MetricStoreCollection:
Expand All @@ -45,11 +48,12 @@ def make_train_metrics() -> MetricStoreCollection:
MetricStore("train/strict_loss", MetricType.LOSS),
]
)

@staticmethod
def make_test_metrics() -> MetricStoreCollection:
return MetricStoreCollection(
IITBehaviorModelPair.make_test_metrics().metrics + [MetricStore("val/strict_accuracy", MetricType.ACCURACY)],
IITBehaviorModelPair.make_test_metrics().metrics
+ [MetricStore("val/strict_accuracy", MetricType.ACCURACY)],
)

def sample_ll_node(self) -> LLNode:
Expand All @@ -64,48 +68,57 @@ def run_train_step(
) -> dict:
use_single_loss = self.training_args["use_single_loss"]

hl_node = self.sample_hl_name() # sample a high-level variable to ablate
iit_loss = (
self.get_IIT_loss_over_batch(base_input, ablation_input, hl_node, loss_fn)
* self.training_args["iit_weight"]
)
if not use_single_loss:
self.step_on_loss(iit_loss, optimizer)
iit_loss = 0
ll_loss = 0
behavior_loss = 0

if self.training_args["iit_weight"] > 0:
hl_node = self.sample_hl_name() # sample a high-level variable to ablate
iit_loss = (
self.get_IIT_loss_over_batch(base_input, ablation_input, hl_node, loss_fn)
* self.training_args["iit_weight"]
)
if not use_single_loss:
self.step_on_loss(iit_loss, optimizer)

# loss for nodes that are not in the circuit
# should not have causal effect on the high-level output
base_x, base_y = base_input[0:2]
ablation_x, ablation_y = ablation_input[0:2]
ll_node = self.sample_ll_node()
_, cache = self.ll_model.run_with_cache(ablation_x)
self.ll_cache = cache
out = self.ll_model.run_with_hooks(
base_x, fwd_hooks=[(ll_node.name, self.make_ll_ablation_hook(ll_node))]
)
# print(out.shape, base_y.shape)
label_idx = self.get_label_idxs()
ll_loss = (
loss_fn(out[label_idx.as_index], base_y[label_idx.as_index].to(self.ll_model.cfg.device))
* self.training_args["strict_weight"]
) # do this only for the tokens that we care about for IIT
if not use_single_loss:
self.step_on_loss(ll_loss, optimizer)

behavior_loss = (
self.get_behaviour_loss_over_batch(base_input, loss_fn)
* self.training_args["behavior_weight"]
)
if not use_single_loss:
self.step_on_loss(behavior_loss, optimizer)
if self.training_args["strict_weight"] > 0:
base_x, base_y = base_input[0:2]
ablation_x, ablation_y = ablation_input[0:2]
ll_node = self.sample_ll_node()
_, cache = self.ll_model.run_with_cache(ablation_x)
self.ll_cache = cache
out = self.ll_model.run_with_hooks(
base_x, fwd_hooks=[(ll_node.name, self.make_ll_ablation_hook(ll_node))]
)

label_idx = self.get_label_idxs()
ll_loss = (
loss_fn(
out[label_idx.as_index], base_y[label_idx.as_index].to(self.ll_model.cfg.device)
)
* self.training_args["strict_weight"]
) # do this only for the tokens that we care about for IIT
if not use_single_loss:
self.step_on_loss(ll_loss, optimizer)

if self.training_args["behavior_weight"] > 0:
behavior_loss = (
self.get_behaviour_loss_over_batch(base_input, loss_fn)
* self.training_args["behavior_weight"]
)
if not use_single_loss:
self.step_on_loss(behavior_loss, optimizer)

if use_single_loss:
total_loss = iit_loss + behavior_loss + ll_loss
self.step_on_loss(total_loss, optimizer)

return {
"train/iit_loss": iit_loss.item(),
"train/behavior_loss": behavior_loss.item(),
"train/strict_loss": ll_loss.item(),
"train/iit_loss": iit_loss.item() if isinstance(iit_loss, Tensor) else iit_loss,
"train/behavior_loss": behavior_loss.item() if isinstance(behavior_loss, Tensor) else behavior_loss,
"train/strict_loss": ll_loss.item() if isinstance(ll_loss, Tensor) else ll_loss,
}

def run_eval_step(
Expand All @@ -117,7 +130,7 @@ def run_eval_step(
eval_returns = super().run_eval_step(base_input, ablation_input, loss_fn)
base_x, base_y = base_input[0:2]
ablation_x, ablation_y = ablation_input[0:2]

_, cache = self.ll_model.run_with_cache(ablation_x)
label_idx = self.get_label_idxs()
base_y = base_y[label_idx.as_index].to(self.ll_model.cfg.device)
Expand All @@ -134,7 +147,9 @@ def run_eval_step(
top1 = t.argmax(ll_output, dim=-1)
accuracy = (top1 == base_y).float().mean().item()
else:
accuracy = ((ll_output - base_y).abs() < self.training_args["atol"]).float().mean().item()
accuracy = (
((ll_output - base_y).abs() < self.training_args["atol"]).float().mean().item()
)
accuracies.append(accuracy)

if len(accuracies) > 0:
Expand All @@ -149,10 +164,13 @@ def run_eval_step(
def _check_early_stop_condition(self, test_metrics: MetricStoreCollection) -> bool:
metrics_to_check = []
for metric in test_metrics:
if metric.get_name() == "val/strict_accuracy" and self.training_args["strict_weight"] > 0:
if (
metric.get_name() == "val/strict_accuracy"
and self.training_args["strict_weight"] > 0
):
metrics_to_check.append(metric)
if metric.get_name() == "val/accuracy" and self.training_args["behavior_weight"] > 0:
metrics_to_check.append(metric)
if metric.get_name() == "val/IIA" and self.training_args["iit_weight"] > 0:
metrics_to_check.append(metric)
return super()._check_early_stop_condition(MetricStoreCollection(metrics_to_check))
return super()._check_early_stop_condition(MetricStoreCollection(metrics_to_check))
10 changes: 10 additions & 0 deletions iit/tasks/ioi/ioi_hl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from argparse import Namespace
from typing import Callable

import torch as t
Expand Down Expand Up @@ -100,7 +101,16 @@ def __init__(self, d_vocab: int, names: Tensor):
self.hook_name_mover = HookPoint()

self.d_vocab = d_vocab
self.cfg = Namespace(
d_vocab=d_vocab,
d_vocab_out=d_vocab,
device=t.device("cuda") if t.cuda.is_available() else t.device("cpu")
)
self.setup()

@property
def device(self):
return self.cfg.device

def is_categorical(self) -> bool:
return True
Expand Down

0 comments on commit e3b5be7

Please sign in to comment.