Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add long roberta and llama baselines for long docs #26

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Code modified from: https://github.com/4AI/LS-LLaMA/tree/main
262 changes: 262 additions & 0 deletions llama/evaluate_models.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5d7b1d18",
"metadata": {},
"source": [
"### Evaluate Models\n"
]
},
{
"cell_type": "markdown",
"id": "80b1b2c7",
"metadata": {},
"source": [
"##### Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "87dc70f1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"import pickle\n",
"import json\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import torch.nn.functional as F\n",
"from sklearn import metrics"
]
},
{
"cell_type": "markdown",
"id": "ed3988f7",
"metadata": {},
"source": [
"##### Evaluation Parameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2a1fc80e",
"metadata": {},
"outputs": [],
"source": [
"threshold = 0.5 # currently we don't maximize val f1 to find the threshold... need to grab scores for all the val sets if we do this\n",
"num_std = 1.96\n",
"num_bootstrap = 1000\n",
"line_width = 2\n",
"alpha = 0.2\n",
"font_size = 16\n",
"legend_size = 10\n",
"x_size = 10\n",
"y_size = 10"
]
},
{
"cell_type": "markdown",
"id": "1ac2d927-76e0-48ab-8777-48bc70206d07",
"metadata": {},
"source": [
"##### Initialize Score, Model, and Color Arrays"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "acc793b2-80d0-45ac-9a3a-15dcf8fb53fb",
"metadata": {},
"outputs": [],
"source": [
"# Define master lists of labels, scores, names, and colors\n",
"all_y_trues, all_y_scores, all_model_names, all_colors = [], [], [], []"
]
},
{
"cell_type": "markdown",
"id": "aef4eaf2-fff5-4f83-8ade-2367a2513aa8",
"metadata": {},
"source": [
"##### Load Fine-Tuned Torch LM Results"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "26d15d8e-81cc-4cd5-a2dc-765325cb5a55",
"metadata": {},
"outputs": [],
"source": [
"# # ls llama 3 8b\n",
"# with open(\"ls-Meta-Llama-3-8B-msp-v2-mdace-20_raw_labels.pkl\", \"rb\") as f:\n",
"# ls_llama_8b_last_labels = pickle.load(f)\n",
"# with open(\"ls-Meta-Llama-3-8B-msp-v2-mdace-20_scores.pkl\", \"rb\") as f:\n",
"# ls_llama_8b_last_scores = pickle.load(f)\n",
"\n",
"# ls_llama_8b_last_scores_transformed = torch.sigmoid(torch.tensor(ls_llama_8b_last_scores))\n",
"# all_model_names.append(\"LS Llama-3 8B (Last)\")\n",
"# all_y_trues.append(ls_llama_8b_last_labels)\n",
"# all_y_scores.append(ls_llama_8b_last_scores_transformed)\n",
"# all_colors.append('#ab20fd')\n",
"\n",
"# ls unllama 3 8b\n",
"with open(\"ls-unllama-Meta-Llama-3-8B-msp-v2-mdace-20_raw_labels.pkl\", \"rb\") as f:\n",
" ls_unllama_8b_max_labels = pickle.load(f)\n",
"with open(\"ls-unllama-Meta-Llama-3-8B-msp-v2-mdace-20_raw_scores.pkl\", \"rb\") as f:\n",
" ls_unllama_8b_max_scores = pickle.load(f)\n",
"\n",
"ls_unllama_8b_max_scores_transformed = torch.sigmoid(torch.tensor(ls_unllama_8b_max_scores)).numpy()\n",
"all_model_names.append(\"LS UnLlama-3 8B (Max)\")\n",
"all_y_trues.append(ls_unllama_8b_max_labels)\n",
"all_y_scores.append(ls_unllama_8b_max_scores_transformed)\n",
"\n",
"# BELT Max 5 segments\n",
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_128_max_5_labels.pkl\", \"rb\") as f:\n",
" belt_5_max_labels = pickle.load(f)\n",
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_128_max_5_scores.pkl\", \"rb\") as f:\n",
" belt_5_max_scores = pickle.load(f)\n",
"\n",
"all_model_names.append(\"BELT 128 step 5 seg (Max)\")\n",
"all_y_trues.append(belt_5_max_labels)\n",
"all_y_scores.append(belt_5_max_scores)\n",
"\n",
"# BELT Max 128 segments\n",
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_448_max_128_labels.pkl\", \"rb\") as f:\n",
" belt_128_max_labels = pickle.load(f)\n",
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_448_max_128_scores.pkl\", \"rb\") as f:\n",
" belt_128_max_scores = pickle.load(f)\n",
"\n",
"all_model_names.append(\"BELT 448 step 128 seg (Max)\")\n",
"all_y_trues.append(belt_128_max_labels)\n",
"all_y_scores.append(belt_128_max_scores)"
]
},
{
"cell_type": "markdown",
"id": "6f21cd42",
"metadata": {},
"source": [
"##### Print Performance for all Metrics for all Models"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2665e9ce-07c4-4196-9fe7-7f911123b8f9",
"metadata": {},
"outputs": [],
"source": [
"def print_mean_ci_of_metric_list(metric_list, metric_name, num_std):\n",
" mean_metric = np.mean(metric_list)\n",
" std_metric = np.std(metric_list)\n",
" metric_low = np.maximum(mean_metric - std_metric * num_std, 0)\n",
" metric_high = np.minimum(mean_metric + std_metric * num_std, 1)\n",
"\n",
" print(\n",
" f\"{metric_name}: {round(mean_metric, 3)} ([{round(metric_low, 3)} - {round(metric_high, 3)}] 95% CI)\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4d6ff4a8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Results for LS UnLlama-3 8B (Max)\n",
"\n",
"Micro Average Precision: 0.277 ([0.256 - 0.299] 95% CI)\n",
"Micro ROC AUC: 0.828 ([0.818 - 0.839] 95% CI)\n",
"\n",
"Results for BELT 128 step 5 seg (Max)\n",
"\n",
"Micro Average Precision: 0.707 ([0.698 - 0.716] 95% CI)\n",
"Micro ROC AUC: 0.942 ([0.94 - 0.944] 95% CI)\n",
"\n",
"Results for BELT 448 step 128 seg (Max)\n",
"\n",
"Micro Average Precision: 0.804 ([0.797 - 0.812] 95% CI)\n",
"Micro ROC AUC: 0.971 ([0.969 - 0.972] 95% CI)\n"
]
}
],
"source": [
"model2metric_df = {}\n",
"for y_trues, y_scores, name in zip(\n",
" all_y_trues, all_y_scores, all_model_names\n",
"):\n",
" \n",
" micro_aps, macro_aps, micro_roc_aucs, macro_roc_aucs = [], [], [], []\n",
" for i in range(num_bootstrap):\n",
" \n",
" # Sample N records with replacement where N is the total number of records\n",
" sample_indices = np.random.choice(len(y_trues), len(y_trues))\n",
" sample_labels = np.array(y_trues)[sample_indices]\n",
" sample_scores = np.array(y_scores)[sample_indices]\n",
" \n",
" micro_ap = metrics.average_precision_score(y_true=sample_labels, y_score=sample_scores, average='micro')\n",
" micro_aps.append(micro_ap)\n",
"\n",
" # macro_ap = metrics.average_precision_score(y_true=sample_labels, y_score=sample_scores, average='macro')\n",
" # macro_aps.append(macro_ap)\n",
"\n",
" micro_roc_auc = metrics.roc_auc_score(y_true=sample_labels, y_score=sample_scores, average='micro')\n",
" micro_roc_aucs.append(micro_roc_auc)\n",
"\n",
" # macro_roc_auc = metrics.roc_auc_score(y_true=sample_labels, y_score=sample_scores, average='macro')\n",
" # macro_roc_aucs.append(macro_roc_auc)\n",
" \n",
" metric_df = pd.DataFrame({\n",
" \"micro_aps\": micro_aps,\n",
" \"micro_roc_aucs\": micro_roc_aucs,\n",
" })\n",
" model2metric_df[name] = metric_df\n",
"\n",
" print(f\"\\nResults for {name}\\n\")\n",
" print_mean_ci_of_metric_list(micro_aps, metric_name=\"Micro Average Precision\", num_std=num_std)\n",
" print_mean_ci_of_metric_list(micro_roc_aucs, metric_name=\"Micro ROC AUC\", num_std=num_std)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10 - SDK v2",
"language": "python",
"name": "python310-sdkv2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading