From 4d33546197a57a72c6e11dd896e064eaff6b8d73 Mon Sep 17 00:00:00 2001 From: KirbytroNic0528 <133230754+KirbytroNic0528@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:11:06 +0800 Subject: [PATCH] Update gettest.py --- scripts/gettest.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/scripts/gettest.py b/scripts/gettest.py index 34fc1c059..c781344b3 100644 --- a/scripts/gettest.py +++ b/scripts/gettest.py @@ -6,24 +6,21 @@ def generate_test_code(moonbit_code, path, api_key): filename = os.path.basename(path) - test_prompt = ChatPromptTemplate.from_template( - """As a MoonBit language engineer, your task is to write a series of test cases to - verify the correctness of a project. - Based on the provided format and the understanding of the function's purpose - from the filename {filename}, write corresponding test cases for the given MoonBit function - including edge cases and any potential error scenarios: - test {{ - assert_eq!(f(x)) - assert_eq!(f(x)) - }} - Provide test cases for the MoonBit function given as {moonbit_code}. - Note that your output should only contain the code for the test cases, - without any analysis, explanations, or any other statements. - Also, ensure that you are generating test cases for the MoonBit language, - and do not confuse MoonBit language with any other - Attention, just generate a test function and the function must not exceed 1500 bytes . - """ - ) + test_prompt = ChatPromptTemplate.from_messages([ + ("system", "As a MoonBit language engineer, your task is to write a series of test cases to verify the correctness of a project."), + ("system", "Based on the provided format and the understanding of the function's purpose from the filename {filename}, write corresponding test cases for the given MoonBit function including edge cases and any potential error scenarios:"), + ("system", "The test cases should be formatted as follows:"), + ("system", "test {{"), + ("system", " assert_eq!(f(x))"), + ("system", " assert_eq!(f(x))"), + ("system", "}}"), + ("system", "Provide test cases for the MoonBit function given as moonbit."), + ("system", "Note that your output should only contain the code for the test cases, without any analysis, explanations, or any other statements."), + ("system", "Also, ensure that you are generating test cases for the MoonBit language, and do not confuse MoonBit language with any other."), + ("system","Attention,you don't need to provide the results of the assertions. "), + ("user", "{moonbit}"), + ]) + test_llm = ChatZhipuAI( api_key=api_key, model="glm-4-9b:772570335:v3:odbzuhb9", temperature=0.5, max_tokens=4095 @@ -31,7 +28,7 @@ def generate_test_code(moonbit_code, path, api_key): test_retriever_chain = test_prompt | test_llm | StrOutputParser() test_code_output = test_retriever_chain.invoke( - {"moonbit_code": moonbit_code, "filename": filename} + {"moonbit": moonbit, "filename": filename} ) test_code = test_code_output.replace("```moonbit\n", "").rstrip( "```"