Skip to content

Commit

Permalink
Make token estimation take sent message into account
Browse files Browse the repository at this point in the history
  • Loading branch information
nygrenh committed Oct 16, 2024
1 parent f508c5a commit c611c50
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
31 changes: 21 additions & 10 deletions services/headless-lms/chatbot/src/azure_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl ChatRequest {
chatbot_configuration_id: Uuid,
conversation_id: Uuid,
message: &str,
) -> anyhow::Result<(Self, ChatbotConversationMessage)> {
) -> anyhow::Result<(Self, ChatbotConversationMessage, i32)> {
let configuration =
models::chatbot_configurations::get_by_id(conn, chatbot_configuration_id).await?;

Expand Down Expand Up @@ -172,6 +172,9 @@ impl ChatRequest {
Vec::new()
};

let serialized_messages = serde_json::to_string(&api_chat_messages)?;
let request_estimated_tokens = estimate_tokens(&serialized_messages);

Ok((
Self {
messages: api_chat_messages,
Expand All @@ -185,6 +188,7 @@ impl ChatRequest {
stream: true,
},
new_message,
request_estimated_tokens,
))
}
}
Expand Down Expand Up @@ -244,6 +248,7 @@ struct RequestCancelledGuard {
received_string: Arc<Mutex<Vec<String>>>,
pool: PgPool,
done: Arc<AtomicBool>,
request_estimated_tokens: i32,
}

impl Drop for RequestCancelledGuard {
Expand All @@ -256,6 +261,7 @@ impl Drop for RequestCancelledGuard {
let response_message_id = self.response_message_id;
let received_string = self.received_string.clone();
let pool = self.pool.clone();
let request_estimated_tokens = self.request_estimated_tokens;
tokio::spawn(async move {
info!("Verifying the received message has been handled");
let mut conn = pool.acquire().await.expect("Could not acquire connection");
Expand All @@ -275,11 +281,13 @@ impl Drop for RequestCancelledGuard {
estimated_cost, full_response_as_string
);

// Update with request_estimated_tokens + estimated_cost
models::chatbot_conversation_messages::update(
&mut conn,
response_message_id,
&full_response_as_string,
true,
request_estimated_tokens + estimated_cost,
)
.await
.expect("Could not update response message");
Expand All @@ -295,13 +303,14 @@ pub async fn send_chat_request_and_parse_stream(
conversation_id: Uuid,
message: &str,
) -> anyhow::Result<impl Stream<Item = anyhow::Result<Bytes>>> {
let (chat_request, new_message) = ChatRequest::build_and_insert_incoming_message_to_db(
conn,
chatbot_configuration_id,
conversation_id,
message,
)
.await?;
let (chat_request, new_message, request_estimated_tokens) =
ChatRequest::build_and_insert_incoming_message_to_db(
conn,
chatbot_configuration_id,
conversation_id,
message,
)
.await?;

let full_response_text = Arc::new(Mutex::new(Vec::new()));
let done = Arc::new(AtomicBool::new(false));
Expand Down Expand Up @@ -337,7 +346,7 @@ pub async fn send_chat_request_and_parse_stream(
message: None,
is_from_chatbot: true,
message_is_complete: false,
used_tokens: 0,
used_tokens: request_estimated_tokens,
order_number: response_order_number,
},
)
Expand All @@ -348,6 +357,7 @@ pub async fn send_chat_request_and_parse_stream(
received_string: full_response_text.clone(),
pool: pool.clone(),
done: done.clone(),
request_estimated_tokens,
};

let request = REQWEST_CLIENT
Expand Down Expand Up @@ -396,7 +406,8 @@ pub async fn send_chat_request_and_parse_stream(
&mut conn,
response_message.id,
&full_response_as_string,
true
true,
request_estimated_tokens + estimated_cost,
).await?;
break;
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,20 @@ pub async fn update(
id: Uuid,
message: &str,
message_is_complete: bool,
used_tokens: i32,
) -> ModelResult<ChatbotConversationMessage> {
let res = sqlx::query_as!(
ChatbotConversationMessage,
r#"
UPDATE chatbot_conversation_messages
SET message = $2, message_is_complete = $3, updated_at = NOW()
SET message = $2, message_is_complete = $3, used_tokens = $4
WHERE id = $1
RETURNING *
"#,
id,
message,
message_is_complete
message_is_complete,
used_tokens
)
.fetch_one(conn)
.await?;
Expand Down

0 comments on commit c611c50

Please sign in to comment.