Skip to content

Commit

Permalink
enh: use specific error variant for exceed max rows error. (#3699)
Browse files Browse the repository at this point in the history
Co-authored-by: Henry Fontanier <[email protected]>
  • Loading branch information
fontanierh and Henry Fontanier authored Feb 14, 2024
1 parent f871058 commit aebc77a
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 66 deletions.
2 changes: 1 addition & 1 deletion core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2309,7 +2309,7 @@ async fn sqlite_workers_heartbeat(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to expire SQLite worker databases",
Some(e),
Some(e.into()),
)
}
Ok(_) => (),
Expand Down
22 changes: 15 additions & 7 deletions core/bin/sqlite_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use datadog_formatting_layer::DatadogFormattingLayer;
use dust::{
databases::database::Table,
databases_store::{self, store::DatabasesStore},
sqlite_workers::sqlite_database::SqliteDatabase,
sqlite_workers::sqlite_database::{QueryError, SqliteDatabase},
utils::{error_response, APIResponse},
};
use hyper::{Body, Client, Request, StatusCode};
Expand Down Expand Up @@ -216,12 +216,20 @@ async fn databases_query(
response: Some(json!(results)),
}),
),
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to query database",
Some(e),
),
Err(e) => match e {
QueryError::ExceededMaxRows(max) => error_response(
StatusCode::BAD_REQUEST,
"too_many_result_rows",
&format!("Result contains too many rows (max: {})", max),
Some(e.into()),
),
_ => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to query database",
Some(e.into()),
),
},
}
}

Expand Down
142 changes: 103 additions & 39 deletions core/src/sqlite_workers/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use hyper::{Body, Client, Request};
use hyper::{body::Bytes, Body, Client, Request};
use serde::{Deserialize, Serialize};
use serde_json::json;
use urlencoding::encode;
Expand All @@ -17,6 +17,31 @@ pub struct SqliteWorker {
url: String,
}

#[derive(Debug)]
pub enum SqliteWorkerError {
ClientError(anyhow::Error),
ServerError(anyhow::Error, Option<String>, u16),
}

impl std::fmt::Display for SqliteWorkerError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::ClientError(e) => write!(f, "SqliteWorkerError: Client error: {}", e),
Self::ServerError(e, code, status) => {
write!(
f,
"SqliteWorkerError (code={}, status={}): Server error: {}",
code.clone().unwrap_or_default(),
status,
e
)
}
}
}
}

impl std::error::Error for SqliteWorkerError {}

impl SqliteWorker {
pub fn new(url: String, last_heartbeat: u64) -> Self {
Self {
Expand All @@ -41,7 +66,7 @@ impl SqliteWorker {
database_unique_id: &str,
tables: &Vec<Table>,
query: &str,
) -> Result<Vec<QueryResult>> {
) -> Result<Vec<QueryResult>, SqliteWorkerError> {
let worker_url = self.url();

let req = Request::builder()
Expand All @@ -58,70 +83,109 @@ impl SqliteWorker {
"query": query,
})
.to_string(),
))?;
))
.map_err(|e| {
SqliteWorkerError::ClientError(anyhow!("Failed to build request: {}", e))
})?;

let res = Client::new().request(req).await?;
let res = Client::new().request(req).await.map_err(|e| {
SqliteWorkerError::ClientError(anyhow!("Failed to execute request: {}", e))
})?;

let body_bytes = get_response_body(res).await?;

#[derive(Deserialize)]
struct ExecuteQueryResponseBody {
error: Option<String>,
response: Option<Vec<QueryResult>>,
}

match res.status().as_u16() {
200 => {
let body = hyper::body::to_bytes(res.into_body()).await?;
let res: ExecuteQueryResponseBody = serde_json::from_slice(&body)?;
match res.error {
Some(e) => Err(anyhow!("Error executing query: {}", e))?,
None => match res.response {
Some(r) => Ok(r),
None => Err(anyhow!("No response found"))?,
},
}
}
s => Err(anyhow!(
"Failed to execute query on sqlite worker. Status: {}",
s
))?,
let body: ExecuteQueryResponseBody = serde_json::from_slice(&body_bytes).map_err(|e| {
SqliteWorkerError::ClientError(anyhow!("Failed to parse response: {}", e))
})?;

match body.error {
Some(e) => Err(SqliteWorkerError::ServerError(anyhow!(e), None, 200))?,
None => match body.response {
Some(r) => Ok(r),
None => Err(SqliteWorkerError::ServerError(
anyhow!("No response in body"),
None,
200,
))?,
},
}
}

