-
Notifications
You must be signed in to change notification settings - Fork 0
/
gradio-gguf-chat_llama2.py
144 lines (139 loc) · 6.26 KB
/
gradio-gguf-chat_llama2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from llama_cpp import Llama #llama.cppのPythonライブラリであるllama-cpp-python
import gradio as gr #Gradio AIの試作で広く使われているブラウザGUIライブラリ
import re #テキスト整形のライブラリ
import datetime #時刻取得ライブラリ
model="hoge_huga.gguf" #対象のモデルのパスを入力。
llm = Llama(
model_path=model,
n_gpu_layers=-1, # #GPUにロードするレイヤー数(llama-cpp-pythonがcuda版の場合)
n_ctx=4096, # 最大コンテキストサイズ。入力の上限。
flash_attn=True,
last_n_tokens_size =0, # Maximum number of tokens to keep in the last_n_tokens deque.
)
role = "[INST] <<SYS>>\n\
あなたは優秀な日本語を話すチャットボットアシスタントです。\n\
現在の日付・曜日・時刻はそれぞれ{day}です。\n\
回答に確信が持てない質問、または自信を持って答えるのに十分な情報が無い質問は回答を拒否してください。\n\
<</SYS>>"
history = ""
output_history =""
# AIに質問する関数
def complement(role,prompt,turn_config):
global history,output_history
day = (str(datetime.datetime.now().year)\
+"年"+str(datetime.datetime.now().month)\
+"月"+str(datetime.datetime.now().day)\
+"日"+str(datetime.datetime.now().strftime(" %a "))\
+str(datetime.datetime.now().hour)\
+"時"+str(datetime.datetime.now().minute)+"分")\
.replace("Sun", "日曜日").replace("Mon", "月曜日").replace("Tue", "火曜日").replace("Wed", "水曜日")\
.replace("Thu", "木曜日").replace("Fri", "金曜日").replace("Sat", "土曜日")
role = role.replace("{day}",day)
role += "\n\n"
if prompt !="":
prompt_C2L = (role+history + "USER: "+prompt+"\nASSISTANT: ")\
.replace("\nASSISTANT: ", " [/INST]").replace("<|endoftext|>\n", "</s>").replace("</s>\nUSER: ", "</s><s>[INST] ").replace("USER: ", "<s>")
output = llm(
prompt=prompt_C2L, # 元々calm2-7b-chat用に作ったプログラムなのでここで整形。
max_tokens=1024,
temperature = 0.8,
top_p=0.95,
min_p=0.05,
typical_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
repeat_penalty=1.1,
top_k=40,
seed=-1,
tfs_z=1.0,
mirostat_mode=0,
mirostat_tau=5.0,
mirostat_eta=0.1,
stop=["<</SYS>>","[INST]","[/INST]","</SYS>","prompt_tokens"] # ストップ。特定の文字を生成したらその文字を生成せず停止する。
)
output= output["choices"][0]["text"]
output =output.replace("\\n", "\n").replace("\\u3000", "\u3000").replace("!","!").replace("?","?")
while output[-1]=="\n":
output=output[:-1]
while output[0]=="\n":
output=output[1:]
print( prompt_C2L+output+"</s>")
history =history +"USER: "+prompt +"\nASSISTANT: " + output+"<|endoftext|>\n"
turn = re.split(r'(?=USER: )', history)
del turn[0:1]
output_history =''.join(turn)
output_history = output_history.replace("<|endoftext|>", '')
turn_count = len(turn)
if turn_count > turn_config:
del turn[0:turn_count - int(turn_config)]
history =''.join(turn)
output_history =''.join(turn)
output_history = output_history.replace("<|endoftext|>", '')
turn_count = len(turn)
if prompt =="":
if history =="":
output=""
if history !="":
output=re.split(r'(?=USER: |ASSISTANT: )', output_history)
output = (output[len(output)-1]).replace("ASSISTANT: ","")
turn = re.split(r'(?=USER: )', history)
del turn[0:1]
output_history =''.join(turn)
output_history = output_history.replace("<|endoftext|>", '')
turn_count = len(turn)
return output, output_history
# 履歴リセット関数
def hist_rst():
global history
prompt=""
output=""
history=""
output_history=""
return prompt, output, output_history
# 会話Undo関数
def undo():
global history
turn = re.split(r'(?=USER: )', history)
output=re.split(r'(?=USER: |ASSISTANT: )', history)
del turn[0:1]
del output[0:1]
if len(turn)>=2:
prompt= output[len(output)-4]
output= output[len(output)-3]
prompt= prompt.replace("USER: ", '')
output=output.replace("<|endoftext|>", '').replace("ASSISTANT: ", '')
if len(turn)<2:
prompt=""
output=""
del turn[len(turn)-1:len(turn)]
history =''.join(turn)
output_history =''.join(turn)
output_history = output_history.replace("<|endoftext|>", '')
return prompt, output,output_history
# Blocksの作成
with gr.Blocks(title="チャットボット",theme=gr.themes.Base(primary_hue="orange", secondary_hue="blue")) as demo:
# コンポーネント
gr.Markdown(
"""
チャットボット(llama2-chat)
""")
# UI
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(lines=2,label="質問入力")
output = gr.Textbox(label="回答出力")
greet_btn = gr.Button(value="送信",variant='primary')
with gr.Accordion(label="会話履歴設定", open=False ):
with gr.Accordion(label="システムプロンプト", open=False):
role = gr.Textbox(lines=26,label="Llama2-Chatのチャットテンプレートで書いてください", value=role)
turn_config = gr.Number(label="会話ターン数設定",value=10,minimum=1,maximum=20)
disphist = gr.Textbox(lines=10,label="会話履歴出力")
undo_btn = gr.Button(value="1ターン戻す",variant='secondary')
reset_btn = gr.Button(value="履歴リセット",variant='secondary')
# イベントハンドラー
greet_btn.click(fn=complement, inputs=[role,prompt,turn_config], outputs=[output,disphist])
undo_btn.click(fn=undo, outputs=[prompt,output, disphist])
reset_btn.click(fn=hist_rst, outputs=[prompt,output, disphist])
#demo.launch(auth=("XXXX","YYYY"),share=True, server_port=7860,show_api=False)
#demo.launch(server_name="192.168.x.xxx", server_port=7860,show_api=False)
demo.launch(show_api=False)