Skip to content

Commit

Permalink
fix: move alias counting to Context
Browse files Browse the repository at this point in the history
  • Loading branch information
jacek-prisma committed Jan 7, 2025
1 parent 444210d commit 0794925
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 167 deletions.
16 changes: 16 additions & 0 deletions query-engine/connectors/sql-query-connector/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::sync::{self, atomic::AtomicUsize};

use quaint::prelude::ConnectionInfo;
use telemetry::TraceParent;

use crate::filter::alias::Alias;

pub(super) struct Context<'a> {
connection_info: &'a ConnectionInfo,
pub(crate) traceparent: Option<TraceParent>,
Expand All @@ -10,6 +14,8 @@ pub(super) struct Context<'a> {
/// Maximum number of bind parameters allowed for a single query.
/// None is unlimited.
pub(crate) max_bind_values: Option<usize>,

alias_counter: AtomicUsize,
}

impl<'a> Context<'a> {
Expand All @@ -22,10 +28,20 @@ impl<'a> Context<'a> {
traceparent,
max_insert_rows,
max_bind_values: Some(max_bind_values),

alias_counter: Default::default(),
}
}

pub(crate) fn schema_name(&self) -> &str {
self.connection_info.schema_name()
}

pub(crate) fn next_table_alias(&self) -> Alias {
Alias::Table(self.alias_counter.fetch_add(1, sync::atomic::Ordering::SeqCst))
}

pub(crate) fn next_join_alias(&self) -> Alias {
Alias::Join(self.alias_counter.fetch_add(1, sync::atomic::Ordering::SeqCst))
}
}
60 changes: 19 additions & 41 deletions query-engine/connectors/sql-query-connector/src/filter/alias.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1,37 @@
use std::fmt;

use crate::{model_extensions::AsColumn, *};

use quaint::prelude::Column;
use query_structure::ScalarField;

#[derive(Clone, Copy, Debug)]
/// A distinction in aliasing to separate the parent table and the joined data
/// in the statement.
#[derive(Default)]
pub enum AliasMode {
#[default]
Table,
Join,
}

