Skip to content

Commit

Permalink
cancel requests immediately on connection drop
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Dec 2, 2024
1 parent 35f2d56 commit aaddb60
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
36 changes: 36 additions & 0 deletions llgtrt/src/routes/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,34 @@ fn take_final_logs(llg: &mut Constraint) -> Result<String> {
Ok(llg.flush_logs())
}

struct ReqCancelToken {
req_id: Option<ReqId>,
}

impl ReqCancelToken {
pub fn disarm(&mut self) {
self.req_id = None;
}
}

impl Drop for ReqCancelToken {
fn drop(&mut self) {
if let Some(req_id) = self.req_id {
log::warn!("dropping ReqCancelToken: {}", req_id);
let _ = AsyncExecutor::lock().cancel_request(req_id);
} else {
log::debug!("not dropping ReqCancelToken");
}
}
}

impl ReqInfo {
fn cancel_token(&self) -> ReqCancelToken {
ReqCancelToken {
req_id: Some(self.req_id),
}
}

fn all_forks_stopped(&self) -> bool {
self.forks.iter().all(|f| f.stop_reason.is_some())
}
Expand Down Expand Up @@ -638,6 +665,7 @@ impl ReqInfo {
async fn completions_stream(
mut client: ReqInfo,
) -> Result<Sse<impl Stream<Item = anyhow::Result<Event>>>, AppError> {
let mut token = client.cancel_token();
let result0 = client
.recv
.recv()
Expand All @@ -649,6 +677,8 @@ async fn completions_stream(
}

let response_stream = try_stream! {
let mut token = client.cancel_token();

if client.is_run {
yield client.initial_run();
}
Expand All @@ -667,12 +697,16 @@ async fn completions_stream(
}

yield Event::default().data("[DONE]");

token.disarm();
};

token.disarm();
Ok(Sse::new(response_stream))
}

async fn completions(mut client: ReqInfo) -> Result<Json<Value>, AppError> {
let mut token = client.cancel_token();
let mut logprobs = vec![];
while let Some(mut result) = client.recv.recv().await {
log::trace!("infer response: {:?}", result.response);
Expand Down Expand Up @@ -745,6 +779,8 @@ async fn completions(mut client: ReqInfo) -> Result<Json<Value>, AppError> {
serde_json::to_value(Completion::of_chat_completion(chat_compl))?
};

token.disarm();

Ok(Json(inner))
}

Expand Down
2 changes: 1 addition & 1 deletion scripts/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def main():
PROMPT_SIZE = 40_000
NUM_REPS = 1
NUM_JOKES = 100
MAX_TOKENS = 4000
MAX_TOKENS = 400
one_round()
return

Expand Down

0 comments on commit aaddb60

Please sign in to comment.