Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to request manual quit on tui #2489

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion burn-book/src/building-blocks/learner.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The learner builder provides numerous options when it comes to configurations.
| Renderer | Configure how to render metrics (default is CLI) |
| Grad Accumulation | Configure the number of steps before applying gradients |
| File Checkpointer | Configure how the model, optimizer and scheduler states are saved |
| Num Epochs | Set the number of epochs. |
| Num Epochs | Set the number of epochs |
| Devices | Set the devices to be used |
| Checkpoint | Restart training from a checkpoint |
| Application logging | Configure the application logging installer (default is writing to `experiment.log`) |
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-train/src/renderer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ pub use base::*;

mod cli;

/// The tui renderer
#[cfg(feature = "tui")]
mod tui;
pub mod tui;
use crate::TrainingInterrupter;

/// Return the default metrics renderer.
Expand Down
56 changes: 56 additions & 0 deletions crates/burn-train/src/renderer/tui/renderer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub struct TuiMetricsRenderer {
interuptor: TrainingInterrupter,
popup: PopupState,
previous_panic_hook: Option<Arc<PanicHook>>,
manual_quit: bool,
}

impl MetricsRenderer for TuiMetricsRenderer {
Expand Down Expand Up @@ -116,9 +117,15 @@ impl TuiMetricsRenderer {
interuptor,
popup: PopupState::Empty,
previous_panic_hook: Some(previous_panic_hook),
manual_quit: false,
}
}

/// Enable manual quit after training.
pub fn enable_manual_quit(&mut self) {
self.manual_quit = true;
}

fn render(&mut self) -> Result<(), Box<dyn Error>> {
let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS);
if self.last_update.elapsed() < tick_rate {
Expand Down Expand Up @@ -200,6 +207,49 @@ impl TuiMetricsRenderer {

Ok(())
}

fn handle_post_training(&mut self) -> Result<(), Box<dyn Error>> {
self.popup = PopupState::Full(
"Training is done".to_string(),
vec![Callback::new(
"Training Done",
"Press 'x' to close this popup. Press 'q' to exit the application after the \
popup is closed.",
'x',
PopupCancel,
)],
);

self.draw().ok();

loop {
if let Ok(true) = event::poll(Duration::from_millis(MAX_REFRESH_RATE_MILLIS)) {
match event::read() {
Ok(event @ Event::Key(key)) => {
if self.popup.is_empty() {
self.metrics_numeric.on_event(&event);
if let KeyCode::Char('q') = key.code {
break;
}
} else {
self.popup.on_event(&event);
}
self.draw().ok();
}

Ok(Event::Resize(..)) => {
self.draw().ok();
}
Err(err) => {
eprintln!("Error reading event: {}", err);
break;
}
_ => continue,
}
}
}
Ok(())
}
}

struct QuitPopupAccept(TrainingInterrupter);
Expand Down Expand Up @@ -230,6 +280,12 @@ impl Drop for TuiMetricsRenderer {
// Reset the terminal back to raw mode. This can be skipped during
// panicking because the panic hook has already reset the terminal
if !std::thread::panicking() {
if self.manual_quit {
if let Err(err) = self.handle_post_training() {
eprintln!("Error in post-training handling: {}", err);
}
}

disable_raw_mode().ok();
execute!(self.terminal.backend_mut(), LeaveAlternateScreen).unwrap();
self.terminal.show_cursor().ok();
Expand Down