Skip to content

Commit

Permalink
fix: relation filters should use linking fields on both sides of IN c…
Browse files Browse the repository at this point in the history
…lause (#4318)
  • Loading branch information
Weakky authored Oct 9, 2023
1 parent a4e8771 commit 9cc3db6
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 71 deletions.
3 changes: 2 additions & 1 deletion psl/builtin-connectors/src/cockroach_datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Connector
MultiSchema |
FilteredInlineChildNestedToOneDisconnect |
InsertReturning |
UpdateReturning
UpdateReturning |
RowIn
});

const SCALAR_TYPE_DEFAULTS: &[(ScalarType, CockroachType)] = &[
Expand Down
3 changes: 2 additions & 1 deletion psl/builtin-connectors/src/mysql_datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Connector
SupportsTxIsolationReadUncommitted |
SupportsTxIsolationReadCommitted |
SupportsTxIsolationRepeatableRead |
SupportsTxIsolationSerializable
SupportsTxIsolationSerializable |
RowIn
});

const CONSTRAINT_SCOPES: &[ConstraintScope] = &[ConstraintScope::GlobalForeignKey, ConstraintScope::ModelKeyIndex];
Expand Down
3 changes: 2 additions & 1 deletion psl/builtin-connectors/src/postgres_datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Connector
SupportsTxIsolationSerializable |
NativeUpsert |
InsertReturning |
UpdateReturning
UpdateReturning |
RowIn
});

pub struct PostgresDatamodelConnector;
Expand Down
3 changes: 2 additions & 1 deletion psl/builtin-connectors/src/sqlite_datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Connector
OrderByNullsFirstLast |
SupportsTxIsolationSerializable |
NativeUpsert |
FilteredInlineChildNestedToOneDisconnect
FilteredInlineChildNestedToOneDisconnect |
RowIn
// InsertReturning - While SQLite does support RETURNING, it does not return column information on the way back from the database.
// This column type information is necessary in order to preserve consistency for some data types such as int, where values could overflow.
// Since we care to stay consistent with reads, it is not enabled.
Expand Down
1 change: 1 addition & 0 deletions psl/psl-core/src/datamodel_connector/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ capabilities!(
NativeUpsert,
InsertReturning,
UpdateReturning,
RowIn, // Connector supports (a, b) IN (c, d) expression.
);

/// Contains all capabilities that the connector is able to serve.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,89 @@ mod one_relation {
Ok(())
}

// https://github.com/prisma/prisma/issues/21356
fn schema_21356() -> String {
let schema = indoc! {
r#"model User {
#id(id, Int, @id)
name String?
posts Post[]
userId Int
userId2 Int
@@unique([userId, userId2])
}
model Post {
#id(id, Int, @id)
title String?
userId Int?
userId_2 Int?
author User? @relation(fields: [userId, userId_2], references: [userId, userId2])
}"#
};

schema.to_owned()
}

#[connector_test(schema(schema_21356))]
async fn repro_21356(runner: Runner) -> TestResult<()> {
run_query!(
&runner,
r#"mutation { createOneUser(data: { id: 1, userId: 1, userId2: 1, name: "Bob", posts: { create: { id: 1, title: "Hello" } } }) { id } }"#
);

insta::assert_snapshot!(
run_query!(&runner, r#"{ findManyUser(where: { posts: { some: { author: { name: "Bob" } } } }) { id } }"#),
@r###"{"data":{"findManyUser":[{"id":1}]}}"###
);

Ok(())
}

// https://github.com/prisma/prisma/issues/21366
fn schema_21366() -> String {
let schema = indoc! {
r#"model device {
#id(id, Int, @id)
device_id String @unique
current_state device_state? @relation(fields: [device_id], references: [device_id], onDelete: NoAction)
}
model device_state {
#id(id, Int, @id)
device_id String @unique
device device[]
}"#
};

schema.to_owned()
}

#[connector_test(schema(schema_21366))]
async fn repro_21366(runner: Runner) -> TestResult<()> {
run_query!(
&runner,
r#"mutation {
createOnedevice(data: { id: 1, current_state: { create: { id: 1, device_id: "1" } } }) {
id
}
}
"#
);

insta::assert_snapshot!(
run_query!(&runner, r#"{ findManydevice_state(where: { device: { some: { device_id: "1" } } }) { id } }"#),
@r###"{"data":{"findManydevice_state":[{"id":1}]}}"###
);

Ok(())
}

