From c7d91bb6a64db68a7419d9523d18550c6ce9f98b Mon Sep 17 00:00:00 2001 From: Max Isom Date: Mon, 30 Dec 2024 16:56:30 -0600 Subject: [PATCH] [BUG]: properly catch and propogate panics in component handlers --- rust/worker/src/execution/operator.rs | 17 ++-- .../src/execution/orchestration/compact.rs | 7 +- .../src/execution/orchestration/count.rs | 7 +- .../worker/src/execution/orchestration/get.rs | 87 ++++++++++--------- .../src/execution/orchestration/knn_filter.rs | 59 ++++++------- .../execution/orchestration/orchestrator.rs | 17 +++- rust/worker/src/system/receiver.rs | 2 - rust/worker/src/system/system.rs | 29 +++++-- rust/worker/src/system/types.rs | 12 +-- rust/worker/src/system/wrapped_message.rs | 27 ++---- rust/worker/src/utils/panic.rs | 31 ++++++- 11 files changed, 162 insertions(+), 133 deletions(-) diff --git a/rust/worker/src/execution/operator.rs b/rust/worker/src/execution/operator.rs index 9bc0e631492..86a15fe0867 100644 --- a/rust/worker/src/execution/operator.rs +++ b/rust/worker/src/execution/operator.rs @@ -1,4 +1,4 @@ -use crate::{system::ReceiverForMessage, utils::get_panic_message}; +use crate::{system::ReceiverForMessage, utils::PanicError}; use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; use futures::FutureExt; @@ -34,7 +34,7 @@ where #[derive(Debug, Error)] pub(super) enum TaskError { #[error("Panic occurred while handling task: {0:?}")] - Panic(Option), + Panic(PanicError), #[error("Task failed with error: {0:?}")] TaskFailed(#[from] Err), } @@ -149,13 +149,11 @@ where } } Err(panic_value) => { - let panic_message = get_panic_message(panic_value); - match self .reply_channel .send( TaskResult { - result: Err(TaskError::Panic(panic_message.clone())), + result: Err(TaskError::Panic(PanicError::new(panic_value))), task_id: self.task_id, }, None, @@ -171,12 +169,6 @@ where ); } }; - - // Re-panic so the message handler can catch it - panic!( - "{}", - panic_message.unwrap_or("Unknown panic occurred in task".to_string()) - ); } }; } @@ -296,6 +288,7 @@ mod tests { let result = &results_guard.first().unwrap().result; assert!(result.is_err()); - matches!(result, Err(TaskError::Panic(Some(msg))) if msg == "MockOperator panicking"); + let err = result.as_ref().unwrap_err(); + assert!(err.to_string().contains("MockOperator panicking")); } } diff --git a/rust/worker/src/execution/orchestration/compact.rs b/rust/worker/src/execution/orchestration/compact.rs index fb241c5f2fc..d305a5d873e 100644 --- a/rust/worker/src/execution/orchestration/compact.rs +++ b/rust/worker/src/execution/orchestration/compact.rs @@ -35,6 +35,7 @@ use crate::system::ChannelError; use crate::system::ComponentContext; use crate::system::ComponentHandle; use crate::system::Handler; +use crate::utils::PanicError; use async_trait::async_trait; use chroma_blockstore::provider::BlockfileProvider; use chroma_error::ChromaError; @@ -143,8 +144,8 @@ impl ChromaError for GetSegmentWritersError { #[derive(Error, Debug)] pub enum CompactionError { - #[error("Panic running task: {0}")] - Panic(String), + #[error("Panic during compaction: {0}")] + Panic(#[from] PanicError), #[error("FetchLog error: {0}")] FetchLog(#[from] FetchLogError), #[error("Partition error: {0}")] @@ -167,7 +168,7 @@ where { fn from(value: TaskError) -> Self { match value { - TaskError::Panic(e) => CompactionError::Panic(e.unwrap_or_default()), + TaskError::Panic(e) => CompactionError::Panic(e), TaskError::TaskFailed(e) => e.into(), } } diff --git a/rust/worker/src/execution/orchestration/count.rs b/rust/worker/src/execution/orchestration/count.rs index 8d91359b83c..eb38d608ebd 100644 --- a/rust/worker/src/execution/orchestration/count.rs +++ b/rust/worker/src/execution/orchestration/count.rs @@ -17,6 +17,7 @@ use crate::{ }, }, system::{ChannelError, ComponentContext, ComponentHandle, Handler}, + utils::PanicError, }; use super::orchestrator::Orchestrator; @@ -29,8 +30,8 @@ pub enum CountError { FetchLog(#[from] FetchLogError), #[error("Error running Count Record Operator: {0}")] CountRecord(#[from] CountRecordsError), - #[error("Panic running task: {0}")] - Panic(String), + #[error("Panic: {0}")] + Panic(#[from] PanicError), #[error("Error receiving final result: {0}")] Result(#[from] RecvError), } @@ -53,7 +54,7 @@ where { fn from(value: TaskError) -> Self { match value { - TaskError::Panic(e) => CountError::Panic(e.unwrap_or_default()), + TaskError::Panic(e) => CountError::Panic(e), TaskError::TaskFailed(e) => e.into(), } } diff --git a/rust/worker/src/execution/orchestration/get.rs b/rust/worker/src/execution/orchestration/get.rs index 2b955e3fc78..f41d2e27254 100644 --- a/rust/worker/src/execution/orchestration/get.rs +++ b/rust/worker/src/execution/orchestration/get.rs @@ -18,6 +18,7 @@ use crate::{ }, }, system::{ChannelError, ComponentContext, ComponentHandle, Handler}, + utils::PanicError, }; use super::orchestrator::Orchestrator; @@ -32,8 +33,8 @@ pub enum GetError { Filter(#[from] FilterError), #[error("Error running Limit Operator: {0}")] Limit(#[from] LimitError), - #[error("Panic running task: {0}")] - Panic(String), + #[error("Panic: {0}")] + Panic(#[from] PanicError), #[error("Error running Projection Operator: {0}")] Projection(#[from] ProjectionError), #[error("Error receiving final result: {0}")] @@ -60,7 +61,7 @@ where { fn from(value: TaskError) -> Self { match value { - TaskError::Panic(e) => GetError::Panic(e.unwrap_or_default()), + TaskError::Panic(e) => e.into(), TaskError::TaskFailed(e) => e.into(), } } @@ -75,46 +76,46 @@ type GetResult = Result; /// /// # Pipeline /// ```text -/// ┌────────────┐ -/// │ │ -/// │ on_start │ -/// │ │ -/// └──────┬─────┘ -/// │ -/// ▼ -/// ┌────────────────────┐ -/// │ │ -/// │ FetchLogOperator │ -/// │ │ -/// └─────────┬──────────┘ -/// │ -/// ▼ -/// ┌───────────────────┐ -/// │ │ -/// │ FilterOperator │ -/// │ │ -/// └─────────┬─────────┘ -/// │ -/// ▼ -/// ┌─────────────────┐ -/// │ │ -/// │ LimitOperator │ -/// │ │ -/// └────────┬────────┘ -/// │ -/// ▼ -/// ┌──────────────────────┐ -/// │ │ -/// │ ProjectionOperator │ -/// │ │ -/// └──────────┬───────────┘ -/// │ -/// ▼ -/// ┌──────────────────┐ -/// │ │ -/// │ result_channel │ -/// │ │ -/// └──────────────────┘ +/// ┌────────────┐ +/// │ │ +/// │ on_start │ +/// │ │ +/// └──────┬─────┘ +/// │ +/// ▼ +/// ┌────────────────────┐ +/// │ │ +/// │ FetchLogOperator │ +/// │ │ +/// └─────────┬──────────┘ +/// │ +/// ▼ +/// ┌───────────────────┐ +/// │ │ +/// │ FilterOperator │ +/// │ │ +/// └─────────┬─────────┘ +/// │ +/// ▼ +/// ┌─────────────────┐ +/// │ │ +/// │ LimitOperator │ +/// │ │ +/// └────────┬────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ │ +/// │ ProjectionOperator │ +/// │ │ +/// └──────────┬───────────┘ +/// │ +/// ▼ +/// ┌──────────────────┐ +/// │ │ +/// │ result_channel │ +/// │ │ +/// └──────────────────┘ /// ``` #[derive(Debug)] pub struct GetOrchestrator { diff --git a/rust/worker/src/execution/orchestration/knn_filter.rs b/rust/worker/src/execution/orchestration/knn_filter.rs index 54632c3755b..e24c16dba85 100644 --- a/rust/worker/src/execution/orchestration/knn_filter.rs +++ b/rust/worker/src/execution/orchestration/knn_filter.rs @@ -29,6 +29,7 @@ use crate::{ utils::distance_function_from_segment, }, system::{ChannelError, ComponentContext, ComponentHandle, Handler}, + utils::PanicError, }; use super::orchestrator::Orchestrator; @@ -57,8 +58,8 @@ pub enum KnnError { KnnProjection(#[from] KnnProjectionError), #[error("Error inspecting collection dimension")] NoCollectionDimension, - #[error("Panic running task: {0}")] - Panic(String), + #[error("Panic: {0}")] + Panic(#[from] PanicError), #[error("Error receiving final result: {0}")] Result(#[from] RecvError), #[error("Invalid distance function")] @@ -92,7 +93,7 @@ where { fn from(value: TaskError) -> Self { match value { - TaskError::Panic(e) => KnnError::Panic(e.unwrap_or_default()), + TaskError::Panic(e) => e.into(), TaskError::TaskFailed(e) => e.into(), } } @@ -116,32 +117,32 @@ type KnnFilterResult = Result; /// /// # Pipeline /// ```text -/// ┌────────────┐ -/// │ │ -/// │ on_start │ -/// │ │ -/// └──────┬─────┘ -/// │ -/// ▼ -/// ┌────────────────────┐ -/// │ │ -/// │ FetchLogOperator │ -/// │ │ -/// └─────────┬──────────┘ -/// │ -/// ▼ -/// ┌───────────────────┐ -/// │ │ -/// │ FilterOperator │ -/// │ │ -/// └─────────┬─────────┘ -/// │ -/// ▼ -/// ┌──────────────────┐ -/// │ │ -/// │ result_channel │ -/// │ │ -/// └──────────────────┘ +/// ┌────────────┐ +/// │ │ +/// │ on_start │ +/// │ │ +/// └──────┬─────┘ +/// │ +/// ▼ +/// ┌────────────────────┐ +/// │ │ +/// │ FetchLogOperator │ +/// │ │ +/// └─────────┬──────────┘ +/// │ +/// ▼ +/// ┌───────────────────┐ +/// │ │ +/// │ FilterOperator │ +/// │ │ +/// └─────────┬─────────┘ +/// │ +/// ▼ +/// ┌──────────────────┐ +/// │ │ +/// │ result_channel │ +/// │ │ +/// └──────────────────┘ /// ``` #[derive(Debug)] pub struct KnnFilterOrchestrator { diff --git a/rust/worker/src/execution/orchestration/orchestrator.rs b/rust/worker/src/execution/orchestration/orchestrator.rs index ddc96f93c26..9c629325d1a 100644 --- a/rust/worker/src/execution/orchestration/orchestrator.rs +++ b/rust/worker/src/execution/orchestration/orchestrator.rs @@ -1,20 +1,20 @@ -use core::fmt::Debug; -use std::any::type_name; - use async_trait::async_trait; use chroma_error::ChromaError; +use core::fmt::Debug; +use std::any::type_name; use tokio::sync::oneshot::{self, error::RecvError, Sender}; use tracing::Span; use crate::{ execution::{dispatcher::Dispatcher, operator::TaskMessage}, system::{ChannelError, Component, ComponentContext, ComponentHandle, System}, + utils::PanicError, }; #[async_trait] pub trait Orchestrator: Debug + Send + Sized + 'static { type Output: Send; - type Error: ChromaError + From + From; + type Error: ChromaError + From + From + From; /// Returns the handle of the dispatcher fn dispatcher(&self) -> ComponentHandle; @@ -108,4 +108,13 @@ impl Component for O { } } } + + fn on_handler_panic(&mut self, panic_value: Box) { + let channel = self.take_result_channel(); + let error = PanicError::new(panic_value); + + if channel.send(Err(O::Error::from(error))).is_err() { + tracing::error!("Error reporting panic to {}", Self::name()); + }; + } } diff --git a/rust/worker/src/system/receiver.rs b/rust/worker/src/system/receiver.rs index a2d79d8fb8e..cb688249247 100644 --- a/rust/worker/src/system/receiver.rs +++ b/rust/worker/src/system/receiver.rs @@ -70,8 +70,6 @@ pub enum RequestError { SendError, #[error("Failed to receive response")] ReceiveError, - #[error("Message handler panicked")] - HandlerPanic(Option), } impl ChromaError for RequestError { diff --git a/rust/worker/src/system/system.rs b/rust/worker/src/system/system.rs index bdd3f955b07..b94458baafa 100644 --- a/rust/worker/src/system/system.rs +++ b/rust/worker/src/system/system.rs @@ -138,7 +138,8 @@ where #[cfg(test)] mod tests { - use crate::system::RequestError; + use crate::utils::get_panic_message; + use std::sync::Mutex; use super::*; use async_trait::async_trait; @@ -147,13 +148,15 @@ mod tests { struct TestComponent { queue_size: usize, counter: usize, + caught_panic: Arc>>, } impl TestComponent { - fn new(queue_size: usize) -> Self { + fn new(queue_size: usize, caught_panic: Arc>>) -> Self { TestComponent { queue_size, counter: 0, + caught_panic, } } } @@ -185,12 +188,19 @@ mod tests { fn queue_size(&self) -> usize { self.queue_size } + + fn on_handler_panic(&mut self, panic_value: Box) { + self.caught_panic + .lock() + .unwrap() + .replace(get_panic_message(&panic_value).unwrap()); + } } #[tokio::test] async fn response_types() { let system = System::new(); - let component = TestComponent::new(10); + let component = TestComponent::new(10, Arc::new(Mutex::new(None))); let handle = system.start_component(component); assert_eq!(1, handle.request(1, None).await.unwrap()); @@ -198,15 +208,18 @@ mod tests { } #[tokio::test] - async fn catches_panic() { + async fn catches_handler_panic_with_hook() { + let caught_panic = Arc::new(Mutex::new(None)); + let system = System::new(); - let component = TestComponent::new(10); + let component = TestComponent::new(10, caught_panic.clone()); let handle = system.start_component(component); - let err = handle.request(0, None).await.unwrap_err(); + handle.request(0, None).await.unwrap_err(); + assert_eq!( - RequestError::HandlerPanic(Some("Invalid input".to_string())), - err + caught_panic.lock().unwrap().clone().unwrap(), + "Invalid input".to_string() ); // Component is still alive diff --git a/rust/worker/src/system/types.rs b/rust/worker/src/system/types.rs index 1ff3e9365f7..a423d89202d 100644 --- a/rust/worker/src/system/types.rs +++ b/rust/worker/src/system/types.rs @@ -44,6 +44,11 @@ pub trait Component: Send + Sized + Debug + 'static { ComponentRuntime::Inherit } async fn start(&mut self, _ctx: &ComponentContext) -> () {} + fn on_handler_panic(&mut self, panic: Box) { + // Default behavior is to log and then resume the panic + tracing::error!("Handler panicked: {:?}", panic); + std::panic::resume_unwind(panic); + } } /// A handler is a component that can process messages of a given type. @@ -154,12 +159,7 @@ impl ComponentSender { let result = rx.await.map_err(|_| RequestError::ReceiveError)?; - match result { - Ok(result) => Ok(result), - Err(err) => match err { - super::MessageHandlerError::Panic(p) => Err(RequestError::HandlerPanic(p)), - }, - } + Ok(result) } } diff --git a/rust/worker/src/system/wrapped_message.rs b/rust/worker/src/system/wrapped_message.rs index 23063ae812d..5b754eb5882 100644 --- a/rust/worker/src/system/wrapped_message.rs +++ b/rust/worker/src/system/wrapped_message.rs @@ -1,10 +1,7 @@ -use crate::utils::get_panic_message; - use super::{Component, ComponentContext, Handler, Message}; use async_trait::async_trait; use futures::FutureExt; use std::{fmt::Debug, panic::AssertUnwindSafe}; -use thiserror::Error; use tokio::sync::oneshot; // Why is this separate from the WrappedMessage struct? WrappedMessage is only generic @@ -28,14 +25,6 @@ impl HandleableMessageImpl { } } -#[derive(Debug, Error)] -pub(super) enum MessageHandlerError { - #[error("Panic occurred while handling message: {0:?}")] - Panic(Option), -} - -type MessageHandlerWrappedResult = Result; - /// Erases the type of the message so it can be sent over a channel and optionally bundles a tracing context. #[derive(Debug)] pub(crate) struct WrappedMessage @@ -49,7 +38,7 @@ where impl WrappedMessage { pub(super) fn new( message: M, - reply_channel: Option>>, + reply_channel: Option>, tracing_context: Option, ) -> Self where @@ -80,8 +69,7 @@ where } #[async_trait] -impl HandleableMessage - for Option>> +impl HandleableMessage for Option> where C: Component + Handler, M: Message, @@ -96,18 +84,13 @@ where Ok(result) => { if let Some(reply_channel) = message.reply_channel { reply_channel - .send(Ok(result)) + .send(result) .expect("message reply channel was unexpectedly dropped by caller"); } } Err(panic_value) => { - let panic_message = get_panic_message(panic_value); - - if let Some(reply_channel) = message.reply_channel { - reply_channel - .send(Err(MessageHandlerError::Panic(panic_message))) - .expect("message reply channel was unexpectedly dropped by caller"); - } + tracing::error!("Panic occurred while handling message: {:?}", panic_value); + component.on_handler_panic(panic_value); } }; } diff --git a/rust/worker/src/utils/panic.rs b/rust/worker/src/utils/panic.rs index 6d2dc6095db..a8feae081ff 100644 --- a/rust/worker/src/utils/panic.rs +++ b/rust/worker/src/utils/panic.rs @@ -1,7 +1,36 @@ +use chroma_error::ChromaError; use std::any::Any; +use thiserror::Error; + +#[derive(Error)] +#[error("Panic {:?}", get_panic_message(.0))] +pub struct PanicError(Box); + +impl std::fmt::Debug for PanicError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Panic: {:?}", + get_panic_message(&self.0) + .unwrap_or("panic does not have displayable message".to_string()) + ) + } +} + +impl PanicError { + pub(crate) fn new(panic_value: Box) -> Self { + PanicError(panic_value) + } +} + +impl ChromaError for PanicError { + fn code(&self) -> chroma_error::ErrorCodes { + chroma_error::ErrorCodes::Internal + } +} /// Extracts the panic message from the value returned by `std::panic::catch_unwind`. -pub(crate) fn get_panic_message(value: Box) -> Option { +pub(crate) fn get_panic_message(value: &Box) -> Option { #[allow(clippy::manual_map)] if let Some(s) = value.downcast_ref::<&str>() { Some(s.to_string())