From 71ba574fa98d827a7416467b8d6a6c9f192a68a6 Mon Sep 17 00:00:00 2001 From: KOBAYASHI Kazuhiro Date: Tue, 23 Apr 2024 08:07:11 +0900 Subject: [PATCH] kble: wait for child process --- kble/src/app.rs | 98 ++++++++++++++++++++++++++++++++++++------------ kble/src/plug.rs | 39 ++++++++++++++----- 2 files changed, 103 insertions(+), 34 deletions(-) diff --git a/kble/src/app.rs b/kble/src/app.rs index e4651ec..01037c0 100644 --- a/kble/src/app.rs +++ b/kble/src/app.rs @@ -8,10 +8,16 @@ use std::collections::HashMap; use tokio::sync::broadcast; use tracing::{debug, warn}; +struct Connection { + backend: plug::Backend, + stream: Option, + sink: Option, +} + struct Connections<'a> { // Some: connections not used yet // None: connections is used in a link - map: HashMap<&'a str, (Option, Option)>, + map: HashMap<&'a str, Connection>, } struct Link<'a> { @@ -37,8 +43,10 @@ pub async fn run(config: &Config) -> Result<()> { let links = future::join_all(link_futs).await; let links = links.into_iter().chain(std::iter::once(terminated_link)); - let link_close_futs = future::try_join_all(links.map(|link| link.close())); - future::try_join(conns.close_all(), link_close_futs).await?; + for link in links { + conns.return_link(link); + } + conns.close_and_wait().await?; Ok(()) } @@ -50,30 +58,77 @@ impl<'a> Connections<'a> { } } - fn insert(&mut self, name: &'a str, stream: plug::PlugStream, sink: plug::PlugSink) { - self.map.insert(name, (Some(stream), Some(sink))); + fn insert( + &mut self, + name: &'a str, + backend: plug::Backend, + stream: plug::PlugStream, + sink: plug::PlugSink, + ) { + self.map.insert( + name, + Connection { + backend, + stream: Some(stream), + sink: Some(sink), + }, + ); + } + + fn return_link(&mut self, link: Link<'a>) { + let conn = self.map.get_mut(link.source_name).unwrap_or_else(|| { + panic!( + "tried to return a invalid link with source name {}", + link.source_name, + ) + }); + conn.stream = Some(link.source); + + let conn = self.map.get_mut(link.dest_name).unwrap_or_else(|| { + panic!( + "tried to return a invalid link with dest name {}", + link.dest_name, + ) + }); + conn.sink = Some(link.dest); } - // close all connections whose sink is not used in a link - async fn close_all(self) -> Result<()> { - let futs = self.map.into_iter().map(|(name, (_, sink))| async move { - if let Some(mut s) = sink { - debug!("Closing {name}"); - s.close().await?; - debug!("Closed {name}"); + // close all connections + // assume all links are returned + async fn close_and_wait(self) -> Result<()> { + let futs = self.map.into_iter().map(|(name, conn)| async move { + let fut = async { + if let Some(mut s) = conn.sink { + debug!("Closing {name}"); + s.close().await?; + debug!("Closed {name}"); + } + debug!("Waiting for plug {name} to exit"); + conn.backend.wait().await?; + debug!("Plug {name} exited"); + anyhow::Ok(()) + }; + let close_result = tokio::time::timeout(std::time::Duration::from_secs(10), fut).await; + + match close_result { + Ok(result) => result, + Err(_) => { + // abandon the connection + warn!("Plug {name} didn't exit in time"); + Ok(()) + } } - anyhow::Ok(()) }); future::try_join_all(futs).await?; Ok(()) } fn take_stream(&mut self, name: &str) -> Option { - self.map.get_mut(name)?.0.take() + self.map.get_mut(name)?.stream.take() } fn take_sink(&mut self, name: &str) -> Option { - self.map.get_mut(name)?.1.take() + self.map.get_mut(name)?.sink.take() } } @@ -87,16 +142,16 @@ async fn connect_to_plugs(config: &Config) -> Result { } }); - let (sink, stream) = match connect_result { + let (backend, sink, stream) = match connect_result { Ok(p) => p, Err(e) => { warn!("Error connecting to {name}: {e}"); - conns.close_all().await?; + conns.close_and_wait().await?; return Err(e); } }; debug!("Connected to {name}"); - conns.insert(name.as_str(), stream, sink); + conns.insert(name.as_str(), backend, stream, sink); } Ok(conns) } @@ -149,11 +204,4 @@ impl<'a> Link<'a> { } self } - - async fn close(mut self) -> Result<()> { - debug!("Closing {}", self.dest_name); - self.dest.close().await?; - debug!("Closed {}", self.dest_name); - Ok(()) - } } diff --git a/kble/src/plug.rs b/kble/src/plug.rs index 41ba096..d43c779 100644 --- a/kble/src/plug.rs +++ b/kble/src/plug.rs @@ -5,7 +5,7 @@ use futures::{future, stream, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use pin_project::pin_project; use tokio::{ io::{AsyncRead, AsyncWrite}, - process::{ChildStdin, ChildStdout}, + process::{Child, ChildStdin, ChildStdout}, }; use tokio_tungstenite::{ tungstenite::{protocol::Role, Message}, @@ -16,7 +16,26 @@ use url::Url; pub type PlugSink = Pin, Error = anyhow::Error> + Send + 'static>>; pub type PlugStream = Pin>> + Send + 'static>>; -pub async fn connect(url: &Url) -> Result<(PlugSink, PlugStream)> { +pub enum Backend { + WebSocketClient, + StdioProcess(Child), +} + +impl Backend { + pub async fn wait(self) -> Result<()> { + match self { + Backend::WebSocketClient => Ok(()), + Backend::StdioProcess(mut proc) => { + proc.wait() + .await + .with_context(|| format!("Failed to wait for {:?}", proc))?; + Ok(()) + } + } + } +} + +pub async fn connect(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> { match url.scheme() { "exec" => connect_exec(url).await, "ws" | "wss" => connect_ws(url).await, @@ -24,7 +43,7 @@ pub async fn connect(url: &Url) -> Result<(PlugSink, PlugStream)> { } } -async fn connect_exec(url: &Url) -> Result<(PlugSink, PlugStream)> { +async fn connect_exec(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> { assert_eq!(url.scheme(), "exec"); ensure!(url.username().is_empty()); ensure!(url.password().is_none()); @@ -32,18 +51,19 @@ async fn connect_exec(url: &Url) -> Result<(PlugSink, PlugStream)> { ensure!(url.port().is_none()); ensure!(url.query().is_none()); ensure!(url.fragment().is_none()); - let proc = tokio::process::Command::new("sh") + let mut proc = tokio::process::Command::new("sh") .args(["-c", url.path()]) .stderr(Stdio::inherit()) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .spawn() .with_context(|| format!("Failed to spawn {}", url))?; - let stdin = proc.stdin.unwrap(); - let stdout = proc.stdout.unwrap(); + let stdin = proc.stdin.take().unwrap(); + let stdout = proc.stdout.take().unwrap(); let stdio = ChildStdio { stdin, stdout }; let wss = WebSocketStream::from_raw_socket(stdio, Role::Client, None).await; - Ok(wss_to_pair(wss)) + let (stream, sink) = wss_to_pair(wss); + Ok((Backend::StdioProcess(proc), stream, sink)) } #[pin_project] @@ -89,11 +109,12 @@ impl AsyncRead for ChildStdio { } } -async fn connect_ws(url: &Url) -> Result<(PlugSink, PlugStream)> { +async fn connect_ws(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> { let (wss, _resp) = tokio_tungstenite::connect_async(url) .await .with_context(|| format!("Failed to connect to {}", url))?; - Ok(wss_to_pair(wss)) + let (stream, sink) = wss_to_pair(wss); + Ok((Backend::WebSocketClient, stream, sink)) } fn wss_to_pair(wss: WebSocketStream) -> (PlugSink, PlugStream)