#[derive(Clone, Copy, Debug, Default)]
/// Aliasing tool to count the nesting level to help with heavily nested
/// self-related queries.
pub struct Alias {
counter: usize,
mode: AliasMode,
#[derive(Debug, Clone, Copy)]
pub enum Alias {
Table(usize),
Join(usize),
}

impl Alias {
/// Increment the alias as a new copy.
///
/// Use when nesting one level down to a new subquery. `AliasMode` is
/// required due to the fact the current mode can be in `AliasMode::Join`.
pub fn inc(&self, mode: AliasMode) -> Self {
Self {
counter: self.counter + 1,
mode,
pub fn to_join_alias(self) -> Self {
match self {
Self::Table(index) | Self::Join(index) => Self::Join(index),
}
}

/// Flip the alias to a different mode keeping the same nesting count.
pub fn flip(&self, mode: AliasMode) -> Self {
Self {
counter: self.counter,
mode,
pub fn to_table_alias(self) -> Self {
match self {
Self::Table(index) | Self::Join(index) => Self::Table(index),
}
}
}

/// A string representation of the current alias. The current mode can be
/// overridden by defining the `mode_override`.
pub fn to_string(&self, mode_override: Option<AliasMode>) -> String {
match mode_override.unwrap_or(self.mode) {
AliasMode::Table => format!("t{}", self.counter),
AliasMode::Join => format!("j{}", self.counter),
impl fmt::Display for Alias {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Table(index) => write!(f, "t{}", index),
Self::Join(index) => write!(f, "j{}", index),
}
}

#[cfg(feature = "relation_joins")]
pub fn to_table_string(&self) -> String {
self.to_string(Some(AliasMode::Table))
}
}

pub(crate) trait AliasedColumn {
Expand All @@ -73,7 +51,7 @@ impl AliasedColumn for &ScalarField {
impl AliasedColumn for Column<'static> {
fn aliased_col(self, alias: Option<Alias>, _ctx: &Context<'_>) -> Column<'static> {
match alias {
Some(alias) => self.table(alias.to_string(None)),
Some(alias) => self.table(alias.to_string()),
None => self,
}
}
Expand Down
66 changes: 18 additions & 48 deletions query-engine/connectors/sql-query-connector/src/filter/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ pub(crate) trait FilterVisitorExt {

#[derive(Debug, Clone, Default)]
pub struct FilterVisitor {
/// The last alias that's been rendered.
last_alias: Option<Alias>,
/// The parent alias, used when rendering nested filters so that a child filter can refer to its join.
parent_alias: Option<Alias>,
/// Whether filters can return top-level joins.
Expand All @@ -56,14 +54,6 @@ impl FilterVisitor {
}
}

/// Returns the next join/table alias by increasing the counter of the last alias.
fn next_alias(&mut self, mode: AliasMode) -> Alias {
let next_alias = self.last_alias.unwrap_or_default().inc(mode);
self.last_alias = Some(next_alias);

next_alias
}

/// Returns the parent alias, if there's one set, so that nested filters can refer to the parent join/table.
fn parent_alias(&self) -> Option<Alias> {
self.parent_alias
Expand Down Expand Up @@ -92,14 +82,6 @@ impl FilterVisitor {
res
}

fn update_last_alias(&mut self, nested_visitor: &Self) -> &mut Self {
if let Some(alias) = nested_visitor.last_alias {
self.last_alias = Some(alias);
}

self
}

fn create_nested_visitor(&self, parent_alias: Alias) -> Self {
let mut nested_visitor = self.clone();
nested_visitor.is_nested = true;
Expand All @@ -111,8 +93,6 @@ impl FilterVisitor {
fn visit_nested_filter<T>(&mut self, parent_alias: Alias, f: impl FnOnce(&mut Self) -> T) -> T {
let mut nested_visitor = self.create_nested_visitor(parent_alias);
let res = f(&mut nested_visitor);
// Ensures the alias counter is updated after building the nested filter so that we don't render duplicate aliases.
self.update_last_alias(&nested_visitor);

res

Check failure on line 97 in query-engine/connectors/sql-query-connector/src/filter/visitor.rs

View workflow job for this annotation

GitHub Actions / clippy linting

returning the result of a `let` binding from a block
}
Expand Down Expand Up @@ -154,31 +134,31 @@ impl FilterVisitor {
/// - For M2M relations, as we need to traverse the join table so the join is not superfluous
/// - SQL Server because it does not support (a, b) IN (subselect)
fn visit_relation_filter_select_no_row(&mut self, filter: RelationFilter, ctx: &Context<'_>) -> Select<'static> {
let alias = self.next_alias(AliasMode::Table);
let table_alias = ctx.next_table_alias();
let condition = filter.condition;
let table = filter.field.as_table(ctx);
let ids = ModelProjection::from(filter.field.model().primary_identifier());

let selected_identifier: Vec<Column> = filter
.field
.identifier_columns(ctx)
.map(|col| col.aliased_col(Some(alias), ctx))
.map(|col| col.aliased_col(Some(table_alias), ctx))
.collect();

let join_columns: Vec<Column> = filter
.field
.join_columns(ctx)
.map(|c| c.aliased_col(Some(alias), ctx))
.map(|c| c.aliased_col(Some(table_alias), ctx))
.collect();

let related_table = filter.field.related_model().as_table(ctx);
let related_join_columns: Vec<_> = ModelProjection::from(filter.field.related_field().linking_fields())
.as_columns(ctx)
.map(|col| col.aliased_col(Some(alias.flip(AliasMode::Join)), ctx))
.map(|col| col.aliased_col(Some(table_alias.to_join_alias()), ctx))
.collect();

let (nested_conditions, nested_joins) = self
.visit_nested_filter(alias.flip(AliasMode::Join), |nested_visitor| {
.visit_nested_filter(table_alias.to_join_alias(), |nested_visitor| {
nested_visitor.visit_filter(*filter.nested_filter, ctx)
});

Expand All @@ -197,10 +177,10 @@ impl FilterVisitor {
.fold(nested_conditions, |acc, column| acc.and(column.is_not_null()));

let join = related_table
.alias(alias.to_string(Some(AliasMode::Join)))
.alias(table_alias.to_join_alias().to_string())
.on(Row::from(related_join_columns).equals(Row::from(join_columns)));

let select = Select::from_table(table.alias(alias.to_string(Some(AliasMode::Table))))
let select = Select::from_table(table.alias(table_alias.to_string()))
.columns(selected_identifier)
.inner_join(join)
.so_that(nested_conditons);
Expand All @@ -221,7 +201,7 @@ impl FilterVisitor {
/// )
/// ```
fn visit_relation_filter_select_with_row(&mut self, filter: RelationFilter, ctx: &Context<'_>) -> Select<'static> {
let alias = self.next_alias(AliasMode::Table);
let alias = ctx.next_table_alias();
let condition = filter.condition;
let linking_fields = ModelProjection::from(filter.field.linking_fields());

Expand Down Expand Up @@ -250,7 +230,7 @@ impl FilterVisitor {
.into_iter()
.fold(nested_conditions, |acc, column| acc.and(column.is_not_null()));

let select = Select::from_table(related_table.alias(alias.to_string(Some(AliasMode::Table))))
let select = Select::from_table(related_table.alias(alias.to_string()))
.columns(related_columns)
.so_that(conditions);

Expand Down Expand Up @@ -395,12 +375,12 @@ impl FilterVisitorExt for FilterVisitor {
filter: RelationFilter,
ctx: &Context<'_>,
) -> (ConditionTree<'static>, Option<Vec<AliasedJoin>>) {
let parent_alias = self.parent_alias().map(|a| a.to_string(None));
let parent_alias = self.parent_alias().map(|a| a.to_string());

match &filter.condition {
// { to_one: { isNot: { ... } } }
RelationCondition::NoRelatedRecord if self.can_render_join() && !filter.field.is_list() => {
let alias = self.next_alias(AliasMode::Join);
let alias = ctx.next_join_alias();

let linking_fields_null: Vec<_> =
ModelProjection::from(filter.field.related_model().primary_identifier())
Expand All @@ -411,12 +391,7 @@ impl FilterVisitorExt for FilterVisitor {
.collect();
let null_filter = ConditionTree::And(linking_fields_null);

let join = compute_one2m_join(
&filter.field,
alias.to_string(None).as_str(),
parent_alias.as_deref(),
ctx,
);
let join = compute_one2m_join(&filter.field, alias.to_string().as_str(), parent_alias.as_deref(), ctx);

let mut output_joins = vec![join];

Expand All @@ -433,7 +408,7 @@ impl FilterVisitorExt for FilterVisitor {
}
// { to_one: { is: { ... } } }
RelationCondition::ToOneRelatedRecord if self.can_render_join() && !filter.field.is_list() => {
let alias = self.next_alias(AliasMode::Join);
let alias = ctx.next_join_alias();

let linking_fields_not_null: Vec<_> =
ModelProjection::from(filter.field.related_model().primary_identifier())
Expand All @@ -444,12 +419,7 @@ impl FilterVisitorExt for FilterVisitor {
.collect();
let not_null_filter = ConditionTree::And(linking_fields_not_null);

let join = compute_one2m_join(
&filter.field,
alias.to_string(None).as_str(),
parent_alias.as_deref(),
ctx,
);
let join = compute_one2m_join(&filter.field, alias.to_string().as_str(), parent_alias.as_deref(), ctx);
let mut output_joins = vec![join];

let (conditions, nested_joins) = self.visit_nested_filter(alias, |nested_visitor| {
Expand Down Expand Up @@ -487,7 +457,7 @@ impl FilterVisitorExt for FilterVisitor {
ctx: &Context<'_>,
) -> (ConditionTree<'static>, Option<Vec<AliasedJoin>>) {
let parent_alias = self.parent_alias();
let parent_alias_string = parent_alias.as_ref().map(|a| a.to_string(None));
let parent_alias_string = parent_alias.as_ref().map(|a| a.to_string());

// If the relation is inlined, we simply check whether the linking fields are null.
//
Expand All @@ -514,7 +484,7 @@ impl FilterVisitorExt for FilterVisitor {
// WHERE "j1"."parentId" IS NULL OFFSET;
// ```
if self.can_render_join() {
let alias = self.next_alias(AliasMode::Join);
let alias = ctx.next_join_alias();

let conditions: Vec<_> = ModelProjection::from(filter.field.related_field().linking_fields())
.as_columns(ctx)
Expand All @@ -525,7 +495,7 @@ impl FilterVisitorExt for FilterVisitor {

let join = compute_one2m_join(
&filter.field,
alias.to_string(None).as_str(),
alias.to_string().as_str(),
parent_alias_string.as_deref(),
ctx,
);
Expand All @@ -544,7 +514,7 @@ impl FilterVisitorExt for FilterVisitor {
let relation = filter.field.relation();
let table = relation.as_table(ctx);
let relation_table = match parent_alias {
Some(ref alias) => table.alias(alias.to_string(None)),
Some(ref alias) => table.alias(alias.to_table_alias().to_string()),
None => table,
};

Expand Down
Loading

0 comments on commit 0794925

Please sign in to comment.