Skip to content

Commit

Permalink
kble: wait for child process
Browse files Browse the repository at this point in the history
  • Loading branch information
kobkaz committed Apr 22, 2024
1 parent 86eb08a commit 71ba574
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 34 deletions.
98 changes: 73 additions & 25 deletions kble/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@ use std::collections::HashMap;
use tokio::sync::broadcast;
use tracing::{debug, warn};

struct Connection {
backend: plug::Backend,
stream: Option<plug::PlugStream>,
sink: Option<plug::PlugSink>,
}

struct Connections<'a> {
// Some: connections not used yet
// None: connections is used in a link
map: HashMap<&'a str, (Option<plug::PlugStream>, Option<plug::PlugSink>)>,
map: HashMap<&'a str, Connection>,
}

struct Link<'a> {
Expand All @@ -37,8 +43,10 @@ pub async fn run(config: &Config) -> Result<()> {
let links = future::join_all(link_futs).await;
let links = links.into_iter().chain(std::iter::once(terminated_link));

let link_close_futs = future::try_join_all(links.map(|link| link.close()));
future::try_join(conns.close_all(), link_close_futs).await?;
for link in links {
conns.return_link(link);
}
conns.close_and_wait().await?;

Ok(())
}
Expand All @@ -50,30 +58,77 @@ impl<'a> Connections<'a> {
}
}

fn insert(&mut self, name: &'a str, stream: plug::PlugStream, sink: plug::PlugSink) {
self.map.insert(name, (Some(stream), Some(sink)));
fn insert(
&mut self,
name: &'a str,
backend: plug::Backend,
stream: plug::PlugStream,
sink: plug::PlugSink,
) {
self.map.insert(
name,
Connection {
backend,
stream: Some(stream),
sink: Some(sink),
},
);
}

fn return_link(&mut self, link: Link<'a>) {
let conn = self.map.get_mut(link.source_name).unwrap_or_else(|| {
panic!(
"tried to return a invalid link with source name {}",
link.source_name,
)
});
conn.stream = Some(link.source);

let conn = self.map.get_mut(link.dest_name).unwrap_or_else(|| {
panic!(
"tried to return a invalid link with dest name {}",
link.dest_name,
)
});
conn.sink = Some(link.dest);
}

// close all connections whose sink is not used in a link
async fn close_all(self) -> Result<()> {
let futs = self.map.into_iter().map(|(name, (_, sink))| async move {
if let Some(mut s) = sink {
debug!("Closing {name}");
s.close().await?;
debug!("Closed {name}");
// close all connections
// assume all links are returned
async fn close_and_wait(self) -> Result<()> {
let futs = self.map.into_iter().map(|(name, conn)| async move {
let fut = async {
if let Some(mut s) = conn.sink {
debug!("Closing {name}");
s.close().await?;
debug!("Closed {name}");
}
debug!("Waiting for plug {name} to exit");
conn.backend.wait().await?;
debug!("Plug {name} exited");
anyhow::Ok(())
};
let close_result = tokio::time::timeout(std::time::Duration::from_secs(10), fut).await;

match close_result {
Ok(result) => result,
Err(_) => {
// abandon the connection
warn!("Plug {name} didn't exit in time");
Ok(())
}
}
anyhow::Ok(())
});
future::try_join_all(futs).await?;
Ok(())
}

fn take_stream(&mut self, name: &str) -> Option<plug::PlugStream> {
self.map.get_mut(name)?.0.take()
self.map.get_mut(name)?.stream.take()
}

fn take_sink(&mut self, name: &str) -> Option<plug::PlugSink> {
self.map.get_mut(name)?.1.take()
self.map.get_mut(name)?.sink.take()
}
}

Expand All @@ -87,16 +142,16 @@ async fn connect_to_plugs(config: &Config) -> Result<Connections> {
}
});

let (sink, stream) = match connect_result {
let (backend, sink, stream) = match connect_result {
Ok(p) => p,
Err(e) => {
warn!("Error connecting to {name}: {e}");
conns.close_all().await?;
conns.close_and_wait().await?;
return Err(e);
}
};
debug!("Connected to {name}");
conns.insert(name.as_str(), stream, sink);
conns.insert(name.as_str(), backend, stream, sink);
}
Ok(conns)
}
Expand Down Expand Up @@ -149,11 +204,4 @@ impl<'a> Link<'a> {
}
self
}

