diff --git a/requirements.txt b/requirements.txt index 53f650b9..029abc84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ gradio mdtex2html pypinyin -jieba +tiktoken socksio tqdm +colorama diff --git a/utils.py b/utils.py index 2099f153..2abe097d 100644 --- a/utils.py +++ b/utils.py @@ -12,8 +12,9 @@ import mdtex2html from pypinyin import lazy_pinyin from presets import * -import jieba +import tiktoken from tqdm import tqdm +import colorama if TYPE_CHECKING: from typing import TypedDict @@ -47,11 +48,12 @@ def postprocess( ) return y -def count_words(input_str): - print("计算输入字数中……") - words = jieba.lcut(input_str) +def count_token(input_str): + print("计算输入Token计数中……") + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + length = len(encoding.encode("tiktoken is great!")) print("计算完成!") - return len(words) + return length def parse_text(text): lines = text.split("\n") @@ -97,8 +99,7 @@ def construct_assistant(text): return construct_text("assistant", text) def construct_token_message(token, stream=False): - extra = "【粗略计数(因为实时传输回答)】 " if stream else "" - return f"{extra}Token 计数: {token}" + return f"Token 计数: {token}" def get_response(openai_api_key, system_prompt, history, temperature, top_p, stream): headers = { @@ -135,10 +136,12 @@ def get_return_value(): counter = 0 status_text = "开始实时传输回答……" history.append(construct_user(inputs)) + user_token_count = 0 if len(previous_token_count) == 0: - rough_user_token_count = count_words(inputs) + count_words(system_prompt) + user_token_count = count_token(inputs) + count_token(system_prompt) else: - rough_user_token_count = count_words(inputs) + user_token_count = count_token(inputs) + print(f"输入token计数: {user_token_count}") try: response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True) except requests.exceptions.ConnectTimeout: @@ -162,7 +165,7 @@ def get_return_value(): # decode each line as response data is in bytes if chunklength > 6 and "delta" in chunk['choices'][0]: finish_reason = chunk['choices'][0]['finish_reason'] - status_text = construct_token_message(sum(previous_token_count)+token_counter+rough_user_token_count, stream=True) + status_text = construct_token_message(sum(previous_token_count)+token_counter+user_token_count, stream=True) if finish_reason == "stop": print("生成完毕") yield get_return_value() @@ -197,6 +200,7 @@ def predict_all(openai_api_key, system_prompt, history, inputs, chatbot, previou def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature, stream=False, should_check_token_count = True): # repetition_penalty, top_k + print(colorama.Fore.BLUE + f"输入为:{inputs}" + colorama.Style.RESET_ALL) if stream: print("使用流式传输") iter = stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature) @@ -207,6 +211,7 @@ def predict(openai_api_key, system_prompt, history, inputs, chatbot, token_count chatbot, history, status_text, token_count = predict_all(openai_api_key, system_prompt, history, inputs, chatbot, token_count, top_p, temperature) yield chatbot, history, status_text, token_count print(f"传输完毕。当前token计数为{token_count}") + print(colorama.Fore.BLUE + f"回答为:{history[-1]['content']}" + colorama.Style.RESET_ALL) if stream: max_token = max_token_streaming else: