From 08e92ba15514f5f196c596c3d9b005017d7bb173 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 12 Mar 2024 16:10:40 -0400 Subject: [PATCH] log both kinds of requests --- router/src/server.rs | 46 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index e4b546f4f..885b6d4e1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -92,7 +92,14 @@ async fn compat_generate( .await .into_response()) } else { - let (headers, generation) = generate(infer, info, req_headers, Json(req.into())).await?; + let (headers, generation) = generate( + infer, + info, + request_logger_sender, + req_headers, + Json(req.into()), + ) + .await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(vec![generation.0])).into_response()) } @@ -162,8 +169,14 @@ async fn completions_v1( .await; Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response()) } else { - let (headers, generation) = - generate(infer, info, req_headers, Json(gen_req.into())).await?; + let (headers, generation) = generate( + infer, + info, + request_logger_sender, + req_headers, + Json(gen_req.into()), + ) + .await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(CompletionResponse::from(generation.0))).into_response()) } @@ -233,8 +246,14 @@ async fn chat_completions_v1( .await; Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response()) } else { - let (headers, generation) = - generate(infer, info, req_headers, Json(gen_req.into())).await?; + let (headers, generation) = generate( + infer, + info, + request_logger_sender, + req_headers, + Json(gen_req.into()), + ) + .await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(ChatCompletionResponse::from(generation.0))).into_response()) } @@ -310,6 +329,7 @@ seed, async fn generate( infer: Extension, info: Extension, + request_logger_sender: Extension>>, req_headers: HeaderMap, mut req: Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { @@ -343,6 +363,8 @@ async fn generate( }); } + let api_token = req.parameters.api_token.clone(); + // Inference let (response, best_of_responses) = match req.0.parameters.best_of { Some(best_of) if best_of > 1 => { @@ -490,6 +512,16 @@ async fn generate( response.generated_text.generated_tokens as f64 ); + if std::env::var("REQUEST_LOGGER_URL").ok().is_some() { + let _ = request_logger_sender + .send(( + response.generated_text.generated_tokens as i64, + api_token.unwrap_or("".to_string()), + info.model_id.clone(), + )) + .await; + } + // Send response let mut output_text = response.generated_text.text; if let Some(prompt) = add_prompt { @@ -731,7 +763,9 @@ async fn generate_stream_with_callback( tracing::debug!(parent: &span, "Output: {}", output_text); tracing::info!(parent: &span, "Success"); - request_logger_sender.send((generated_text.generated_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await; + if std::env::var("REQUEST_LOGGER_URL").ok().is_some() { + let _ = request_logger_sender.send((generated_text.generated_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await; + } let stream_token = StreamResponse { token,