diff --git a/iit/utils/eval_ablations.py b/iit/utils/eval_ablations.py index 95ba969..459e2c4 100644 --- a/iit/utils/eval_ablations.py +++ b/iit/utils/eval_ablations.py @@ -335,7 +335,7 @@ def ablate_nodes( ) changed_result = (~ll_unchanged).cpu().float() * accuracy if relative_change: - return changed_result.sum().item() / (accuracy.float().sum().item() + 1e-6) + return changed_result.sum() / (accuracy.float().sum() + 1e-6) return (~ll_unchanged).cpu().float().mean()