diff --git a/app/domain/data/text_exercise_dto.py b/app/domain/data/text_exercise_dto.py index 7040b18..3fd1533 100644 --- a/app/domain/data/text_exercise_dto.py +++ b/app/domain/data/text_exercise_dto.py @@ -8,7 +8,7 @@ class TextExerciseDTO(BaseModel): id: int - name: str + title: str course: CourseDTO problem_statement: str = Field(alias="problemStatement") start_date: Optional[datetime] = Field(alias="startDate", default=None) diff --git a/app/domain/status/text_exercise_chat_status_update_dto.py b/app/domain/status/text_exercise_chat_status_update_dto.py new file mode 100644 index 0000000..dd063ff --- /dev/null +++ b/app/domain/status/text_exercise_chat_status_update_dto.py @@ -0,0 +1,5 @@ +from app.domain.status.status_update_dto import StatusUpdateDTO + + +class TextExerciseChatStatusUpdateDTO(StatusUpdateDTO): + result: str = [] diff --git a/app/domain/text_exercise_chat_pipeline_execution_dto.py b/app/domain/text_exercise_chat_pipeline_execution_dto.py index 03ff7c1..efae1ad 100644 --- a/app/domain/text_exercise_chat_pipeline_execution_dto.py +++ b/app/domain/text_exercise_chat_pipeline_execution_dto.py @@ -1,10 +1,11 @@ from pydantic import BaseModel, Field -from domain import PipelineExecutionDTO +from domain import PipelineExecutionDTO, PyrisMessage from domain.data.text_exercise_dto import TextExerciseDTO class TextExerciseChatPipelineExecutionDTO(BaseModel): execution: PipelineExecutionDTO exercise: TextExerciseDTO - current_answer: str = Field(alias="currentAnswer") + conversation: list[PyrisMessage] = Field(default=[]) + current_submission: str = Field(alias="currentSubmission", default="") diff --git a/app/pipeline/prompts/text_exercise_chat_prompts.py b/app/pipeline/prompts/text_exercise_chat_prompts.py index 390cb95..cf49924 100644 --- a/app/pipeline/prompts/text_exercise_chat_prompts.py +++ b/app/pipeline/prompts/text_exercise_chat_prompts.py @@ -1,4 +1,30 @@ -def system_prompt( +def fmt_guard_prompt( + exercise_name: str, + course_name: str, + course_description: str, + problem_statement: str, + user_input: str, +) -> str: + return """ + You check whether a user's input is on-topic and appropriate discourse in the context of a writing exercise. + The exercise is called '{exercise_name}' in the course '{course_name}'. + The course has the following description: + {course_description} + The exercise has the following problem statement: + {problem_statement} + The user says: + {user_input} + If this is on-topic and appropriate discussion, respond with "Yes". If not, respond with "No". + """.format( + exercise_name=exercise_name, + course_name=course_name, + course_description=course_description, + problem_statement=problem_statement, + user_input=user_input, + ) + + +def fmt_system_prompt( exercise_name: str, course_name: str, course_description: str, @@ -6,7 +32,7 @@ def system_prompt( start_date: str, end_date: str, current_date: str, - current_answer: str, + current_submission: str, ) -> str: return """ The student is working on a free-response exercise called '{exercise_name}' in the course '{course_name}'. @@ -19,7 +45,7 @@ def system_prompt( The exercise began on {start_date} and will end on {end_date}. The current date is {current_date}. This is what the student has written so far: - {current_answer} + {current_submission} You are a writing tutor. Provide feedback to the student on their response, giving specific tips to better answer the problem statement. @@ -31,5 +57,31 @@ def system_prompt( start_date=start_date, end_date=end_date, current_date=current_date, - current_answer=current_answer, + current_submission=current_submission, + ) + + +def fmt_rejection_prompt( + exercise_name: str, + course_name: str, + course_description: str, + problem_statement: str, + user_input: str, +) -> str: + return """ + The user is working on a free-response exercise called '{exercise_name}' in the course '{course_name}'. + The course has the following description: + {course_description} + The exercise has the following problem statement: + {problem_statement} + The user has asked the following question: + {user_input} + The question is off-topic or inappropriate. + Briefly explain that you cannot help with their query, and prompt them to focus on the exercise. + """.format( + exercise_name=exercise_name, + course_name=course_name, + course_description=course_description, + problem_statement=problem_statement, + user_input=user_input, ) diff --git a/app/pipeline/text_exercise_chat_pipeline.py b/app/pipeline/text_exercise_chat_pipeline.py index 253504a..38d5754 100644 --- a/app/pipeline/text_exercise_chat_pipeline.py +++ b/app/pipeline/text_exercise_chat_pipeline.py @@ -5,11 +5,14 @@ from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments from app.pipeline import Pipeline from domain import PyrisMessage, IrisMessageRole -from domain.data.text_message_content_dto import TextMessageContentDTO from domain.text_exercise_chat_pipeline_execution_dto import ( TextExerciseChatPipelineExecutionDTO, ) -from pipeline.prompts.text_exercise_chat_prompts import system_prompt +from pipeline.prompts.text_exercise_chat_prompts import ( + fmt_system_prompt, + fmt_rejection_prompt, + fmt_guard_prompt, +) from web.status.status_update import TextExerciseChatCallback logger = logging.getLogger(__name__) @@ -34,26 +37,67 @@ def __call__( if not dto.exercise: raise ValueError("Exercise is required") - prompt = system_prompt( - exercise_name=dto.exercise.name, + should_respond = self.guard(dto) + self.callback.done("Responding" if should_respond else "Rejecting") + + if should_respond: + response = self.respond(dto) + else: + response = self.reject(dto) + + self.callback.done(final_result=response) + + def guard(self, dto: TextExerciseChatPipelineExecutionDTO) -> bool: + guard_prompt = fmt_guard_prompt( + exercise_name=dto.exercise.title, + course_name=dto.exercise.course.name, + course_description=dto.exercise.course.description, + problem_statement=dto.exercise.problem_statement, + user_input=dto.current_submission, + ) + guard_prompt = PyrisMessage( + sender=IrisMessageRole.SYSTEM, + contents=[{"text_content": guard_prompt}], + ) + response = self.request_handler.chat([guard_prompt], CompletionArguments()) + response = response.contents[0].text_content + return "yes" in response.lower() + + def respond(self, dto: TextExerciseChatPipelineExecutionDTO) -> str: + system_prompt = fmt_system_prompt( + exercise_name=dto.exercise.title, course_name=dto.exercise.course.name, course_description=dto.exercise.course.description, problem_statement=dto.exercise.problem_statement, start_date=str(dto.exercise.start_date), end_date=str(dto.exercise.end_date), current_date=str(datetime.now()), - current_answer=dto.current_answer, + current_submission=dto.current_submission, ) - prompt = PyrisMessage( + system_prompt = PyrisMessage( sender=IrisMessageRole.SYSTEM, - contents=[TextMessageContentDTO(text_content=prompt)], + contents=[{"text_content": system_prompt}], ) - - # done building prompt + prompts = [system_prompt] + dto.conversation response = self.request_handler.chat( - [prompt], CompletionArguments(temperature=0.4) + prompts, CompletionArguments(temperature=0.4) ) - response = response.contents[0].text_content + return response.contents[0].text_content - self.callback.done(response) + def reject(self, dto: TextExerciseChatPipelineExecutionDTO) -> str: + rejection_prompt = fmt_rejection_prompt( + exercise_name=dto.exercise.title, + course_name=dto.exercise.course.name, + course_description=dto.exercise.course.description, + problem_statement=dto.exercise.problem_statement, + user_input=dto.current_submission, + ) + rejection_prompt = PyrisMessage( + sender=IrisMessageRole.SYSTEM, + contents=[{"text_content": rejection_prompt}], + ) + response = self.request_handler.chat( + [rejection_prompt], CompletionArguments(temperature=0.4) + ) + return response.contents[0].text_content diff --git a/app/web/routers/pipelines.py b/app/web/routers/pipelines.py index fae6092..c53ab66 100644 --- a/app/web/routers/pipelines.py +++ b/app/web/routers/pipelines.py @@ -98,9 +98,9 @@ def run_course_chat_pipeline(variant: str, dto: CourseChatPipelineExecutionDTO): def run_text_exercise_chat_pipeline_worker(dto, variant): try: callback = TextExerciseChatCallback( - run_id=dto.settings.authentication_token, - base_url=dto.settings.artemis_base_url, - initial_stages=dto.initial_stages, + run_id=dto.execution.settings.authentication_token, + base_url=dto.execution.settings.artemis_base_url, + initial_stages=dto.execution.initial_stages, ) match variant: case "default" | "text_exercise_chat_pipeline_reference_impl": @@ -193,6 +193,14 @@ def get_pipeline(feature: str): description="Default programming exercise chat variant.", ) ] + case "TEXT_EXERCISE_CHAT": + return [ + FeatureDTO( + id="default", + name="Default Variant", + description="Default text exercise chat variant.", + ) + ] case "COURSE_CHAT": return [ FeatureDTO( diff --git a/app/web/status/status_update.py b/app/web/status/status_update.py index 1ddf1ca..da0078b 100644 --- a/app/web/status/status_update.py +++ b/app/web/status/status_update.py @@ -5,18 +5,21 @@ import requests from abc import ABC -from ...domain.status.competency_extraction_status_update_dto import ( +from app.domain.status.competency_extraction_status_update_dto import ( CompetencyExtractionStatusUpdateDTO, ) -from ...domain.chat.course_chat.course_chat_status_update_dto import ( +from app.domain.chat.course_chat.course_chat_status_update_dto import ( CourseChatStatusUpdateDTO, ) -from ...domain.status.stage_state_dto import StageStateEnum -from ...domain.status.stage_dto import StageDTO -from ...domain.chat.exercise_chat.exercise_chat_status_update_dto import ( +from app.domain.status.stage_state_dto import StageStateEnum +from app.domain.status.stage_dto import StageDTO +from app.domain.status.text_exercise_chat_status_update_dto import ( + TextExerciseChatStatusUpdateDTO, +) +from app.domain.chat.exercise_chat.exercise_chat_status_update_dto import ( ExerciseChatStatusUpdateDTO, ) -from ...domain.status.status_update_dto import StatusUpdateDTO +from app.domain.status.status_update_dto import StatusUpdateDTO import logging logger = logging.getLogger(__name__) @@ -230,13 +233,22 @@ def __init__( stage = len(stages) stages += [ StageDTO( - weight=40, + weight=20, state=StageStateEnum.NOT_STARTED, name="Thinking", - ) + ), + StageDTO( + weight=20, + state=StageStateEnum.NOT_STARTED, + name="Responding", + ), ] super().__init__( - url, run_id, StatusUpdateDTO(stages=stages), stages[stage], stage + url, + run_id, + TextExerciseChatStatusUpdateDTO(stages=stages), + stages[stage], + stage, )