From 6c65082874dac65e885b695cca8788b632391963 Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Tue, 7 Nov 2023 13:46:53 -0500 Subject: [PATCH] trino: Show source location for query errors This will make it easier to find errors in large transpiled queries. --- src/drivers/trino/mod.rs | 73 +++++++++++++++++++++++++++++++++------- src/errors.rs | 8 +++++ 2 files changed, 69 insertions(+), 12 deletions(-) diff --git a/src/drivers/trino/mod.rs b/src/drivers/trino/mod.rs index 9ccf4c0..4d6cc3d 100644 --- a/src/drivers/trino/mod.rs +++ b/src/drivers/trino/mod.rs @@ -1,16 +1,18 @@ //! Trino and maybe Presto driver. -use std::{fmt, str::FromStr}; +use std::{fmt, str::FromStr, sync::Arc}; use async_trait::async_trait; +use codespan_reporting::{diagnostic::Diagnostic, files::Files}; use once_cell::sync::Lazy; -use prusto::{Client, ClientBuilder, Presto, Row}; +use prusto::{error::Error as PrustoError, Client, ClientBuilder, Presto, QueryError, Row}; use regex::Regex; use tracing::debug; use crate::{ ast::Target, - errors::{format_err, Context, Error, Result}, + errors::{format_err, Context, Error, Result, SourceError}, + known_files::KnownFiles, transforms::{self, Transform, Udf}, util::AnsiIdent, }; @@ -161,8 +163,7 @@ impl Driver for TrinoDriver { self.client .execute(sql.to_owned()) .await - .map_err(abbreviate_trino_error) - .with_context(|| format!("Failed to execute SQL: {}", sql))?; + .map_err(|err| abbreviate_trino_error(sql, err))?; Ok(()) } @@ -193,10 +194,11 @@ impl Driver for TrinoDriver { #[tracing::instrument(skip(self))] async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> { + let sql = format!("DROP TABLE IF EXISTS {}", AnsiIdent(table_name)); self.client - .execute(format!("DROP TABLE IF EXISTS {}", AnsiIdent(table_name))) + .execute(sql.clone()) .await - .map_err(abbreviate_trino_error) + .map_err(|err| abbreviate_trino_error(&sql, err)) .with_context(|| format!("Failed to drop table: {}", table_name))?; Ok(()) } @@ -233,9 +235,9 @@ impl DriverImpl for TrinoDriver { ); Ok(self .client - .get_all::(sql) + .get_all::(sql.clone()) .await - .map_err(abbreviate_trino_error) + .map_err(|err| abbreviate_trino_error(&sql, err)) .with_context(|| format!("Failed to get columns for table: {}", table_name))? .into_vec() .into_iter() @@ -265,9 +267,9 @@ impl DriverImpl for TrinoDriver { ); let rows = self .client - .get_all::(sql) + .get_all::(sql.clone()) .await - .map_err(abbreviate_trino_error) + .map_err(|err| abbreviate_trino_error(&sql, err)) .with_context(|| format!("Failed to query table: {}", table_name))? .into_vec() .into_iter() @@ -313,7 +315,54 @@ impl fmt::Display for TrinoString<'_> { } /// These errors are pages long. -fn abbreviate_trino_error(e: prusto::error::Error) -> Error { +fn abbreviate_trino_error(sql: &str, e: PrustoError) -> Error { + if let PrustoError::QueryError(e) = &e { + // We can make these look pretty. + let QueryError { + message, + error_code, + error_location, + .. + } = e; + let mut files = KnownFiles::default(); + let file_id = files.add_string("trino.sql", sql); + + let offset = if let Some(loc) = error_location { + // We don't want to panic, because we're already processing an + // error, and the error comes from an external source. So just + // muddle through and return Span::Unknown or a bogus location + // if our input data is too odd. + // + // Convert from u32, defaulting negative values to 1. (Although + // lines count from 1.) + let line_number = usize::try_from(loc.line_number).unwrap_or(0); + let column_number = usize::try_from(loc.column_number).unwrap_or(0); + files + .line_range(file_id, line_number.saturating_sub(1)) + .ok() + .map(|r| r.start + column_number.saturating_sub(1)) + } else { + None + }; + + if let Some(offset) = offset { + let diagnostic = Diagnostic::error() + .with_message(message.clone()) + .with_code(format!("TRINO {}", error_code)) + .with_labels(vec![codespan_reporting::diagnostic::Label::primary( + file_id, + offset..offset, + ) + .with_message("Trino error")]); + + return Error::Source(Box::new(SourceError { + alternate_summary: message.clone(), + diagnostic, + files_override: Some(Arc::new(files)), + })); + } + } + let msg = e .to_string() .lines() diff --git a/src/errors.rs b/src/errors.rs index 6780b65..d5093a1 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -3,6 +3,7 @@ use std::{ error::{self, Error as _}, fmt, result, + sync::Arc, }; use anstream::eprintln; @@ -190,6 +191,10 @@ where pub struct SourceError { pub alternate_summary: String, pub diagnostic: Diagnostic, + /// If you're not using the standard set of known files, perhaps because + /// you're in a database driver, you can override the [`KnownFiles`] used to + /// display this error. + pub files_override: Option>, } impl SourceError { @@ -212,6 +217,7 @@ impl SourceError { SourceError { alternate_summary, diagnostic, + files_override: None, } } else { let alternate_summary = format!("{} (at unknown location): {}", summary, annotation); @@ -219,6 +225,7 @@ impl SourceError { SourceError { alternate_summary, diagnostic, + files_override: None, } } } @@ -227,6 +234,7 @@ impl SourceError { pub fn emit(&self, files: &KnownFiles) { let writer = StandardStream::stderr(ColorChoice::Auto); let config = term::Config::default(); + let files = self.files_override.as_deref().unwrap_or(files); term::emit(&mut writer.lock(), &config, files, &self.diagnostic) .expect("could not write to stderr"); }