async fn test_data(runner: &Runner) -> TestResult<()> {
runner
.query(indoc! { r#"
Expand Down
201 changes: 134 additions & 67 deletions query-engine/connectors/sql-query-connector/src/filter/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{model_extensions::*, Context};

use connector_interface::filter::*;
use prisma_models::prelude::*;
use psl::datamodel_connector::ConnectorCapability;
use quaint::ast::concat;
use quaint::ast::*;
use std::convert::TryInto;
Expand Down Expand Up @@ -110,84 +111,149 @@ impl FilterVisitor {
res
}

fn visit_relation_filter_select(&mut self, filter: RelationFilter, ctx: &Context<'_>) -> Select<'static> {
fn visit_relation_filter_select(
&mut self,
filter: RelationFilter,
ctx: &Context<'_>,
) -> (ModelProjection, Select<'static>) {
let is_many_to_many = filter.field.relation().is_many_to_many();
// HACK: This is temporary. A fix should be done in Quaint instead of branching out here.
// See https://www.notion.so/prismaio/Spec-Faulty-Tuple-Join-on-SQL-Server-55b8232fb44f4a6cb4d3f36428f17bac
// for more info
let support_row_in = filter
.field
.dm
.schema
.connector
.capabilities()
.contains(ConnectorCapability::RowIn);
let has_compound_fields = filter.field.linking_fields().into_inner().len() > 1;

// If the relation is an M2M relation we don't have a choice but to join
// If the connector does not support (a, b) IN (SELECT c, d) and there are several linking fields, then we must use a join.
// Hint: SQL Server does not support `ROW() IN ()`.
if is_many_to_many || (!support_row_in && has_compound_fields) {
self.visit_relation_filter_select_no_row(filter, ctx)
} else {
self.visit_relation_filter_select_with_row(filter, ctx)
}
}

/// Traverses a relation filter using this rough SQL structure:
///
/// ```sql
/// (parent.id) IN (
/// SELECT id FROM parent
/// INNER JOIN child ON (child.parent_id = parent.id)
/// WHERE <filter>
/// )
/// ```
/// We need this in two cases:
/// - For M2M relations, as we need to traverse the join table so the join is not superfluous
/// - SQL Server because it does not support (a, b) IN (subselect)
fn visit_relation_filter_select_no_row(
&mut self,
filter: RelationFilter,
ctx: &Context<'_>,
) -> (ModelProjection, Select<'static>) {
let alias = self.next_alias(AliasMode::Table);
let condition = filter.condition;
let table = filter.field.as_table(ctx);
let ids = ModelProjection::from(filter.field.model().primary_identifier());

// Perf: We can skip a join if the relation is inlined on the related model.
// In this case, we can select the related table's foreign key instead of joining.
// This is not possible in the case of M2M implicit relations.
if filter.field.related_field().is_inlined_on_enclosing_model() {
let related_table = filter.field.related_model().as_table(ctx);
let related_columns: Vec<_> = ModelProjection::from(filter.field.related_field().linking_fields())
.as_columns(ctx)
.map(|col| col.aliased_col(Some(alias), ctx))
.collect();
let selected_identifier: Vec<Column> = filter
.field
.identifier_columns(ctx)
.map(|col| col.aliased_col(Some(alias), ctx))
.collect();

let (nested_conditions, nested_joins) =
self.visit_nested_filter(alias, |this| this.visit_filter(*filter.nested_filter, ctx));
let nested_conditions = nested_conditions.invert_if(condition.invert_of_subselect());
let join_columns: Vec<Column> = filter
.field
.join_columns(ctx)
.map(|c| c.aliased_col(Some(alias), ctx))
.collect();

let conditions = related_columns
.clone()
.into_iter()
.fold(nested_conditions, |acc, column| acc.and(column.is_not_null()));
let related_table = filter.field.related_model().as_table(ctx);
let related_join_columns: Vec<_> = ModelProjection::from(filter.field.related_field().linking_fields())
.as_columns(ctx)
.map(|col| col.aliased_col(Some(alias.flip(AliasMode::Join)), ctx))
.collect();

let select = Select::from_table(related_table.alias(alias.to_string(Some(AliasMode::Table))))
.columns(related_columns)
.so_that(conditions);
let (nested_conditions, nested_joins) = self
.visit_nested_filter(alias.flip(AliasMode::Join), |nested_visitor| {
nested_visitor.visit_filter(*filter.nested_filter, ctx)
});

if let Some(nested_joins) = nested_joins {
nested_joins.into_iter().fold(select, |acc, join| acc.join(join.data))
} else {
select
}
let nested_conditions = nested_conditions.invert_if(condition.invert_of_subselect());
let nested_conditons = selected_identifier
.clone()
.into_iter()
.fold(nested_conditions, |acc, column| acc.and(column.is_not_null()));

let join = related_table
.alias(alias.to_string(Some(AliasMode::Join)))
.on(Row::from(related_join_columns).equals(Row::from(join_columns)));

let select = Select::from_table(table.alias(alias.to_string(Some(AliasMode::Table))))
.columns(selected_identifier)
.inner_join(join)
.so_that(nested_conditons);

let select = if let Some(nested_joins) = nested_joins {
nested_joins.into_iter().fold(select, |acc, join| acc.join(join.data))
} else {
let table = filter.field.as_table(ctx);
let selected_identifier: Vec<Column> = filter
.field
.identifier_columns(ctx)
.map(|col| col.aliased_col(Some(alias), ctx))
.collect();
select
};

let join_columns: Vec<Column> = filter
.field
.join_columns(ctx)
.map(|c| c.aliased_col(Some(alias), ctx))
.collect();
(ids, select)
}

let related_table = filter.field.related_model().as_table(ctx);
let related_join_columns: Vec<_> = ModelProjection::from(filter.field.related_field().linking_fields())
.as_columns(ctx)
.map(|col| col.aliased_col(Some(alias.flip(AliasMode::Join)), ctx))
.collect();
/// Traverses a relation filter using this rough SQL structure:
///
/// ```sql
/// (parent.id1, parent.id2) IN (
/// SELECT id1, id2 FROM child
/// WHERE <filter>
/// )
/// ```
fn visit_relation_filter_select_with_row(
&mut self,
filter: RelationFilter,
ctx: &Context<'_>,
) -> (ModelProjection, Select<'static>) {
let alias = self.next_alias(AliasMode::Table);
let condition = filter.condition;
let linking_fields = ModelProjection::from(filter.field.linking_fields());

let (nested_conditions, nested_joins) = self
.visit_nested_filter(alias.flip(AliasMode::Join), |nested_visitor| {
nested_visitor.visit_filter(*filter.nested_filter, ctx)
});
let related_table = filter.field.related_model().as_table(ctx);
// Select linking fields to match the linking fields of the parent record
let related_columns: Vec<_> = filter
.field
.related_field()
.join_columns(ctx)
.map(|col| col.aliased_col(Some(alias), ctx))
.collect();

let nested_conditions = nested_conditions.invert_if(condition.invert_of_subselect());
let nested_conditons = selected_identifier
.clone()
.into_iter()
.fold(nested_conditions, |acc, column| acc.and(column.is_not_null()));
let (nested_conditions, nested_joins) =
self.visit_nested_filter(alias, |this| this.visit_filter(*filter.nested_filter, ctx));
let nested_conditions = nested_conditions.invert_if(condition.invert_of_subselect());

let join = related_table
.alias(alias.to_string(Some(AliasMode::Join)))
.on(Row::from(related_join_columns).equals(Row::from(join_columns)));
let conditions = related_columns
.clone()
.into_iter()
.fold(nested_conditions, |acc, column| acc.and(column.is_not_null()));

let select = Select::from_table(table.alias(alias.to_string(Some(AliasMode::Table))))
.columns(selected_identifier)
.inner_join(join)
.so_that(nested_conditons);
let select = Select::from_table(related_table.alias(alias.to_string(Some(AliasMode::Table))))
.columns(related_columns)
.so_that(conditions);

if let Some(nested_joins) = nested_joins {
nested_joins.into_iter().fold(select, |acc, join| acc.join(join.data))
} else {
select
}
}
let select = if let Some(nested_joins) = nested_joins {
nested_joins.into_iter().fold(select, |acc, join| acc.join(join.data))
} else {
select
};

(linking_fields, select)
}
}

Expand Down Expand Up @@ -392,11 +458,12 @@ impl FilterVisitorExt for FilterVisitor {
}

_ => {
let ids = ModelProjection::from(filter.field.model().primary_identifier()).as_columns(ctx);
let columns: Vec<Column<'static>> = ids.map(|col| col.aliased_col(self.parent_alias(), ctx)).collect();

let condition = filter.condition;
let sub_select = self.visit_relation_filter_select(filter, ctx);
let (ids, sub_select) = self.visit_relation_filter_select(filter, ctx);
let columns: Vec<Column<'static>> = ids
.as_columns(ctx)
.map(|col| col.aliased_col(self.parent_alias(), ctx))
.collect();

let comparison = match condition {
RelationCondition::AtLeastOneRelatedRecord => Row::from(columns).in_selection(sub_select),
Expand Down

0 comments on commit 9cc3db6

Please sign in to comment.