diff --git a/server/chatbot.py b/server/chatbot.py index 29c3c51..aad8059 100644 --- a/server/chatbot.py +++ b/server/chatbot.py @@ -35,12 +35,13 @@ def ask(self, user_request: str) -> dict: } """ prompt = self.prompt.construct_prompt(user_request) + completion = openai.Completion.create( engine="text-davinci-003", prompt=prompt, temperature=self.temprature, max_tokens=1024, - stop=["\n\n\n"], + stop=["\n"], ) if completion.get("choices") is None: raise Exception("ChatGPT API returned no choices") @@ -56,11 +57,12 @@ def ask(self, user_request: str) -> dict: self.prompt.add_to_chat_history( "You: " + user_request - + "\n\n\n" + + "\n" + "ChatGPT: " + completion["choices"][0]["text"] - + "\n\n\n", + + "\n", ) + print(self.prompt.history()) return completion def rollback(self, num: int) -> None: diff --git a/server/prompt.py b/server/prompt.py index 7396865..2f1dc5f 100644 --- a/server/prompt.py +++ b/server/prompt.py @@ -24,7 +24,7 @@ def history(self) -> str: """ Return chat history """ - return "\n\n\n\n".join(self.chat_history) + return "\n".join(self.chat_history) def construct_prompt(self, new_prompt: str) -> str: """ diff --git a/server/server.py b/server/server.py index 613bc5f..e554521 100644 --- a/server/server.py +++ b/server/server.py @@ -20,6 +20,8 @@ quit(0) app = Flask(__name__) +chatbot = Chatbot(api_key=OPEN_AI_KEY, + temprature=TEMPRATURE, base_prompt=base_prompt) @app.route("/chat", methods=["GET"]) @cross_origin() @@ -50,15 +52,14 @@ def chatbot_commands(cmd: str) -> bool: user_request = request.args.get('q') # decode the `q` parameter from UTF-8 encoding user_request = urllib.parse.unquote(user_request) - chatbot = Chatbot(api_key=OPEN_AI_KEY, - temprature=TEMPRATURE, base_prompt=base_prompt) + # Start chat PROMPT = user_request if PROMPT.startswith("!"): if chatbot_commands(PROMPT): print("continue") response = chatbot.ask(PROMPT) - print("ChatGPT: " + response["choices"][0]["text"]) + # print("ChatGPT: " + response["choices"][0]["text"]) message = response["choices"][0]["text"] message = message.replace("\n\n", "") # print(message)