Skip to content

Commit

Permalink
Merge pull request #15 from locuslab/refactor_eval
Browse files Browse the repository at this point in the history
add aggregate stat to the eval piprline
  • Loading branch information
zhilif authored Mar 18, 2024
2 parents 8cdf80f + e429290 commit 8889542
Show file tree
Hide file tree
Showing 123 changed files with 783,830 additions and 270,052 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,9 @@ __pycache__
*.sh
paper_models
notebook/
*.ttf

*.ttf
scripts/
reproduce.ipynb
upload_to_hf.py

12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The TOFU dataset serves as a benchmark for evaluating unlearning performance of
- [**Leaderboard on Hugging Face Spaces**](https://huggingface.co/spaces/locuslab/tofu_leaderboard): Current rankings and submissions for the TOFU dataset challenges.
- [**Summary on Twitter**](https://x.com/_akhaliq/status/1745643293839327268): A concise summary and key takeaways from the project.

## Updates 03/18
We have updated a new evaluation pipeline, see the following section on model evaluation. We notice that Llama2 model has reproducibility issue due to the internal randomness of flash attention. You are encouraged to collect your own retain results. Our huggingface leaderboard results and the numbers/figures in the paper are also subject to update. Feel free to contact us if you run into any issue!

## Applicability 🚀

Expand Down Expand Up @@ -63,6 +65,16 @@ CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=$port evaluate_
```
You can modify the configuration in config/eval_everything.yaml. We suggest to evaluate with one gpu, meanwhile we are also working on a script that allows multi-gpu evaluations.

The evaluation result will by default be dumped to `${model_path}/eval_results/ds_size${ds_size}`, you can also modify the `save_dir` field in `config/eval_everything.yaml`

The evaluation results on four datasets (forget, retain, real_world, real_author) will be aggregated into one json file named `eval_log_aggregated.json`. Finally, you can run
```
python aggregate_eval_stat.py retain_result=${path_to_aggregated_retain_result} ckpt_result=${path_to_aggregated_retain_result} \
method_name=${method_name} save_file=${save_filename}
```
to obtain an aggregated csv format result which contains the overall model utility and forget quality. Here the `${path_to_aggregated_retain_result}` and `${path_to_aggregated_retain_result}` are the path to your `eval_log_aggregated.json`. The retain results are uploaded in `data/`.


### Available forget sets are:

- `forget01`: Forgetting 1% of the original dataset, all entries correspond to a single author.
Expand Down
Empty file added __init__.py
Empty file.
108 changes: 108 additions & 0 deletions aggregate_eval_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@

from omegaconf import OmegaConf
import hydra
import json
import numpy as np
from scipy.stats import hmean
from scipy.stats import sem, hmean, ks_2samp
import pprint
import csv
def get_forget_quality(unlearn_result, retain_result):
unlearn_forget_result = unlearn_result['eval_log_forget.json']
retain_forget_result = retain_result['eval_log_forget.json']

unlearn_paraphrase_np_values = np.array(list(unlearn_forget_result['avg_paraphrased_loss'].values()))
unlearn_perturbed_np_values = np.array(list(unlearn_forget_result['average_perturb_loss'].values()))
unlearn_perturbed_np_values = unlearn_perturbed_np_values.mean(axis=-1)

retain_paraphrase_np_values = np.array(list(retain_forget_result['avg_paraphrased_loss'].values()))
retain_perturbed_np_values = np.array(list(retain_forget_result['average_perturb_loss'].values()))
retain_perturbed_np_values = retain_perturbed_np_values.mean(axis=-1)

unlearn_truth_ratio = np.exp( unlearn_perturbed_np_values - unlearn_paraphrase_np_values)
retain_truth_ratio = np.exp( retain_perturbed_np_values - retain_paraphrase_np_values)

test_res = ks_2samp(unlearn_truth_ratio, retain_truth_ratio)
return {'Forget Quality': test_res.pvalue, 'KS Test PVal Forget': test_res.pvalue, 'KS Test Forget': test_res.statistic}

def get_model_utility(eval_result_dict):
eval_task_dict = {
'eval_real_author_wo_options.json': 'Real Authors',
'eval_real_world_wo_options.json': 'Real World',
'eval_log.json': 'Retain',
'eval_log_forget.json': 'Forget'
}
eval_tasks = list(eval_task_dict.keys())
metrics = ['ROUGE', 'Prob.', 'Truth Ratio']

output_result = {}
for eval_task in eval_tasks:
for metric in metrics:
output_result[metric + ' ' + eval_task_dict[eval_task]] = []

# k is different files
for k, v in eval_result_dict.items():
# getting Probability
if 'eval_log' in k:
gt_probs = np.exp(-1 * np.array(list(eval_result_dict[k]['avg_gt_loss'].values())))
avg_gt_prob = np.mean(gt_probs)
else:
avg_true_prob = np.exp(-1 * np.array(list(eval_result_dict[k]['avg_gt_loss'].values())))
avg_false_prob = np.exp(-1 * np.array(list(eval_result_dict[k]['average_perturb_loss'].values())))
avg_all_prob = np.concatenate([np.expand_dims(avg_true_prob, axis=-1), avg_false_prob], axis=1).sum(-1)
avg_gt_prob = np.mean(avg_true_prob/avg_all_prob)
output_result[f'Prob. {eval_task_dict[k]}'] = avg_gt_prob

# getting ROUGE
avg_rouge = np.array(list(eval_result_dict[k]['rougeL_recall'].values())).mean()
output_result[f'ROUGE {eval_task_dict[k]}'] = avg_rouge

# getting Truth Ratio
avg_paraphrase_np_values = np.array(list(eval_result_dict[k]['avg_paraphrased_loss'].values()))

avg_perturbed_np_values = np.array(list(eval_result_dict[k]['average_perturb_loss'].values()))
avg_perturbed_np_values = avg_perturbed_np_values.mean(axis=-1)

curr_stat_1 = np.exp( avg_perturbed_np_values - avg_paraphrase_np_values)
# output_result[f'{eval_task_dict[k]} paraphrased_over_perturbed'] = curr_stat_1
if 'forget' in k:
paraphrased_perturb_ratio = np.mean(np.minimum(curr_stat_1, 1/curr_stat_1))
else:
paraphrased_perturb_ratio = np.mean(np.maximum(0, 1 - 1/curr_stat_1))
output_result[f'Truth Ratio {eval_task_dict[k]}'] = paraphrased_perturb_ratio

model_utility_cands = []
for k, v in output_result.items():
if 'Forget' not in k:
model_utility_cands.append(v)
output_result['Model Utility'] = hmean(model_utility_cands)
return output_result

@hydra.main(version_base=None, config_path="config", config_name="aggregate_eval_stat")
def main(cfg):
if cfg.retain_result is None or cfg.ckpt_result is None:
raise ValueError("Please provide either retain_result or ckpt_result")

retain_result = json.load(open(cfg.retain_result))
ckpt_result = json.load(open(cfg.ckpt_result))

# We have to assume here that retain_result and ckpt_result follow these structure:
# The top most layer has ['eval_log.json', 'eval_log_forget.json', 'eval_real_world_wo_options.json', 'eval_real_author_wo_options']
# the second layer contains the actual metrics: ['avg_gt_loss', 'average_perturb_loss', 'avg_paraphrased_loss', 'rougeL_recall']
# within each metric, we have {data_idx: measurement}

model_utility = get_model_utility(ckpt_result)
forget_quality = get_forget_quality(ckpt_result, retain_result)
model_utility['Forget Quality'] = forget_quality['Forget Quality']

model_utility['Method'] = cfg.method_name
model_utility['Submitted By'] = cfg.submitted_by
# dump the model utility to a temp.csv
with open(cfg.save_file, 'w') as f: # You will need 'wb' mode in Python 2.x
w = csv.DictWriter(f, model_utility.keys())
w.writeheader()
w.writerow(model_utility)
return model_utility

if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions config/aggregate_eval_stat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
retain_result: null
ckpt_result: null
method_name: temp
submitted_by: john
save_file: aggr_result.csv
4 changes: 3 additions & 1 deletion config/eval_everything.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,6 @@ use_pretrained: false

batch_size: 30
reinitialize_weights: false
retain_result: null

retain_result: null

4 changes: 3 additions & 1 deletion config/finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@ batch_size: 16
gradient_accumulation_steps: 1
num_epochs: 5
lr: 1e-5
save_dir: /data/locus/llm_weights/zhilif/TOFU/final_ft_noLORA_5_epochs_inst_lr${lr}_${model_family}_${split}_${weight_decay}
save_dir: /data/locus/llm_weights/zhilif/TOFU/ft_epoch${num_epochs}_lr${lr}_${model_family}_${split}_wd${weight_decay}

weight_decay: 0.01
seed: 42
14 changes: 9 additions & 5 deletions config/forget.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@ split: forget01
data_path: locuslab/TOFU
batch_size: 16
gradient_accumulation_steps: 4
num_epochs: 10
num_epochs: 5
forget_loss: grad_ascent
save_dir: ${model_path}/${forget_loss}_${lr}_${split}

save_dir: ${model_path}/${forget_loss}_${lr}_${split}_${num_epochs}
overwrite_dir: false
weight_decay: 0.01
save_model: true
eval_while_train: true
eval_while_train: false
eval_only: false
seed: 42

eval:
retain_result: data/retain90_llama_wd0.01/eval_results/ds_size300/eval_log_aggregated.json
# retain_result: data/retain90_llama_wd0.01/eval_results/ds_size300/eval_log_aggregated.json
model_path: ${..model_path}
model_family: ${..model_family}
save_dir: ${..save_dir}
Expand Down Expand Up @@ -48,4 +51,5 @@ eval:
overwrite: true
use_pretrained: false

batch_size: 30
batch_size: 30
retain_result: null
Loading

0 comments on commit 8889542

Please sign in to comment.