Skip to content

Commit

Permalink
Add thumbs up / down to each message (aws-samples#256)
Browse files Browse the repository at this point in the history
* add

* fix type check

* fix type check

* add frontend

* add doc

* fix

* fix

* fix type err

* add: store used chunk to ddb

* add: isLargeMessage to glue schema

* refactoring

* fix

* fix

---------

Co-authored-by: Yusuke Wada <[email protected]>
  • Loading branch information
statefb and wadabee authored Apr 26, 2024
1 parent ba72df2 commit 41d2639
Show file tree
Hide file tree
Showing 28 changed files with 706 additions and 39 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
> [!Tip]
> 🔔**Claude3 Opus supported.** As of 04/17/2024, Bedrock only supports the `us-west-2` region. In this repository, Bedrock uses the `us-east-1` region by default. Therefore, if you plan to use it, please change the value of `bedrockRegion` before deployment. For more details, please refer [here](#deploy-using-cdk).
> [!Info]
> [!Important]
> We'd like to hear your feedback to implement bot creation permission management feature. The plan is to grant permissions to individual users through the admin panel, but this may increase operational overhead for existing users. [Please take the survey](https://github.com/aws-samples/bedrock-claude-chat/issues/161#issuecomment-2058194533).
> [!Warning]
Expand Down
48 changes: 48 additions & 0 deletions backend/app/repositories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
decompose_conv_id,
)
from app.repositories.models.conversation import (
ChunkModel,
ContentModel,
ConversationMeta,
ConversationModel,
FeedbackModel,
MessageModel,
)
from app.utils import get_current_time
Expand Down Expand Up @@ -204,6 +206,27 @@ def find_conversation_by_id(user_id: str, conversation_id: str) -> ConversationM
children=v["children"],
parent=v["parent"],
create_time=float(v["create_time"]),
feedback=(
FeedbackModel(
thumbs_up=v["feedback"]["thumbs_up"],
category=v["feedback"]["category"],
comment=v["feedback"]["comment"],
)
if v.get("feedback")
else None
),
used_chunks=(
[
ChunkModel(
content=c["content"],
source=c["source"],
rank=c["rank"],
)
for c in v["used_chunks"]
]
if v.get("used_chunks")
else None
),
)
for k, v in message_map.items()
},
Expand Down Expand Up @@ -325,3 +348,28 @@ def change_conversation_title(user_id: str, conversation_id: str, new_title: str
logger.info(f"Updated conversation title response: {response}")

return response


def update_feedback(
user_id: str, conversation_id: str, message_id: str, feedback: FeedbackModel
):
logger.info(f"Updating feedback for conversation: {conversation_id}")
table = _get_table_client(user_id)
conv = find_conversation_by_id(user_id, conversation_id)
message_map = conv.message_map
message_map[message_id].feedback = feedback

response = table.update_item(
Key={
"PK": user_id,
"SK": compose_conv_id(user_id, conversation_id),
},
UpdateExpression="set MessageMap = :m",
ExpressionAttributeValues={
":m": json.dumps({k: v.model_dump() for k, v in message_map.items()})
},
ConditionExpression="attribute_exists(PK) AND attribute_exists(SK)",
ReturnValues="UPDATED_NEW",
)
logger.info(f"Updated feedback response: {response}")
return response
14 changes: 14 additions & 0 deletions backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,27 @@ class ContentModel(BaseModel):
body: str


class FeedbackModel(BaseModel):
thumbs_up: bool
category: str
comment: str


class ChunkModel(BaseModel):
content: str
source: str
rank: int


class MessageModel(BaseModel):
role: str
content: list[ContentModel]
model: type_model_name
children: list[str]
parent: str | None
create_time: float
feedback: FeedbackModel | None
used_chunks: list[ChunkModel] | None


class ConversationModel(BaseModel):
Expand Down
1 change: 0 additions & 1 deletion backend/app/routes/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
remove_bot_by_id,
remove_uploaded_file,
)
from app.usecases.chat import chat, fetch_conversation, propose_conversation_title
from app.user import User
from fastapi import APIRouter, Request

