forked from tomaarsen/attention_sinks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
perplexity.py
170 lines (143 loc) · 6.17 KB
/
perplexity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Adapted from https://github.com/mit-han-lab/streaming-llm
Note: Although this script measures latency, it is not optimized whatsoever!
The latency is only tracked to see the impact of speed over time.
Usage:
python benchmark/perplexity.py --experiment attention_sinks
python benchmark/perplexity.py --experiment transformers
python benchmark/perplexity.py --experiment windowed
"""
import argparse
import itertools
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import pandas as pd
import torch
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import AutoTokenizer
def compute_perplexity(
model,
tokenizer,
dataset,
experiment: str,
output_dir: str = "outputs",
data_column: str = "text",
num_samples: int = 1,
num_tokens: Optional[int] = None,
overwrite: bool = False,
) -> None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / f"{experiment}.csv"
if output_file.exists() and not overwrite:
raise ValueError(
f"The {output_file!r} output file already exists - if you really want to override it, then use `--overwrite`."
)
logs = defaultdict(list)
loss_fn = CrossEntropyLoss(reduction="none")
past_key_values = None
num_processed_tokens = 0
for text in itertools.islice(dataset, num_samples):
encodings = tokenizer(text[data_column], return_tensors="pt")
seq_len = encodings.input_ids.size(1)
print(f"sequence length: {seq_len}")
pbar = tqdm(range(0, seq_len - 1))
for idx in pbar:
start_t = time.time()
input_ids = encodings.input_ids[:, idx : idx + 1].to(model.device)
with torch.no_grad():
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
logits = outputs.logits.view(-1, model.config.vocab_size)
past_key_values = outputs.past_key_values
label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1)
neg_log_likelihood = loss_fn(logits, label)
perplexity = neg_log_likelihood.exp()
pbar.set_description(f"nll: {neg_log_likelihood.item():>5.2f}, ppl: {perplexity.item():>8.2f}")
# Store data and save every 10 tokens
logs["input_length"].append(idx + 1)
logs["nll"].append(neg_log_likelihood.item())
logs["ppl"].append(perplexity.item())
logs["overall_ppl"].append(torch.tensor(logs["nll"]).mean().exp().item())
logs["cuda_vram_allocated"].append(torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024) # in GB
logs["latency"].append(time.time() - start_t)
if num_processed_tokens % 10 == 0:
try:
pd.DataFrame(logs).to_csv(output_file, index=False)
except KeyboardInterrupt as ex:
# If there's a Keyboard Interrupt, still write the file, and then stop
pd.DataFrame(logs).to_csv(output_file, index=False)
raise ex
num_processed_tokens += 1
if num_tokens and num_processed_tokens >= num_tokens:
return
def main():
parser = argparse.ArgumentParser()
# Which experiment to run?
parser.add_argument(
"--experiment", choices=["attention_sinks", "transformers", "windowed"], default="attention_sinks"
)
# Model args
parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Llama-2-7b-hf")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument("--trust_remote_code", action="store_true")
# Dataset args
parser.add_argument("--dataset_name", type=str, default="emozilla/pg19-test")
parser.add_argument("--data_column", type=str, default="text")
parser.add_argument("--task", type=str, default=None)
parser.add_argument("--split", type=str, default="test", choices=["validation", "test"])
# parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--num_tokens", type=int, default=8192)
# Where to log
parser.add_argument("--output_dir", type=str, default="benchmark/outputs")
parser.add_argument("--overwrite", action="store_true")
# Window size for windowed and attention_sinks
parser.add_argument("--window_size", type=int, default=1024)
# Attention Sinks-only settings
# Attention Sink window size is calculated with args.window_size - args.attention_sink_size
parser.add_argument("--attention_sink_size", type=int, default=4)
args = parser.parse_args()
# Initialize the model, either via transformers or via attention_sinks
if args.experiment == "transformers":
from transformers import AutoModelForCausalLM
else:
from attention_sinks import AutoModelForCausalLM
kwargs = {}
if args.experiment == "attention_sinks":
kwargs = {
"attention_sink_size": args.attention_sink_size,
"attention_sink_window_size": args.window_size - args.attention_sink_size, # default: 1020
}
elif args.experiment == "windowed":
kwargs = {
"attention_sink_size": 0,
"attention_sink_window_size": args.window_size,
}
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
revision=args.revision,
trust_remote_code=bool(args.trust_remote_code),
torch_dtype=torch.float16,
device_map="auto",
**kwargs,
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code))
# Set up the dataset
dataset = load_dataset(args.dataset_name, args.task, split=args.split, streaming=True)
compute_perplexity(
model,
tokenizer,
dataset,
args.experiment,
output_dir=args.output_dir,
data_column=args.data_column,
num_samples=1, # <- No support for more than one instance now
num_tokens=args.num_tokens,
overwrite=args.overwrite,
)
if __name__ == "__main__":
main()