diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs index d1c21b0..9fa73ce 100644 --- a/crates/server/src/main.rs +++ b/crates/server/src/main.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use miniserve::{http::StatusCode, Content, Request, Response}; use serde::{Deserialize, Serialize}; use tokio::join; @@ -16,19 +18,23 @@ async fn chat(req: Request) -> Response { let Request::Post(body) = req else { return Err(StatusCode::METHOD_NOT_ALLOWED); }; - let Ok(mut messages) = serde_json::from_str::(&body) else { + let Ok(mut data) = serde_json::from_str::(&body) else { return Err(StatusCode::INTERNAL_SERVER_ERROR); }; - let (i, mut responses) = join!( + let messages = Arc::new(data.messages); + let messages_ref = Arc::clone(&messages); + let (i, responses) = join!( chatbot::gen_random_number(), - chatbot::query_chat(&messages.messages) + tokio::spawn(async move { chatbot::query_chat(&messages_ref).await }) ); + let mut responses = responses.unwrap(); let response = responses.remove(i % responses.len()); - messages.messages.push(response); + data.messages = Arc::into_inner(messages).unwrap(); + data.messages.push(response); - Ok(Content::Json(serde_json::to_string(&messages).unwrap())) + Ok(Content::Json(serde_json::to_string(&data).unwrap())) } #[tokio::main]