Expand Down
34 changes: 34 additions & 0 deletions backend/app/routes/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
delete_conversation_by_id,
delete_conversation_by_user_id,
find_conversation_by_user_id,
update_feedback,
)
from app.repositories.models.conversation import FeedbackModel
from app.routes.schemas.conversation import (
ChatInput,
ChatOutput,
Conversation,
ConversationMetaOutput,
FeedbackInput,
FeedbackOutput,
NewTitleInput,
ProposedTitle,
RelatedDocumentsOutput,
Expand Down Expand Up @@ -119,3 +123,33 @@ def get_proposed_title(request: Request, conversation_id: str):

title = propose_conversation_title(current_user.id, conversation_id)
return ProposedTitle(title=title)


@router.put(
"/conversation/{conversation_id}/{message_id}/feedback",
response_model=FeedbackOutput,
)
def put_feedback(
request: Request,
conversation_id: str,
message_id: str,
feedback_input: FeedbackInput,
):
"""Send feedback."""
current_user: User = request.state.current_user

update_feedback(
user_id=current_user.id,
conversation_id=conversation_id,
message_id=message_id,
feedback=FeedbackModel(
thumbs_up=feedback_input.thumbs_up,
category=feedback_input.category if feedback_input.category else "",
comment=feedback_input.comment if feedback_input.comment else "",
),
)
return FeedbackOutput(
thumbs_up=feedback_input.thumbs_up,
category=feedback_input.category if feedback_input.category else "",
comment=feedback_input.comment if feedback_input.comment else "",
)
34 changes: 33 additions & 1 deletion backend/app/routes/schemas/conversation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Literal

from app.routes.schemas.base import BaseSchema
from pydantic import Field
from pydantic import Field, root_validator, validator

type_model_name = Literal[
"claude-instant-v1",
Expand All @@ -23,6 +23,36 @@ class Content(BaseSchema):
body: str = Field(..., description="Content body. Text or base64 encoded image.")


class FeedbackInput(BaseSchema):
thumbs_up: bool
category: str | None = Field(
None, description="Reason category. Required if thumbs_up is False."
)
comment: str | None = Field(None, description="optional comment")

@root_validator(pre=True)
def check_category(cls, values):
thumbs_up = values.get("thumbs_up")
category = values.get("category")

if not thumbs_up and category is None:
raise ValueError("category is required if `thumbs_up` is `False`")

return values


class FeedbackOutput(BaseSchema):
thumbs_up: bool
category: str
comment: str


class Chunk(BaseSchema):
content: str
source: str
rank: int


class MessageInput(BaseSchema):
role: str
content: list[Content]
Expand All @@ -38,6 +68,8 @@ class MessageOutput(BaseSchema):
content: list[Content]
model: type_model_name
children: list[str]
feedback: FeedbackOutput | None
used_chunks: list[Chunk] | None
parent: str | None


Expand Down
71 changes: 66 additions & 5 deletions backend/app/usecases/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from app.repositories.custom_bot import find_alias_by_id, store_alias
from app.repositories.models.conversation import (
ChunkModel,
ContentModel,
ConversationModel,
MessageModel,
Expand All @@ -22,14 +23,21 @@
from app.routes.schemas.conversation import (
ChatInput,
ChatOutput,
Chunk,
Content,
Conversation,
FeedbackOutput,
MessageOutput,
RelatedDocumentsOutput,
)
from app.usecases.bot import fetch_bot, modify_bot_last_used_time
from app.utils import get_anthropic_client, get_current_time, is_running_on_lambda
from app.vector_search import SearchResult, get_source_link, search_related_docs
from app.vector_search import (
SearchResult,
filter_used_results,
get_source_link,
search_related_docs,
)
from ulid import ULID

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,7 +73,7 @@ def prepare_conversation(
)

initial_message_map = {
# Dummy system message
# Dummy system message, which is used for root node of the message tree.
"system": MessageModel(
role="system",
content=[
Expand All @@ -79,6 +87,8 @@ def prepare_conversation(
children=[],
parent=None,
create_time=current_time,
feedback=None,
used_chunks=None,
)
}
parent_id = "system"
Expand All @@ -100,6 +110,8 @@ def prepare_conversation(
children=[],
parent="system",
create_time=current_time,
feedback=None,
used_chunks=None,
)
initial_message_map["system"].children.append("instruction")

Expand Down Expand Up @@ -157,6 +169,8 @@ def prepare_conversation(
children=[],
parent=parent_id,
create_time=current_time,
feedback=None,
used_chunks=None,
)
conversation.message_map[message_id] = new_message
conversation.message_map[parent_id].children.append(message_id) # type: ignore
Expand Down Expand Up @@ -259,18 +273,19 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
user_msg_id, conversation, bot = prepare_conversation(user_id, chat_input)

message_map = conversation.message_map
search_results = []
if bot and is_running_on_lambda():
# NOTE: `is_running_on_lambda`is a workaround for local testing due to no postgres mock.
# Fetch most related documents from vector store
# NOTE: Currently embedding not support multi-modal. For now, use the last content.
query = conversation.message_map[user_msg_id].content[-1].body
results = search_related_docs(
search_results = search_related_docs(
bot_id=bot.id, limit=SEARCH_CONFIG["max_results"], query=query
)
logger.info(f"Search results from vector store: {results}")
logger.info(f"Search results from vector store: {search_results}")

# Insert contexts to instruction
conversation_with_context = insert_knowledge(conversation, results)
conversation_with_context = insert_knowledge(conversation, search_results)
message_map = conversation_with_context.message_map

messages = trace_to_root(
Expand All @@ -291,6 +306,14 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
response: AnthropicMessage = client.messages.create(**args)
reply_txt = response.content[0].text

# Used chunks for RAG generation
used_chunks = None
if bot and is_running_on_lambda():
used_chunks = [
ChunkModel(content=r.content, source=r.source, rank=r.rank)
for r in filter_used_results(reply_txt, search_results)
]

# Issue id for new assistant message
assistant_msg_id = str(ULID())
# Append bedrock output to the existing conversation
Expand All @@ -301,6 +324,8 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
children=[],
parent=user_msg_id,
create_time=get_current_time(),
feedback=None,
used_chunks=used_chunks,
)
conversation.message_map[assistant_msg_id] = message

Expand Down Expand Up @@ -341,6 +366,19 @@ def chat(user_id: str, chat_input: ChatInput) -> ChatOutput:
model=message.model,
children=message.children,
parent=message.parent,
feedback=None,
used_chunks=(
[
Chunk(
content=c.content,
source=c.source,
rank=c.rank,
)
for c in message.used_chunks
]
if message.used_chunks
else None
),
),
bot_id=conversation.bot_id,
)
Expand Down Expand Up @@ -389,6 +427,8 @@ def propose_conversation_title(
children=[],
parent=conversation.last_message_id,
create_time=get_current_time(),
feedback=None,
used_chunks=None,
)
messages.append(new_message)

Expand Down Expand Up @@ -419,6 +459,27 @@ def fetch_conversation(user_id: str, conversation_id: str) -> Conversation:
model=message.model,
children=message.children,
parent=message.parent,
feedback=(
FeedbackOutput(
thumbs_up=message.feedback.thumbs_up,
category=message.feedback.category,
comment=message.feedback.comment,
)
if message.feedback
else None
),
used_chunks=(
[
Chunk(
content=c.content,
source=c.source,
rank=c.rank,
)
for c in message.used_chunks
]
if message.used_chunks
else None
),
)
for message_id, message in conversation.message_map.items()
}
Expand Down
Loading

0 comments on commit 41d2639

Please sign in to comment.