Skip to content

Commit

Permalink
Connect child process to pty
Browse files Browse the repository at this point in the history
  • Loading branch information
mdepp committed Nov 30, 2023
1 parent 51e8a2b commit 5123f28
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 88 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ encoding_rs = "0.8"
env_logger = "0.10"
iced = { version = "0.10", features = ["canvas", "tokio"] }
log = "0.4"
pty-process = { version = "0.4.0", features = ["async"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.32", features = ["full"] }
tokio-util = "0.7.10"
unicode-segmentation = "1.10"
189 changes: 102 additions & 87 deletions src/child.rs
Original file line number Diff line number Diff line change
@@ -1,79 +1,35 @@
use encoding_rs::{CoderResult, SHIFT_JIS};
use iced::{
futures::{
channel::mpsc::{self, Sender},
join, SinkExt, StreamExt,
},
subscription, Subscription,
};
use log::debug;
use std::{
future::{pending, Future},
process::Stdio,
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWriteExt},
task,
};

use crate::config::Config;
use anyhow::Context;
use anyhow::{anyhow, Result};
use encoding_rs::{CoderResult, SHIFT_JIS};
use iced::futures::channel::mpsc::{Receiver, Sender};
use iced::futures::{SinkExt, StreamExt};
use iced::{futures::channel::mpsc, subscription, Subscription};
use log::{debug, error, info};
use std::future::pending;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::{join, select};
use tokio_util::sync::CancellationToken;

pub fn connect(config: Config) -> Subscription<OutputEvent> {
pub fn subscribe_to_pty(config: Config) -> Subscription<OutputEvent> {
struct Connect;

subscription::channel(
std::any::TypeId::of::<Connect>(),
config.channel_buf_size,
|mut send_output| async move {
let (send_input, mut recv_input) = mpsc::channel(config.channel_buf_size);
async move |mut send_output| {
let config = config.clone();
let (send_input, recv_input) = mpsc::channel(config.channel_buf_size);
send_output
.send(OutputEvent::Connected(send_input))
.await
.unwrap();

debug!("Connecting to shell...");
let mut shell_process = tokio::process::Command::new(config.shell.clone())
.args(config.shell_args.clone())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.stdin(Stdio::piped())
.spawn()
make_pty(config, send_output.clone(), recv_input)
.await
.with_context(|| "make_pty")
.unwrap();
debug!("Connected to shell.");

let stdout = shell_process.stdout.take().unwrap();
let stderr = shell_process.stderr.take().unwrap();
let mut stdin = shell_process.stdin.take().unwrap();

let stdout_future = decode_output(
stdout,
|text| async {
debug!("Read stdout: {text:?}");
let mut send_output = send_output.clone();
send_output.send(OutputEvent::Stdout(text)).await.unwrap();
},
config.clone(),
);
let stderr_future = decode_output(
stderr,
|text| async {
debug!("Read stderr: {text:?}");
let mut send_output = send_output.clone();
send_output.send(OutputEvent::Stderr(text)).await.unwrap();
},
config.clone(),
);

let stdin_handle = task::spawn(async move {
debug!("Waiting for input messages...");
while let Some(InputEvent::Stdin(text)) = recv_input.next().await {
debug!("Write stdin: {text:?}");
stdin.write_all(text.as_bytes()).await.unwrap();
}
});

join!(stdout_future, stderr_future);
stdin_handle.abort();
send_output.send(OutputEvent::Disconnected).await.unwrap();

pending::<()>().await;
Expand All @@ -82,34 +38,93 @@ pub fn connect(config: Config) -> Subscription<OutputEvent> {
)
}

async fn decode_output<T: AsyncRead, F: Future>(
bytestream: T,
mut cb: impl FnMut(String) -> F,
async fn make_pty(
config: Config,
) {
let mut bytestream = Box::pin(bytestream);
let mut decoder = SHIFT_JIS.new_decoder_without_bom_handling();
sender: Sender<OutputEvent>,
mut receiver: Receiver<InputEvent>,
) -> Result<()> {
let mut pty = pty_process::Pty::new()?;
let mut cmd = pty_process::Command::new(config.shell)
.args(config.shell_args)
.spawn(&pty.pts()?)?;

let (mut pty_reader, mut pty_writer) = pty.split();
let cancellation_token = CancellationToken::new();

let cloned_token = cancellation_token.clone();
let write_to_pty = async move || -> Result<()> {
loop {
select! {
_ = cloned_token.cancelled() => break,
message = receiver.next() => match message {
Some(InputEvent::Stdin(text)) => {
debug!("Receive {text} from stdin");
pty_writer.write_all(text.as_bytes()).await?;
debug!("Sent to pty");
}
None => break
}
}
}
debug!("Shutting down pty writer");
pty_writer.shutdown().await?;
Ok(())
};

let mut cloned_sender = sender.clone();
let cloned_token = cancellation_token.clone();
let read_from_pty = async move || -> Result<()> {
let mut decoder = SHIFT_JIS.new_decoder_without_bom_handling();
let mut readbuf = vec![0u8; config.read_buf_size];
let mut decodebuf = vec![
0u8;
decoder
.max_utf8_buffer_length(config.read_buf_size)
.ok_or(anyhow!("Could not find decodebuf length"))?
];

let mut last = false;
while !last {
select! {
_ = cloned_token.cancelled() => break,
nbytes = pty_reader.read(&mut readbuf) => match nbytes {
Ok(nbytes) => {
debug!("Read {nbytes} bytes from pty");
last = nbytes == 0;
let (result, nwritten, _, _) =
decoder.decode_to_utf8(&readbuf[..nbytes], &mut decodebuf, last);
assert!(
result == CoderResult::InputEmpty,
"Can't have OutputFull result since decode_buf_size was set sufficiently large"
);
let text = String::from_utf8(decodebuf[..nwritten].into())?;
cloned_sender.send(OutputEvent::Stdout(text)).await?;
}
Err(err) => {
error!("pty read error: {err}");
break;
}
}
}
}
debug!("Shutting down pty reader");
Ok(())
};

let mut readbuf = vec![0u8; config.read_buf_size];
let mut decodebuf = vec![
0u8;
decoder
.max_utf8_buffer_length(config.read_buf_size)
.unwrap()
];
let cleanup = async move || -> Result<()> {
let status = cmd.wait().await?;
info!("Shell finished with status {status}");
cancellation_token.cancel();
Ok(())
};

let mut last = false;
while !last {
let nbytes = bytestream.read(&mut readbuf).await.unwrap();
last = nbytes == 0;
debug!("Read {} bytes", nbytes);
let (result, nwritten, nread, replaced) =
decoder.decode_to_utf8(&readbuf[..nbytes], &mut decodebuf, last);
debug!("Decoded result={result:?} nwritten={nwritten} nread={nread} replaced={replaced}");
// Can't have OutputFull result since decode_buf_size was set sufficiently large
assert!(result == CoderResult::InputEmpty);
cb(String::from_utf8(decodebuf[..nwritten].into()).unwrap()).await;
}
let result = join!(write_to_pty(), read_from_pty(), cleanup());
result
.0
.with_context(|| "write_to_pty")
.and(result.1.with_context(|| "read_from_pty"))
.and(result.2.with_context(|| "cleanup"))?;
Ok(())
}

#[derive(Debug, Clone)]
Expand Down
3 changes: 2 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![feature(assert_matches)]
#![feature(try_trait_v2)]
#![feature(async_closure)]

mod child;
mod config;
Expand Down Expand Up @@ -95,7 +96,7 @@ impl Application for Firn {

fn subscription(&self) -> Subscription<Message> {
Subscription::batch([
child::connect(self.config.clone()).map(Message::ChildEvent),
child::subscribe_to_pty(self.config.clone()).map(Message::ChildEvent),
subscription::events_with(|event, status| match (&event, status) {
(Event::Keyboard(_), Status::Ignored) => Some(Message::ApplicationEvent(event)),
_ => None,
Expand Down

0 comments on commit 5123f28

Please sign in to comment.