Skip to content

Commit

Permalink
fix(cubesql): Don't clone AST on routing
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr committed Aug 27, 2024
1 parent 4366299 commit 343df70
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 34 deletions.
2 changes: 1 addition & 1 deletion rust/cubesql/cubesql/src/compile/query_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ impl QueryEngine for SqlQueryEngine {
}

fn sanitize_statement(&self, stmt: &Self::AstStatementType) -> Self::AstStatementType {
SensitiveDataSanitizer::new().replace(&stmt)
SensitiveDataSanitizer::new().replace(stmt.clone())
}
}

Expand Down
21 changes: 12 additions & 9 deletions rust/cubesql/cubesql/src/compile/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ impl QueryRouter {
("query".to_string(), stmt.to_string()),
(
"sanitizedQuery".to_string(),
SensitiveDataSanitizer::new().replace(stmt).to_string(),
SensitiveDataSanitizer::new()
.replace(stmt.clone())
.to_string(),
),
]));
let msg = err.message();
Expand Down Expand Up @@ -712,25 +714,26 @@ impl QueryRouter {
}
}

pub fn rewrite_statement(stmt: &ast::Statement) -> ast::Statement {
pub fn rewrite_statement(stmt: ast::Statement) -> ast::Statement {
let stmt = CastReplacer::new().replace(stmt);
let stmt = ToTimestampReplacer::new().replace(&stmt);
let stmt = UdfWildcardArgReplacer::new().replace(&stmt);
let stmt = DateTokenNormalizeReplacer::new().replace(&stmt);
let stmt = RedshiftDatePartReplacer::new().replace(&stmt);
let stmt = ApproximateCountDistinctVisitor::new().replace(&stmt);
let stmt = ToTimestampReplacer::new().replace(stmt);
let stmt = UdfWildcardArgReplacer::new().replace(stmt);
let stmt = DateTokenNormalizeReplacer::new().replace(stmt);

Check failure on line 721 in rust/cubesql/cubesql/src/compile/router.rs

View workflow job for this annotation

GitHub Actions / Build native Linux 18 x86_64-unknown-linux-gnu Python fallback

mismatched types
let stmt = RedshiftDatePartReplacer::new().replace(stmt);
let stmt = ApproximateCountDistinctVisitor::new().replace(stmt);

stmt
}

pub async fn convert_statement_to_cube_query(
stmt: &ast::Statement,
stmt: ast::Statement,
meta: Arc<MetaContext>,
session: Arc<Session>,
qtrace: &mut Option<Qtrace>,
span_id: Option<Arc<SpanId>>,
) -> CompilationResult<QueryPlan> {
let stmt = rewrite_statement(stmt);

if let Some(qtrace) = qtrace {
qtrace.set_visitor_replaced_statement(&stmt);
}
Expand All @@ -745,5 +748,5 @@ pub async fn convert_sql_to_cube_query(
session: Arc<Session>,
) -> CompilationResult<QueryPlan> {
let stmt = parse_sql_to_statement(&query, session.state.protocol.clone(), &mut None)?;
convert_statement_to_cube_query(&stmt, meta, session, &mut None, None).await
convert_statement_to_cube_query(stmt, meta, session, &mut None, None).await
}
8 changes: 4 additions & 4 deletions rust/cubesql/cubesql/src/compile/test/rewrite_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ pub async fn create_test_postgresql_cube_context(

pub fn query_to_logical_plan(query: String, context: &CubeContext) -> LogicalPlan {
let stmt = parse_sql_to_statement(&query, DatabaseProtocol::PostgreSQL, &mut None).unwrap();
let stmt = rewrite_statement(&stmt);
let stmt = rewrite_statement(stmt);
let df_query_planner = SqlToRel::new_with_options(context, true);

return df_query_planner
.statement_to_plan(Statement::Statement(Box::new(stmt.clone())))
.unwrap();
df_query_planner
.statement_to_plan(Statement::Statement(Box::new(stmt)))
.unwrap()
}

pub fn rewrite_runner(plan: LogicalPlan, context: Arc<CubeContext>) -> CubeRunner {
Expand Down
12 changes: 6 additions & 6 deletions rust/cubesql/cubesql/src/sql/postgres/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ impl AsyncPostgresShim {
.await?;

let plan = convert_statement_to_cube_query(
&prepared_statement,
prepared_statement,
meta,
self.session.clone(),
&mut None,
Expand Down Expand Up @@ -1132,10 +1132,10 @@ impl AsyncPostgresShim {
.await?;

let stmt_replacer = StatementPlaceholderReplacer::new();
let hacked_query = stmt_replacer.replace(&query)?;
let hacked_query = stmt_replacer.replace(query.clone())?;

let plan = convert_statement_to_cube_query(
&hacked_query,
hacked_query,
meta,
self.session.clone(),
qtrace,
Expand Down Expand Up @@ -1393,7 +1393,7 @@ impl AsyncPostgresShim {
})?;

let plan = convert_statement_to_cube_query(
&cursor.query,
cursor.query.clone(),
meta,
self.session.clone(),
qtrace,
Expand Down Expand Up @@ -1475,7 +1475,7 @@ impl AsyncPostgresShim {
let select_stmt = Statement::Query(query);
// It's just a verification that we can compile that query.
let _ = convert_statement_to_cube_query(
&select_stmt,
select_stmt.clone(),
meta.clone(),
self.session.clone(),
&mut None,
Expand Down Expand Up @@ -1648,7 +1648,7 @@ impl AsyncPostgresShim {
}
other => {
let plan = convert_statement_to_cube_query(
&other,
other,
meta.clone(),
self.session.clone(),
qtrace,
Expand Down
28 changes: 14 additions & 14 deletions rust/cubesql/cubesql/src/sql/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,8 @@ impl StatementPlaceholderReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> Result<ast::Statement, ConnectionError> {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> Result<ast::Statement, ConnectionError> {
let mut result = stmt;

self.visit_statement(&mut result)?;

Expand Down Expand Up @@ -671,8 +671,8 @@ impl CastReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down Expand Up @@ -874,8 +874,8 @@ impl RedshiftDatePartReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down Expand Up @@ -930,8 +930,8 @@ impl ToTimestampReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand All @@ -957,8 +957,8 @@ impl UdfWildcardArgReplacer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down Expand Up @@ -1046,8 +1046,8 @@ impl ApproximateCountDistinctVisitor {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down Expand Up @@ -1075,8 +1075,8 @@ impl SensitiveDataSanitizer {
Self {}
}

pub fn replace(mut self, stmt: &ast::Statement) -> ast::Statement {
let mut result = stmt.clone();
pub fn replace(mut self, stmt: ast::Statement) -> ast::Statement {
let mut result = stmt;

self.visit_statement(&mut result).unwrap();

Expand Down

0 comments on commit 343df70

Please sign in to comment.