-
Notifications
You must be signed in to change notification settings - Fork 0
/
play.py
78 lines (63 loc) · 2.74 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from transformers import AutoTokenizer, AutoModelForCausalLM
from convokit import Corpus, download, Conversation
from lm import LanguageModel
from toxicity.detoxify_reddit import filter_corpus_toxicity, jsonl_to_dict
from toxicity.reddit_data_helpers import filter_corpus_formatting, clean_utterance
from environment import episode
import torch
import random
checkpoint = "./models/TL_v_TL_beta_1e-1_best"
# checkpoint = "/home/houjun/FineGrainedLLMDetox/sft_out/checkpoint-16000"
# checkpoint = "/home/houjun/FineGrainedLLMDetox/sft_out/checkpoint-500"
base = "TinyLlama/TinyLlama_v1.1"
defender = "TinyLlama/TinyLlama_v1.1"
# load our initial corpus ahead of time
corpus = Corpus(filename=download("reddit-corpus-small"))
id2results = jsonl_to_dict('detox_results.jsonl')
corpus = filter_corpus_toxicity(corpus, id2results, {"toxicity": 0.5})
corpus = filter_corpus_formatting(corpus)
convos = list(corpus.conversations.values())
# we only keep the last five utterances (and also discard the front
# because the front is the self-post on reddit)
prompts = [[clean_utterance(j.text)
for j in list(i.iter_utterances())
if j.text.strip() != "[deleted]"
and j.text.strip() != ""][1:][-2:]
for i in convos]
prompts = [[j+" " for j in i if j.strip() != ""]
for i in prompts]
prompts = [i for i in prompts if len(i) != 0]
model = AutoModelForCausalLM.from_pretrained(checkpoint)
model_base = AutoModelForCausalLM.from_pretrained(base)
# model_defender = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", attn_implementation="flash_attention_2", load_in_4bit=True, torch_dtype=torch.float16)
model_defender = AutoModelForCausalLM.from_pretrained(defender)
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama_v1.1")
tokenizer_defender = AutoTokenizer.from_pretrained(defender)
adversary = LanguageModel(dont_init=True)
adversary.model = model
adversary.tokenizer = tokenizer
base = LanguageModel(dont_init=True)
base.model = model_base
base.tokenizer = tokenizer
defender = LanguageModel(dont_init=True)
defender.model = model_defender
defender.tokenizer = tokenizer_defender
while True:
prompt = []
r = None
while r != 'q':
r = input("> ").strip()
if r == "redditme":
prompt = random.choice(prompts)
print("==== PROMPT ====")
print(" ".join(prompt))
break
if r != "q":
prompt.append(r)
ro_policy = episode(adversary, defender, prompt, horizon=5, return_sequence=True)
ro_base = episode(base, defender, prompt, horizon=5, return_sequence=True)
print("==== POLICY ====")
print("".join("["+i+"] " for i in ro_policy))
print("==== BASE ====")
print("".join("["+i+"] " for i in ro_base))
breakpoint()