pub async fn invalidate_database(&self, database_unique_id: &str) -> Result<()> {
pub async fn invalidate_database(
&self,
database_unique_id: &str,
) -> Result<(), SqliteWorkerError> {
let worker_url = self.url();

let req = Request::builder()
.method("DELETE")
.uri(format!("{}/databases/{}", worker_url, database_unique_id))
.body(Body::from(""))?;
.body(Body::from(""))
.map_err(|e| {
SqliteWorkerError::ClientError(anyhow!("Failed to build request: {}", e))
})?;

let res = Client::new().request(req).await?;
let res = Client::new().request(req).await.map_err(|e| {
SqliteWorkerError::ClientError(anyhow!("Failed to execute request: {}", e))
})?;

match res.status().as_u16() {
200 => Ok(()),
s => Err(anyhow!(
"Failed to invalidate database on sqlite worker. Status: {}",
s
))?,
}
let _ = get_response_body(res).await?;

Ok(())
}

pub async fn expire_all(&self) -> Result<()> {
pub async fn expire_all(&self) -> Result<(), SqliteWorkerError> {
let worker_url = self.url();

let req = Request::builder()
.method("DELETE")
.uri(format!("{}/databases", worker_url))
.body(Body::from(""))?;
.body(Body::from(""))
.map_err(|e| {
SqliteWorkerError::ClientError(anyhow!("Failed to build request: {}", e))
})?;

let res = Client::new().request(req).await.map_err(|e| {
SqliteWorkerError::ClientError(anyhow!("Failed to execute request: {}", e))
})?;
let _ = get_response_body(res).await?;

let res = Client::new().request(req).await?;
Ok(())
}
}

match res.status().as_u16() {
200 => Ok(()),
s => Err(anyhow!(
"Failed to expire all databases on sqlite worker. Status: {}",
s
))?,
async fn get_response_body(res: hyper::Response<hyper::Body>) -> Result<Bytes, SqliteWorkerError> {
let status = res.status().as_u16();
let body = hyper::body::to_bytes(res.into_body())
.await
.map_err(|e| SqliteWorkerError::ClientError(anyhow!("Failed to read response: {}", e)))?;

match status {
200 => Ok(body),
s => {
let body_json: serde_json::Value = serde_json::from_slice(&body).map_err(|e| {
SqliteWorkerError::ClientError(anyhow!("Failed to parse response: {}", e))
})?;
let error = body_json.get("error");
let error_code = match error {
Some(e) => e
.get("code")
.map(|c| c.as_str())
.flatten()
.map(|s| s.to_string()),
None => None,
};
Err(SqliteWorkerError::ServerError(
anyhow!("Received error response from SQLite worker",),
error_code,
s,
))?
}
}
}
80 changes: 61 additions & 19 deletions core/src/sqlite_workers/sqlite_database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,25 @@ pub struct SqliteDatabase {
interrupt_handle: Option<Arc<tokio::sync::Mutex<InterruptHandle>>>,
}

#[derive(Debug)]
pub enum QueryError {
ExceededMaxRows(usize),
QueryExecutionError(anyhow::Error),
}

impl std::fmt::Display for QueryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QueryError::ExceededMaxRows(limit) => {
write!(f, "Query returned more than {} rows", limit)
}
QueryError::QueryExecutionError(e) => write!(f, "Query execution error: {}", e),
}
}
}

impl std::error::Error for QueryError {}

const MAX_ROWS: usize = 128;

impl SqliteDatabase {
Expand Down Expand Up @@ -47,22 +66,26 @@ impl SqliteDatabase {
}
}

pub async fn query(&self, query: &str, timeout_ms: u64) -> Result<Vec<QueryResult>> {
pub async fn query(
&self,
query: &str,
timeout_ms: u64,
) -> Result<Vec<QueryResult>, QueryError> {
let query = query.to_string();
let conn = self.conn.clone();

let query_future = task::spawn_blocking(move || {
let conn = match conn {
Some(conn) => conn.clone(),
None => Err(anyhow!("Database not initialized"))?,
};
let conn = conn.ok_or(QueryError::QueryExecutionError(anyhow!(
"Database not initialized"
)))?;

// This lock is a parking_lot so it's blocking but we're in a spawn_blocking, so OK.
let conn = conn.lock();
let time_query_start = utils::now();

// Execute the query and collect results
let mut stmt = conn.prepare(&query)?;
let mut stmt = conn
.prepare(&query)
.map_err(|e| QueryError::QueryExecutionError(anyhow::Error::new(e)))?;

let column_names = stmt
.column_names()
.into_iter()
Expand Down Expand Up @@ -113,14 +136,23 @@ impl SqliteDatabase {
))
})
.collect::<Result<serde_json::Value>>()
})?
// Limit to 128 rows.
.take(MAX_ROWS)
.collect::<Result<Vec<_>>>()?
})
// At this point we have a result (from the query_and_then fn itself) of results (for each
// individual row parsing). We wrap the potential top-level error in a QueryError and bubble it up.
.map_err(|e| QueryError::QueryExecutionError(anyhow::Error::new(e)))?
.take(MAX_ROWS + 1)
.collect::<Result<Vec<_>, _>>()
// Thanks to the collect above, we now have a single top-level result.
// We wrap the potential error in a QueryError and bubble up if needed.
.map_err(QueryError::QueryExecutionError)?
.into_par_iter()
.map(|value| QueryResult { value })
.collect::<Vec<_>>();

if result_rows.len() > MAX_ROWS {
return Err(QueryError::ExceededMaxRows(MAX_ROWS));
}

info!(
duration = utils::now() - time_query_start,
"DSSTRUCTSTAT - WORKER Finished executing user query"
Expand All @@ -129,16 +161,26 @@ impl SqliteDatabase {
Ok(result_rows)
});

match timeout(std::time::Duration::from_millis(timeout_ms), query_future).await {
Ok(r) => r?,
match timeout(std::time::Duration::from_millis(timeout_ms), query_future)
.await
.map_err(|_| QueryError::QueryExecutionError(anyhow!("Join error")))?
{
Ok(r) => r,
Err(_) => {
let interrupt_handle = match &self.interrupt_handle {
Some(interrupt_handle) => interrupt_handle.clone(),
None => Err(anyhow!("Database not initialized"))?,
};
let interrupt_handle =
self.interrupt_handle
.as_ref()
.ok_or(QueryError::QueryExecutionError(anyhow!(
"Database is not initialized"
)))?;

let interrupt_handle = interrupt_handle.lock().await;
interrupt_handle.interrupt();
Err(anyhow!("Query execution timed out after {}ms", timeout_ms))?

Err(QueryError::QueryExecutionError(anyhow!(format!(
"Query execution timed out after {} ms",
timeout_ms
))))
}
}
}
Expand Down

0 comments on commit aebc77a

Please sign in to comment.