Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - OnConflict support for Insert mutations #503

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion sql/load_sql_context.sql
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,14 @@ select
array[]::text[]
),
'is_unique', pi.indisunique and pi.indpred is null,
'is_primary_key', pi.indisprimary
'is_primary_key', pi.indisprimary,
'name', pc_ix.relname
)
)
from
pg_catalog.pg_index pi
join pg_catalog.pg_class pc_ix
on pi.indexrelid = pc_ix.oid
where
pi.indrelid = pc.oid
),
Expand Down
156 changes: 142 additions & 14 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@ use crate::parser_util::*;
use crate::sql_types::*;
use graphql_parser::query::*;
use serde::Serialize;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::ops::Deref;
use std::str::FromStr;
use std::sync::Arc;

#[derive(Clone, Debug)]
pub struct InsertBuilder {
pub alias: String,
pub struct OnConflictBuilder {
pub constraint: Index, // Could probably get away with a name ref
pub update_fields: HashSet<Arc<Column>>, // Could probably get away with a name ref
pub filter: FilterBuilder,
}

#[derive(Clone, Debug)]
pub struct InsertBuilder {
// args
pub objects: Vec<InsertRowBuilder>,

Expand All @@ -22,6 +27,8 @@ pub struct InsertBuilder {

//fields
pub selections: Vec<InsertSelection>,

pub on_conflict: Option<OnConflictBuilder>,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -176,6 +183,117 @@ where
parse_node_id(node_id_base64_encoded_json_string)
}

fn read_argument_on_conflict<'a, T>(
field: &__Field,
query_field: &graphql_parser::query::Field<'a, T>,
variables: &serde_json::Value,
variable_definitions: &Vec<VariableDefinition<'a, T>>,
) -> Result<Option<OnConflictBuilder>, String>
where
T: Text<'a> + Eq + AsRef<str>,
{
let conflict_type: OnConflictType = match field.get_arg("onConflict") {
None => return Ok(None),
Some(x) => match x.type_().unmodified_type() {
__Type::OnConflictInput(insert_on_conflict) => insert_on_conflict,
_ => return Err("Could not locate Insert Entity type".to_string()),
},
};

let validated: gson::Value = read_argument(
"onConflict",
field,
query_field,
variables,
variable_definitions,
)?;

let on_conflict_builder = match validated {
gson::Value::Absent | gson::Value::Null => None,
gson::Value::Object(contents) => {
let constraint = match contents
.get("constraint")
.expect("OnConflict revalidation error. Expected constraint")
{
gson::Value::String(ix_name) => conflict_type
.table
.indexes
.iter()
.find(|ix| &ix.name == ix_name)
.expect("OnConflict revalidation error. constraint: unknown constraint name"),
_ => {
return Err(
"OnConflict revalidation error. Expected constraint as String".to_string(),
)
}
};

// TODO: Filter reading logic is partially duplicated from read_argument_filter
// ideally this should be refactored
let filter_gson = contents
.get("filter")
.expect("onConflict revalidation error");

let filter = match filter_gson {
gson::Value::Null | gson::Value::Absent => FilterBuilder { elems: vec![] },
gson::Value::Object(_) => {
let filter_type = conflict_type
.input_fields()
.expect("Failed to unwrap input fields on OnConflict type")
.iter()
.find(|in_f| in_f.name() == "filter")
.expect("Failed to get filter input_field on onConflict type")
.type_()
.unmodified_type();

if !matches!(filter_type, __Type::FilterEntity(_)) {
return Err("Could not locate Filter Entity type".to_string());
}
let filter_field_map = input_field_map(&filter_type);
let filter_elems = create_filters(&filter_gson, &filter_field_map)?;
FilterBuilder {
elems: filter_elems,
}
}
_ => return Err("OnConflict revalidation error. invalid filter object".to_string()),
};

let update_fields = match contents
.get("updateFields")
.expect("OnConflict revalidation error. Expected updateFields")
{
gson::Value::Array(col_names) => {
let mut update_columns: HashSet<Arc<Column>> = HashSet::new();
for col_name in col_names {
match col_name {
gson::Value::String(c) => {
let col = conflict_type.table.columns.iter().find(|column| &column.name == c).expect("OnConflict revalidation error. updateFields: unknown column name");
update_columns.insert(Arc::clone(col));
}
_ => return Err("OnConflict revalidation error. Expected updateFields to be column names".to_string()),
}
}
update_columns
}
_ => {
return Err(
"OnConflict revalidation error. Expected updateFields to be an array"
.to_string(),
)
}
};

Some(OnConflictBuilder {
constraint: constraint.clone(),
update_fields,
filter,
})
}
_ => return Err("Insert re-validation errror".to_string()),
};
Ok(on_conflict_builder)
}

fn read_argument_objects<'a, T>(
field: &__Field,
query_field: &graphql_parser::query::Field<'a, T>,
Expand Down Expand Up @@ -272,12 +390,27 @@ where
.name()
.ok_or("Encountered type without name in connection builder")?;
let field_map = field_map(&type_);
let alias = alias_or_name(query_field);

match &type_ {
__Type::InsertResponse(xtype) => {
// Raise for disallowed arguments
restrict_allowed_arguments(&["objects"], query_field)?;
let allowed_args = field
.args
.iter()
.map(|iv| iv.name())
.collect::<HashSet<String>>();

match allowed_args.contains("onConflict") {
true => restrict_allowed_arguments(&["objects", "onConflict"], query_field)?,
false => restrict_allowed_arguments(&["objects"], query_field)?,
}

let on_conflict: Option<OnConflictBuilder> = match allowed_args.contains("onConflict") {
true => {
read_argument_on_conflict(field, query_field, variables, variable_definitions)?
}
false => None,
};

let objects: Vec<InsertRowBuilder> =
read_argument_objects(field, query_field, variables, variable_definitions)?;
Expand Down Expand Up @@ -320,10 +453,10 @@ where
}
}
Ok(InsertBuilder {
alias,
table: Arc::clone(&xtype.table),
objects,
selections: builder_fields,
on_conflict,
})
}
_ => Err(format!(
Expand All @@ -335,8 +468,6 @@ where

#[derive(Clone, Debug)]
pub struct UpdateBuilder {
pub alias: String,

// args
pub filter: FilterBuilder,
pub set: SetBuilder,
Expand Down Expand Up @@ -438,7 +569,6 @@ where
.name()
.ok_or("Encountered type without name in update builder")?;
let field_map = field_map(&type_);
let alias = alias_or_name(query_field);

match &type_ {
__Type::UpdateResponse(xtype) => {
Expand Down Expand Up @@ -490,7 +620,6 @@ where
}
}
Ok(UpdateBuilder {
alias,
filter,
set,
at_most,
Expand All @@ -507,8 +636,6 @@ where

#[derive(Clone, Debug)]
pub struct DeleteBuilder {
pub alias: String,

// args
pub filter: FilterBuilder,
pub at_most: i64,
Expand Down Expand Up @@ -544,7 +671,6 @@ where
.name()
.ok_or("Encountered type without name in delete builder")?;
let field_map = field_map(&type_);
let alias = alias_or_name(query_field);

match &type_ {
__Type::DeleteResponse(xtype) => {
Expand Down Expand Up @@ -594,7 +720,6 @@ where
}
}
Ok(DeleteBuilder {
alias,
filter,
at_most,
table: Arc::clone(&xtype.table),
Expand Down Expand Up @@ -1060,11 +1185,14 @@ where
variable_definitions,
)?;

//return Err(format!("Err {:?}", validated));

let filter_type = field
.get_arg("filter")
.expect("failed to get filter argument")
.type_()
.unmodified_type();

if !matches!(filter_type, __Type::FilterEntity(_)) {
return Err("Could not locate Filter Entity type".to_string());
}
Expand Down
Loading
Loading