Skip to content

Commit

Permalink
log both kinds of requests
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Mar 12, 2024
1 parent 29a9da0 commit 08e92ba
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,14 @@ async fn compat_generate(
.await
.into_response())
} else {
let (headers, generation) = generate(infer, info, req_headers, Json(req.into())).await?;
let (headers, generation) = generate(
infer,
info,
request_logger_sender,
req_headers,
Json(req.into()),
)
.await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation.0])).into_response())
}
Expand Down Expand Up @@ -162,8 +169,14 @@ async fn completions_v1(
.await;
Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response())
} else {
let (headers, generation) =
generate(infer, info, req_headers, Json(gen_req.into())).await?;
let (headers, generation) = generate(
infer,
info,
request_logger_sender,
req_headers,
Json(gen_req.into()),
)
.await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(CompletionResponse::from(generation.0))).into_response())
}
Expand Down Expand Up @@ -233,8 +246,14 @@ async fn chat_completions_v1(
.await;
Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response())
} else {
let (headers, generation) =
generate(infer, info, req_headers, Json(gen_req.into())).await?;
let (headers, generation) = generate(
infer,
info,
request_logger_sender,
req_headers,
Json(gen_req.into()),
)
.await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(ChatCompletionResponse::from(generation.0))).into_response())
}
Expand Down Expand Up @@ -310,6 +329,7 @@ seed,
async fn generate(
infer: Extension<Infer>,
info: Extension<Info>,
request_logger_sender: Extension<Arc<mpsc::Sender<(i64, String, String)>>>,
req_headers: HeaderMap,
mut req: Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
Expand Down Expand Up @@ -343,6 +363,8 @@ async fn generate(
});
}

let api_token = req.parameters.api_token.clone();

// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => {
Expand Down Expand Up @@ -490,6 +512,16 @@ async fn generate(
response.generated_text.generated_tokens as f64
);

if std::env::var("REQUEST_LOGGER_URL").ok().is_some() {
let _ = request_logger_sender
.send((
response.generated_text.generated_tokens as i64,
api_token.unwrap_or("".to_string()),
info.model_id.clone(),
))
.await;
}

// Send response
let mut output_text = response.generated_text.text;
if let Some(prompt) = add_prompt {
Expand Down Expand Up @@ -731,7 +763,9 @@ async fn generate_stream_with_callback(
tracing::debug!(parent: &span, "Output: {}", output_text);
tracing::info!(parent: &span, "Success");

request_logger_sender.send((generated_text.generated_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await;
if std::env::var("REQUEST_LOGGER_URL").ok().is_some() {
let _ = request_logger_sender.send((generated_text.generated_tokens as i64, api_token.unwrap_or("".to_string()), info.model_id.clone())).await;
}

let stream_token = StreamResponse {
token,
Expand Down

0 comments on commit 08e92ba

Please sign in to comment.