diff --git a/src/backend/app/tasks/task_logic.py b/src/backend/app/tasks/task_logic.py index ca962312..c44c11b0 100644 --- a/src/backend/app/tasks/task_logic.py +++ b/src/backend/app/tasks/task_logic.py @@ -131,40 +131,38 @@ async def request_mapping( async def get_task_state( - db: Connection, project_id: uuid.UUID, task_id: uuid.UUID - ) -> dict: - """ - Retrieve the latest state of a task by querying the task_events table. - - Args: - db (Connection): The database connection. - project_id (uuid.UUID): The project ID. - task_id (uuid.UUID): The task ID. - - Returns: - dict: A dictionary containing the task's state and associated metadata. - """ - try: - async with db.cursor(row_factory=dict_row) as cur: - await cur.execute( - """ - SELECT state, user_id, created_at, comment - FROM task_events - WHERE project_id = %(project_id)s AND task_id = %(task_id)s - ORDER BY created_at DESC - LIMIT 1; - """, - { - "project_id": str(project_id), - "task_id": str(task_id), - }, - ) - result = await cur.fetchone() - return result - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"An error occurred while retrieving the task state: {str(e)}" - ) + db: Connection, project_id: uuid.UUID, task_id: uuid.UUID +) -> dict: + """ + Retrieve the latest state of a task by querying the task_events table. + Args: + db (Connection): The database connection. + project_id (uuid.UUID): The project ID. + task_id (uuid.UUID): The task ID. + Returns: + dict: A dictionary containing the task's state and associated metadata. + """ + try: + async with db.cursor(row_factory=dict_row) as cur: + await cur.execute( + """ + SELECT state, user_id, created_at, comment + FROM task_events + WHERE project_id = %(project_id)s AND task_id = %(task_id)s + ORDER BY created_at DESC + LIMIT 1; + """, + { + "project_id": str(project_id), + "task_id": str(task_id), + }, + ) + result = await cur.fetchone() + return result + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An error occurred while retrieving the task state: {str(e)}", + ) diff --git a/src/backend/app/tasks/task_routes.py b/src/backend/app/tasks/task_routes.py index 7d67576a..47e3cab7 100644 --- a/src/backend/app/tasks/task_routes.py +++ b/src/backend/app/tasks/task_routes.py @@ -369,27 +369,36 @@ async def new_event( ) case EventType.RESET: + # Fetch the task state current_task_state = await task_logic.get_task_state( db, project_id, task_id ) - if ( - current_task_state["state"] == State.LOCKED_FOR_MAPPING.name - and user_id == current_task_state["user_id"] - ): - # Task is locked by the user, so reset it (unlock) - return await task_logic.update_task_state( - db, - project_id, - task_id, - user_id, - "Task reset by user", - State.LOCKED_FOR_MAPPING, - State.UNLOCKED_TO_MAP, + # Extract state and user from the result + state = current_task_state.get("state") + locked_user_id = current_task_state.get("user_id") + + # Determine error conditions in a single pass + if state != State.LOCKED_FOR_MAPPING.name: + raise HTTPException( + status_code=400, + detail="Task state does not match expected state for reset operation.", ) - raise HTTPException( - status_code=400, - detail="Task is not locked by the user or cannot be reset.", + if user_id != locked_user_id: + raise HTTPException( + status_code=403, + detail="You cannot unlock this task as it is locked by another user.", + ) + + # Proceed with resetting the task + return await task_logic.update_task_state( + db, + project_id, + task_id, + user_id, + f"Task has been reset by user {user_data.name}.", + State.LOCKED_FOR_MAPPING, + State.UNLOCKED_TO_MAP, ) return True