Skip to content

Commit

Permalink
chris review this
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 committed Sep 6, 2023
1 parent 318078b commit d3f3a60
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 10 deletions.
22 changes: 22 additions & 0 deletions backend/danswer/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ def fetch_chat_messages_by_session(
return list(result)


def fetch_chat_message(
chat_session_id: int, message_number: int, edit_number: int, db_session: Session
) -> ChatMessage:
stmt = select(ChatMessage).where(
(ChatMessage.chat_session_id == chat_session_id)
& (ChatMessage.message_number == message_number)
& (ChatMessage.edit_number == edit_number)
)

chat_message = db_session.execute(stmt).scalar_one_or_none()

if not chat_message:
raise ValueError("Invalid Chat Message specified")

return chat_message


def fetch_chat_session_by_id(chat_session_id: int, db_session: Session) -> ChatSession:
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
result = db_session.execute(stmt)
Expand Down Expand Up @@ -135,6 +152,11 @@ def _set_latest_chat_message_no_commit(
edit_number: int,
db_session: Session,
) -> None:
if message_number != 0 and parent_edit_number is None:
raise ValueError(
"Only initial message in a chat is allowed to not have a parent"
)

db_session.query(ChatMessage).filter(
and_(
ChatMessage.chat_session_id == chat_session_id,
Expand Down
36 changes: 28 additions & 8 deletions backend/danswer/server/chat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
from danswer.db.chat import fetch_chat_message
from danswer.db.chat import fetch_chat_messages_by_session
from danswer.db.chat import fetch_chat_session_by_id
from danswer.db.chat import fetch_chat_sessions_by_user
Expand Down Expand Up @@ -88,8 +89,8 @@ def get_chat_session_messages(
messages=[
ChatMessageDetail(
message_number=msg.message_number,
parent_edit_number=msg.parent_edit_number,
edit_number=msg.edit_number,
parent_edit_number=msg.parent_edit_number,
latest=msg.latest,
message=msg.message,
message_type=msg.message_type,
Expand Down Expand Up @@ -217,6 +218,7 @@ def handle_new_chat_message(
if parent_edit_number is not None:
raise ValueError("Initial message in session cannot have parent")

# Create new message at the right place in the tree and label it latest for its parent
new_message = create_new_chat_message(
chat_session_id=chat_session_id,
message_number=message_number,
Expand Down Expand Up @@ -251,19 +253,28 @@ def stream_chat_tokens() -> Iterator[str]:
return StreamingResponse(stream_chat_tokens(), media_type="application/json")


@router.post("/regenerate")
@router.post("/regenerate-from-parent")
def regenerate_message_given_parent(
parent_message: ChatMessageIdentifier,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
"""Regenerate an LLM response given a particular parent message
The parent message is set as latest and a new LLM response is set as
the latest following message"""
chat_session_id = parent_message.chat_session_id
message_number = parent_message.message_number
parent_edit_number = parent_message.parent_edit_number
edit_number = parent_message.edit_number
user_id = user.id if user is not None else None

chat_session = fetch_chat_session_by_id(parent_message.chat_session_id, db_session)
chat_message = fetch_chat_message(
chat_session_id=chat_session_id,
message_number=message_number,
edit_number=edit_number,
db_session=db_session,
)

chat_session = chat_message.chat_session

if chat_session.deleted:
raise ValueError("Chat session has been deleted")
Expand All @@ -280,11 +291,13 @@ def regenerate_message_given_parent(
set_latest_chat_message(
chat_session_id,
message_number,
parent_edit_number,
chat_message.parent_edit_number,
edit_number,
db_session,
)

# The parent message, now set as latest, may have follow on messages
# Don't want to include those in the context to LLM
mainline_messages = _create_chat_chain(
chat_session_id, db_session, stop_after=message_number
)
Expand All @@ -300,7 +313,7 @@ def stream_regenerate_tokens() -> Iterator[str]:
create_new_chat_message(
chat_session_id=chat_session_id,
message_number=message_number + 1,
parent_edit_number=parent_edit_number,
parent_edit_number=edit_number,
message=llm_output,
message_type=MessageType.ASSISTANT,
db_session=db_session,
Expand All @@ -311,13 +324,20 @@ def stream_regenerate_tokens() -> Iterator[str]:

@router.put("/set-message-as-latest")
def set_message_as_latest(
chat_message: ChatMessageIdentifier,
message_identifier: ChatMessageIdentifier,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user is not None else None

chat_session = fetch_chat_session_by_id(chat_message.chat_session_id, db_session)
chat_message = fetch_chat_message(
chat_session_id=message_identifier.chat_session_id,
message_number=message_identifier.message_number,
edit_number=message_identifier.edit_number,
db_session=db_session,
)

chat_session = chat_message.chat_session

if chat_session.deleted:
raise ValueError("Chat session has been deleted")
Expand Down
3 changes: 1 addition & 2 deletions backend/danswer/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ class CreateChatRequest(BaseModel):
class ChatMessageIdentifier(BaseModel):
chat_session_id: int
message_number: int
parent_edit_number: int | None
edit_number: int


Expand All @@ -186,8 +185,8 @@ class ChatSessionIdsResponse(BaseModel):

class ChatMessageDetail(BaseModel):
message_number: int
parent_edit_number: int | None
edit_number: int
parent_edit_number: int | None
latest: bool
message: str
message_type: MessageType
Expand Down

0 comments on commit d3f3a60

Please sign in to comment.