forked from PaulPauls/llama3_interpretability_sae
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathllama_3_inference_text_completion_test.py
132 lines (112 loc) · 4.4 KB
/
llama_3_inference_text_completion_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import argparse
import logging
import os
from pathlib import Path
import torch
from llama_3_inference import Llama3Inference
from sae import load_sae_model
from utils.cuda_utils import set_torch_seed_for_inference
def parse_arguments() -> argparse.Namespace:
""""""
parser = argparse.ArgumentParser()
parser.add_argument("--llama_model_dir", type=Path, required=True)
parser.add_argument("--sae_model_path", type=Path, default=None)
return parser.parse_args()
def main() -> None:
""""""
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Parse arguments and set up paths
args = parse_arguments()
args.llama_model_dir = args.llama_model_dir.resolve()
llama_tokenizer_path = args.llama_model_dir / "tokenizer.model"
llama_params_path = args.llama_model_dir / "params.json"
llama_model_path = args.llama_model_dir / "consolidated.00.pth"
if args.sae_model_path is not None:
args.sae_model_path = args.sae_model_path.resolve()
# Set up configuration
max_new_tokens = 128
temperature = 0.7
top_p = 0.9
seed = 42
sae_layer_idx = None
sae_h_bias = None
sae_top_k = 64
sae_normalization_eps = 1e-6
sae_dtype = torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info("#### Starting sample Llama3 Text Completion")
logging.info("#### Arguments:")
logging.info(f"# llama_model_dir={args.llama_model_dir}")
logging.info(f"# sae_model_path={args.sae_model_path}")
logging.info("#### Configuration:")
logging.info(f"# max_new_tokens={max_new_tokens}")
logging.info(f"# temperature={temperature}")
logging.info(f"# top_p={top_p}")
logging.info(f"# seed={seed}")
logging.info(f"# sae_layer_idx={sae_layer_idx}")
logging.info(f"# sae_h_bias={sae_h_bias}")
logging.info(f"# sae_top_k={sae_top_k}")
logging.info(f"# sae_normalization_eps={sae_normalization_eps}")
logging.info(f"# sae_dtype={sae_dtype}")
logging.info(f"# device={device}")
# Set up CUDA and seed for inference
set_torch_seed_for_inference(seed)
# Load the SAE model if provided and set up the forward fn for the specified sae_layer_idx
sae_layer_forward_fn = None
if args.sae_model_path is not None:
assert sae_layer_idx is not None
sae_model = load_sae_model(
model_path=args.sae_model_path,
sae_top_k=sae_top_k,
sae_normalization_eps=sae_normalization_eps,
device=device,
dtype=sae_dtype,
)
sae_layer_forward_fn = {sae_layer_idx: sae_model.forward}
if sae_h_bias is not None:
logging.info("Setting SAE h_bias...")
h_bias = torch.zeros(sae_model.n_latents)
h_bias[sae_h_bias[0]] = sae_h_bias[1]
h_bias = h_bias.to(sae_dtype).to(device)
sae_model.set_latent_bias(h_bias)
# Initialize the Llama3Inferenence generator
llama_inference = Llama3Inference(
tokenizer_path=llama_tokenizer_path,
params_path=llama_params_path,
model_path=llama_model_path,
device=device,
sae_layer_forward_fn=sae_layer_forward_fn,
)
# Prepare batch for text completion
logging.info("Generating sample text completions...")
text_prompts = [
"Once upon a time, in a land far, far away",
"The quick brown fox jumps over",
"In the year 2050, technology had advanced to the point where",
"The secret to happiness is",
]
# Generate text completions and print results iteratively
text_completions = [""] * len(text_prompts)
for next_tokens_text in llama_inference.generate_text_completions(
prompts=text_prompts,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
):
# Clear the console for a more 'commercial LLM web UI' feel
os.system("clear")
# Update each completion with the new tokens (or initial minimal sequence) and print
for i, new_token in enumerate(next_tokens_text):
text_completions[i] += new_token
# Print current state
print(f"#### Text Completion {i + 1}: ".ljust(80, "#"))
print(text_completions[i])
print("#" * 80)
logging.info("#### FIN!")
if __name__ == "__main__":
main()