async fn close(mut self) -> Result<()> {
debug!("Closing {}", self.dest_name);
self.dest.close().await?;
debug!("Closed {}", self.dest_name);
Ok(())
}
}
39 changes: 30 additions & 9 deletions kble/src/plug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use futures::{future, stream, Sink, SinkExt, Stream, StreamExt, TryStreamExt};
use pin_project::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite},
process::{ChildStdin, ChildStdout},
process::{Child, ChildStdin, ChildStdout},
};
use tokio_tungstenite::{
tungstenite::{protocol::Role, Message},
Expand All @@ -16,34 +16,54 @@ use url::Url;
pub type PlugSink = Pin<Box<dyn Sink<Vec<u8>, Error = anyhow::Error> + Send + 'static>>;
pub type PlugStream = Pin<Box<dyn Stream<Item = Result<Vec<u8>>> + Send + 'static>>;

pub async fn connect(url: &Url) -> Result<(PlugSink, PlugStream)> {
pub enum Backend {
WebSocketClient,
StdioProcess(Child),
}

impl Backend {
pub async fn wait(self) -> Result<()> {
match self {
Backend::WebSocketClient => Ok(()),
Backend::StdioProcess(mut proc) => {
proc.wait()
.await
.with_context(|| format!("Failed to wait for {:?}", proc))?;
Ok(())
}
}
}
}

pub async fn connect(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> {
match url.scheme() {
"exec" => connect_exec(url).await,
"ws" | "wss" => connect_ws(url).await,
_ => Err(anyhow!("Unsupported scheme: {}", url.scheme())),
}
}

async fn connect_exec(url: &Url) -> Result<(PlugSink, PlugStream)> {
async fn connect_exec(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> {
assert_eq!(url.scheme(), "exec");
ensure!(url.username().is_empty());
ensure!(url.password().is_none());
ensure!(url.host().is_none());
ensure!(url.port().is_none());
ensure!(url.query().is_none());
ensure!(url.fragment().is_none());
let proc = tokio::process::Command::new("sh")
let mut proc = tokio::process::Command::new("sh")
.args(["-c", url.path()])
.stderr(Stdio::inherit())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.with_context(|| format!("Failed to spawn {}", url))?;
let stdin = proc.stdin.unwrap();
let stdout = proc.stdout.unwrap();
let stdin = proc.stdin.take().unwrap();
let stdout = proc.stdout.take().unwrap();
let stdio = ChildStdio { stdin, stdout };
let wss = WebSocketStream::from_raw_socket(stdio, Role::Client, None).await;
Ok(wss_to_pair(wss))
let (stream, sink) = wss_to_pair(wss);
Ok((Backend::StdioProcess(proc), stream, sink))
}

#[pin_project]
Expand Down Expand Up @@ -89,11 +109,12 @@ impl AsyncRead for ChildStdio {
}
}

async fn connect_ws(url: &Url) -> Result<(PlugSink, PlugStream)> {
async fn connect_ws(url: &Url) -> Result<(Backend, PlugSink, PlugStream)> {
let (wss, _resp) = tokio_tungstenite::connect_async(url)
.await
.with_context(|| format!("Failed to connect to {}", url))?;
Ok(wss_to_pair(wss))
let (stream, sink) = wss_to_pair(wss);
Ok((Backend::WebSocketClient, stream, sink))
}

fn wss_to_pair<S>(wss: WebSocketStream<S>) -> (PlugSink, PlugStream)
Expand Down

0 comments on commit 71ba574

Please sign in to comment.