Skip to content

Commit

Permalink
Merge pull request #16 from evanhanders/typing
Browse files Browse the repository at this point in the history
Adds automatic mypy typing and cleans up a bunch of typing errors
  • Loading branch information
cybershiptrooper authored Aug 5, 2024
2 parents cacb50c + 9eea6f7 commit 2aeeb3d
Show file tree
Hide file tree
Showing 53 changed files with 1,116 additions and 735 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Mypy Type Checking

on: [push, pull_request]

jobs:
mypy:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy
- name: Run mypy
run: |
mypy --config-file mypy-${{ matrix.python-version }}.ini
43 changes: 25 additions & 18 deletions eval_causality.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime

import torch as t
from torch import Tensor
import wandb
from tqdm import tqdm

Expand All @@ -10,6 +11,7 @@
from iit.utils.config import DEVICE
from iit.utils.plotter import plot_ablation_stats
from iit.utils.wrapper import get_hook_points
from iit.tasks.mnist_pvr.dataset import ImagePVRDataset


def evaluate_model_on_ablations(
Expand All @@ -18,15 +20,15 @@ def evaluate_model_on_ablations(
test_set: t.utils.data.Dataset,
eval_args: dict,
verbose: bool = False,
):
) -> dict:
print("reached evaluate_model!")
stats_per_layer = {}
for hook_point in tqdm(get_hook_points(ll_model), desc="Hook points"):
_, hl_model, corr = get_alignment(
task,
config={
"hook_point": hook_point,
"input_shape": test_set.get_input_shape(),
"input_shape": test_set.get_input_shape(), # type: ignore
},
)
model_pair = IITProbeSequentialPair(
Expand All @@ -40,23 +42,26 @@ def evaluate_model_on_ablations(
# set up stats
hookpoint_stats = {}
for hl_node, _ in model_pair.corr.items():
hookpoint_stats[hl_node] = 0
hookpoint_stats[hl_node] = t.zeros(1)
# find test accuracy
with t.no_grad():
for base_input_lists in tqdm(dataloader, desc=f"Ablations on {hook_point}"):
base_input = [x.to(DEVICE) for x in base_input_lists]
base_input: tuple[Tensor, Tensor, Tensor] = (x.to(DEVICE) for x in base_input_lists) # type: ignore
for hl_node, ll_nodes in model_pair.corr.items():
ablated_input = test_set.patch_batch_at_hl(
list(base_input[0]),
list(base_input_lists[-1]),
hl_node,
list(base_input[1]),
)
ablated_input = (
t.stack(ablated_input[0]).to(DEVICE), # input
t.stack(ablated_input[1]).to(DEVICE), # label
t.stack(ablated_input[2]).to(DEVICE),
) # intermediate_data
if isinstance(test_set, ImagePVRDataset):
ablated_input_pre = test_set.patch_batch_at_hl(
list(base_input),
list(base_input_lists),
hl_node,
)
ablated_input = (
t.stack(ablated_input_pre[0]).to(DEVICE), # input
t.stack(ablated_input_pre[1]).to(DEVICE), # label
t.stack(ablated_input_pre[2]).to(DEVICE),
) # intermediate_data
else:
raise ValueError(f"patch_batch_at_hl not implemented for this dataset type: {type(test_set)}")

# unsqueeze if single element
if ablated_input[1].shape == ():
assert (
Expand Down Expand Up @@ -115,14 +120,16 @@ def evaluate_model_on_ablations(
ll_model, hl_model, corr = get_alignment(
task, config={"input_shape": test_set.get_input_shape()}
)
assert ll_model is not None
assert hl_model is not None
model_pair = IITProbeSequentialPair(
ll_model=ll_model, hl_model=hl_model, corr=corr, training_args=training_args
)
if train:
model_pair.train(
train_set,
test_set,
epochs=training_args["epochs"],
epochs=int(training_args["epochs"]),
use_wandb=use_wandb,
)
else:
Expand All @@ -147,8 +154,8 @@ def evaluate_model_on_ablations(

if use_wandb:
wandb.init(project="iit")
wandb.run.name = f"{leaky_task}_ablation"
wandb.run.save()
wandb.run.name = f"{leaky_task}_ablation" # type: ignore
wandb.run.save() # type: ignore
wandb.config.update(eval_args)

leaky_stats_per_layer = evaluate_model_on_ablations(
Expand Down
15 changes: 9 additions & 6 deletions eval_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch as t
from tqdm import tqdm
from iit.utils.plotter import plot_probe_stats
from iit.utils.iit_dataset import IITDataset
import os
import wandb
from datetime import datetime
Expand All @@ -21,14 +22,14 @@ def evaluate_model_on_probes(
use_wandb: bool = False,
verbose: bool = False,
save_probes: bool = False,
):
) -> dict:
print("reached evaluate_model!")
probe_stats_per_layer = {}
log_stats_per_layer = {}
if use_wandb:
wandb.init(project="iit")
wandb.run.name = f"{task}_probes"
wandb.run.save()
wandb.run.name = f"{task}_probes" # type: ignore
wandb.run.save() # type: ignore
# add training args to wandb config
wandb.config.update(probe_training_args)

Expand All @@ -37,7 +38,7 @@ def evaluate_model_on_probes(
task,
config={
"hook_point": hook_point,
"input_shape": test_set.get_input_shape(),
"input_shape": test_set.get_input_shape(), # type: ignore
},
)
model_pair = IITProbeSequentialPair(
Expand All @@ -47,7 +48,7 @@ def evaluate_model_on_probes(
training_args=probe_training_args,
)

input_shape = train_set.get_input_shape()
input_shape = train_set.get_input_shape() # type: ignore
trainer_out = train_probes_on_model_pair(
model_pair, input_shape, train_set, probe_training_args
)
Expand Down Expand Up @@ -103,13 +104,15 @@ def evaluate_model_on_probes(
ll_model, hl_model, corr = get_alignment(
task, config={"input_shape": test_set.get_input_shape()}
)
assert ll_model is not None
assert hl_model is not None
model_pair = IITProbeSequentialPair(
ll_model=ll_model, hl_model=hl_model, corr=corr, training_args=training_args
)
model_pair.train(
train_set,
test_set,
epochs=training_args["epochs"],
epochs=int(training_args["epochs"]),
use_wandb=use_wandb,
)
if use_wandb:
Expand Down
4 changes: 3 additions & 1 deletion eval_ioi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from iit.utils.argparsing import IOIArgParseNamespace
from iit.utils.eval_scripts import eval_ioi
import torch

Expand Down Expand Up @@ -45,4 +46,5 @@
)

args = parser.parse_args()
eval_ioi(args)
namespace = IOIArgParseNamespace(**vars(args))
eval_ioi(namespace)
Loading

0 comments on commit 2aeeb3d

Please sign in to comment.