Skip to content

Commit

Permalink
使用tiktoken精确计数输入token
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Mar 13, 2023
1 parent 893df38 commit 9c45970
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
gradio
mdtex2html
pypinyin
jieba
tiktoken
socksio
tqdm
colorama
25 changes: 15 additions & 10 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import mdtex2html
from pypinyin import lazy_pinyin
from presets import *
import jieba
import tiktoken
from tqdm import tqdm
import colorama

if TYPE_CHECKING:
from typing import TypedDict
Expand Down Expand Up @@ -47,11 +48,12 @@ def postprocess(
)
return y

def count_words(input_str):
print("计算输入字数中……")
words = jieba.lcut(input_str)
def count_token(input_str):
print("计算输入Token计数中……")
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
length = len(encoding.encode("tiktoken is great!"))
print("计算完成!")
return len(words)
return length

def parse_text(text):
lines = text.split("\n")
Expand Down Expand Up @@ -97,8 +99,7 @@ def construct_assistant(text):
return construct_text("assistant", text)

def construct_token_message(token, stream=False):
extra = "【粗略计数(因为实时传输回答)】 " if stream else ""
return f"{extra}Token 计数: {token}"
return f"Token 计数: {token}"

def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream):
headers = {
Expand Down Expand Up @@ -135,10 +136,12 @@ def get_return_value():
counter = 0
status_text = "开始实时传输回答……"
history.append(construct_user(inputs))
user_token_count = 0
if len(previous_token_count) == 0:
rough_user_token_count = count_words(inputs) + count_words(system_prompt)
user_token_count = count_token(inputs) + count_token(system_prompt)
else:
rough_user_token_count = count_words(inputs)
user_token_count = count_token(inputs)
print(f"输入token计数: {user_token_count}")
try:
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
except requests.exceptions.ConnectTimeout:
Expand All @@ -162,7 +165,7 @@ def get_return_value():
# decode each line as response data is in bytes
if chunklength > 6 and "delta" in chunk['choices'][0]:
finish_reason = chunk['choices'][0]['finish_reason']
status_text = construct_token_message(sum(previous_token_count)+token_counter+rough_user_token_count, stream=True)
status_text = construct_token_message(sum(previous_token_count)+token_counter+user_token_count, stream=True)
if finish_reason == "stop":
print("生成完毕")
yield get_return_value()
Expand Down Expand Up @@ -197,6 +200,7 @@ def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, previou


def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=False, should_check_token_count = True): # repetition_penalty, top_k
print(colorama.Fore.BLUE + f"输入为:{inputs}" + colorama.Style.RESET_ALL)
if stream:
print("使用流式传输")
iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
Expand All @@ -207,6 +211,7 @@ def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count
chatbot, history, status_text, token_count = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature)
yield chatbot, history, status_text, token_count
print(f"传输完毕。当前token计数为{token_count}")
print(colorama.Fore.BLUE + f"回答为:{history[-1]['content']}" + colorama.Style.RESET_ALL)
if stream:
max_token = max_token_streaming
else:
Expand Down

0 comments on commit 9c45970

Please sign in to comment.