Skip to content

Commit

Permalink
implement Subquery (#86)
Browse files Browse the repository at this point in the history
* feat: implement table aliases

* feat: implement subquery

* fix: the corresponding TableName cannot be found when binding Table is allowed (e.g. when there is a join in the subquery and no alias is declared)

* style: code fmt
  • Loading branch information
KKould authored Oct 12, 2023
1 parent 8786f0f commit b6700ac
Show file tree
Hide file tree
Showing 13 changed files with 170 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/binder/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
};

if let Some(table) = self.context.transaction.table(&table_name.to_string()) {
if let Some(table) = self.context.table(&table_name.to_string()) {
let cols = table.all_columns();
let ext_source = ExtSource {
path: match target {
Expand Down
5 changes: 3 additions & 2 deletions src/binder/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ impl<'a, T: Transaction> Binder<'a, T> {
from: &TableWithJoins,
selection: &Option<Expr>,
) -> Result<LogicalPlan, BindError> {
if let TableFactor::Table { name, .. } = &from.relation {
if let TableFactor::Table { name, alias, .. } = &from.relation {
let name = lower_case_name(name);
let (_, name) = split_name(&name)?;
let (table_name, mut plan) = self._bind_single_table_ref(None, name)?;
let (table_name, mut plan) =
self._bind_single_table_ref(None, name, Self::trans_alias(alias))?;

if let Some(predicate) = selection {
plan = self.bind_where(plan, predicate)?;
Expand Down
1 change: 0 additions & 1 deletion src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ impl<'a, T: Transaction> Binder<'a, T> {
if let Some(table) = table_name.or(bind_table_name) {
let table_catalog = self
.context
.transaction
.table(table)
.ok_or_else(|| BindError::InvalidTable(table.to_string()))?;

Expand Down
2 changes: 1 addition & 1 deletion src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
let (_, name) = split_name(&name)?;
let table_name = Arc::new(name.to_string());

if let Some(table) = self.context.transaction.table(&table_name) {
if let Some(table) = self.context.table(&table_name) {
let mut columns = Vec::new();

if idents.is_empty() {
Expand Down
54 changes: 47 additions & 7 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ pub enum InputRefType {

#[derive(Clone)]
pub struct BinderContext<'a, T: Transaction> {
pub(crate) transaction: &'a T,
transaction: &'a T,
pub(crate) bind_table: BTreeMap<TableName, (TableCatalog, Option<JoinType>)>,
aliases: BTreeMap<String, ScalarExpression>,
table_aliases: BTreeMap<String, TableName>,
group_by_exprs: Vec<ScalarExpression>,
pub(crate) agg_calls: Vec<ScalarExpression>,
}
Expand All @@ -41,11 +42,20 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
transaction,
bind_table: Default::default(),
aliases: Default::default(),
table_aliases: Default::default(),
group_by_exprs: vec![],
agg_calls: Default::default(),
}
}

pub fn table(&self, table_name: &String) -> Option<&TableCatalog> {
if let Some(real_name) = self.table_aliases.get(table_name) {
self.transaction.table(real_name)
} else {
self.transaction.table(table_name)
}
}

// Tips: The order of this index is based on Aggregate being bound first.
pub fn input_ref_index(&self, ty: InputRefType) -> usize {
match ty {
Expand All @@ -54,12 +64,42 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
}
}

pub fn add_alias(&mut self, alias: String, expr: ScalarExpression) {
if self.aliases.contains_key(&alias) {
return;
pub fn add_alias(&mut self, alias: String, expr: ScalarExpression) -> Result<(), BindError> {
let is_exist = self.aliases.insert(alias.clone(), expr).is_some();
if is_exist {
return Err(BindError::InvalidColumn(format!("{} duplicated", alias)));
}

Ok(())
}

pub fn add_table_alias(&mut self, alias: String, table: TableName) -> Result<(), BindError> {
let is_alias_exist = self
.table_aliases
.insert(alias.clone(), table.clone())
.is_some();
if is_alias_exist {
return Err(BindError::InvalidTable(format!("{} duplicated", alias)));
}

Ok(())
}

pub fn add_bind_table(
&mut self,
table: TableName,
table_catalog: TableCatalog,
join_type: Option<JoinType>,
) -> Result<(), BindError> {
let is_bound = self
.bind_table
.insert(table.clone(), (table_catalog.clone(), join_type))
.is_some();
if is_bound {
return Err(BindError::InvalidTable(format!("{} duplicated", table)));
}

self.aliases.insert(alias, expr);
Ok(())
}

pub fn has_agg_call(&self, expr: &ScalarExpression) -> bool {
Expand Down Expand Up @@ -175,8 +215,8 @@ pub enum BindError {
AmbiguousColumn(String),
#[error("binary operator types mismatch: {0} != {1}")]
BinaryOpTypeMismatch(String, String),
#[error("subquery in FROM must have an alias")]
SubqueryMustHaveAlias,
#[error("subquery error: {0}")]
Subquery(String),
#[error("agg miss: {0}")]
AggMiss(String),
#[error("catalog error: {0}")]
Expand Down
94 changes: 60 additions & 34 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use itertools::Itertools;
use sqlparser::ast;
use sqlparser::ast::{
Distinct, Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, OrderByExpr, Query, Select,
SelectItem, SetExpr, TableFactor, TableWithJoins,
SelectItem, SetExpr, TableAlias, TableFactor, TableWithJoins,
};

impl<'a, T: Transaction> Binder<'a, T> {
Expand Down Expand Up @@ -128,18 +128,25 @@ impl<'a, T: Transaction> Binder<'a, T> {
let (left_name, mut plan) = self.bind_single_table_ref(relation, None)?;

if !joins.is_empty() {
let left_name = Self::unpack_name(left_name, true);

for join in joins {
plan = self.bind_join(left_name.clone(), plan, join)?;
plan = self.bind_join(&left_name, plan, join)?;
}
}
Ok(plan)
}

fn unpack_name(table_name: Option<TableName>, is_left: bool) -> TableName {
let title = if is_left { "Left" } else { "Right" };
table_name.expect(&format!("{}: Table is not named", title))
}

fn bind_single_table_ref(
&mut self,
table: &TableFactor,
joint_type: Option<JoinType>,
) -> Result<(TableName, LogicalPlan), BindError> {
) -> Result<(Option<TableName>, LogicalPlan), BindError> {
let plan_with_name = match table {
TableFactor::Table { name, alias, .. } => {
let obj_name = name
Expand All @@ -148,45 +155,69 @@ impl<'a, T: Transaction> Binder<'a, T> {
.map(|ident| Ident::new(ident.value.to_lowercase()))
.collect_vec();

let (_database, _schema, mut table): (&str, &str, &str) = match obj_name.as_slice()
{
let (_database, _schema, table): (&str, &str, &str) = match obj_name.as_slice() {
[table] => (DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME, &table.value),
[schema, table] => (DEFAULT_DATABASE_NAME, &schema.value, &table.value),
[database, schema, table] => (&database.value, &schema.value, &table.value),
_ => return Err(BindError::InvalidTableName(obj_name)),
};
if let Some(alias) = alias {
table = &alias.name.value;
}

self._bind_single_table_ref(joint_type, table)?
let (table, plan) =
self._bind_single_table_ref(joint_type, table, Self::trans_alias(alias))?;
(Some(table), plan)
}
TableFactor::Derived {
subquery, alias, ..
} => {
let plan = self.bind_query(subquery)?;
let mut tables = plan.referenced_table();

if let Some(alias) = Self::trans_alias(alias) {
let alias = Arc::new(alias.clone());

if tables.len() > 1 {
todo!("Implement virtual tables for multiple table aliases");
}
// FIXME
self.context
.add_table_alias(alias.to_string(), tables.remove(0))?;

(Some(alias), plan)
} else {
((tables.len() > 1).then(|| tables.pop()).flatten(), plan)
}
}
_ => unimplemented!(),
};

Ok(plan_with_name)
}

pub(crate) fn trans_alias(alias: &Option<TableAlias>) -> Option<&String> {
alias.as_ref().map(|alias| &alias.name.value)
}

pub(crate) fn _bind_single_table_ref(
&mut self,
joint_type: Option<JoinType>,
join_type: Option<JoinType>,
table: &str,
alias: Option<&String>,
) -> Result<(Arc<String>, LogicalPlan), BindError> {
let table_name = Arc::new(table.to_string());

if self.context.bind_table.contains_key(&table_name) {
return Err(BindError::InvalidTable(format!("{} duplicated", table)));
}

let table_catalog = self
.context
.transaction
.table(&table_name)
.cloned()
.ok_or_else(|| BindError::InvalidTable(format!("bind table {}", table)))?;

self.context
.bind_table
.insert(table_name.clone(), (table_catalog.clone(), joint_type));
.add_bind_table(table_name.clone(), table_catalog.clone(), join_type)?;

if let Some(alias) = alias {
self.context
.add_table_alias(alias.to_string(), table_name.clone())?;
}

Ok((
table_name.clone(),
Expand All @@ -213,7 +244,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
let expr = self.bind_expr(expr)?;
let alias_name = alias.to_string();

self.context.add_alias(alias_name.clone(), expr.clone());
self.context.add_alias(alias_name.clone(), expr.clone())?;

select_items.push(ScalarExpression::Alias {
expr: Box::new(expr),
Expand All @@ -236,7 +267,6 @@ impl<'a, T: Transaction> Binder<'a, T> {
for table_name in self.context.bind_table.keys().cloned() {
let table = self
.context
.transaction
.table(&table_name)
.ok_or_else(|| BindError::InvalidTable(table_name.to_string()))?;
for col in table.all_columns() {
Expand All @@ -249,7 +279,7 @@ impl<'a, T: Transaction> Binder<'a, T> {

fn bind_join(
&mut self,
left_table: TableName,
left_table: &String,
left: LogicalPlan,
join: &Join,
) -> Result<LogicalPlan, BindError> {
Expand All @@ -266,21 +296,17 @@ impl<'a, T: Transaction> Binder<'a, T> {
JoinOperator::CrossJoin => (JoinType::Cross, None),
_ => unimplemented!(),
};

let (right_table, right) = self.bind_single_table_ref(relation, Some(join_type))?;

let left_table = self
.context
.transaction
.table(&left_table)
.cloned()
.ok_or_else(|| BindError::InvalidTable(format!("Left: {} not found", left_table)))?;
let right_table = self
.context
.transaction
.table(&right_table)
.cloned()
.ok_or_else(|| BindError::InvalidTable(format!("Right: {} not found", right_table)))?;
let right_table = Self::unpack_name(right_table, false);

let left_table =
self.context.table(left_table).cloned().ok_or_else(|| {
BindError::InvalidTable(format!("Left: {} not found", left_table))
})?;
let right_table =
self.context.table(&right_table).cloned().ok_or_else(|| {
BindError::InvalidTable(format!("Right: {} not found", right_table))
})?;

let on = match joint_condition {
Some(constraint) => self.bind_join_constraint(&left_table, &right_table, constraint)?,
Expand Down
4 changes: 2 additions & 2 deletions src/catalog/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ pub type TableName = Arc<String>;

#[derive(Debug, Clone, PartialEq)]
pub struct TableCatalog {
pub name: TableName,
pub(crate) name: TableName,
/// Mapping from column names to column ids
column_idxs: BTreeMap<String, ColumnId>,
pub(crate) columns: BTreeMap<ColumnId, ColumnRef>,
pub indexes: Vec<IndexMetaRef>,
pub(crate) indexes: Vec<IndexMetaRef>,
}

impl TableCatalog {
Expand Down
10 changes: 5 additions & 5 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ impl<S: Storage> Database<S> {
RuleImpl::PushPredicateIntoScan,
],
)
.batch(
"Combine Operators".to_string(),
HepBatchStrategy::fix_point_topdown(10),
vec![RuleImpl::CollapseProject, RuleImpl::CombineFilter],
)
.batch(
"Column Pruning".to_string(),
HepBatchStrategy::fix_point_topdown(10),
Expand All @@ -102,11 +107,6 @@ impl<S: Storage> Database<S> {
RuleImpl::EliminateLimits,
],
)
.batch(
"Combine Operators".to_string(),
HepBatchStrategy::fix_point_topdown(10),
vec![RuleImpl::CollapseProject, RuleImpl::CombineFilter],
)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution/executor/dml/copy_to_file.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::planner::operator::copy_to_file::CopyToFileOperator;

#[warn(dead_code)]
#[allow(dead_code)]
pub struct CopyToFile {
op: CopyToFileOperator,
}
2 changes: 1 addition & 1 deletion src/optimizer/rule/column_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl Rule for PushProjectIntoScan {
new_scan_op.columns = project_op
.columns
.iter()
.filter(|expr| matches!(expr.unpack_alias(), ScalarExpression::ColumnRef(_)))
.map(ScalarExpression::unpack_alias)
.cloned()
.collect_vec();

Expand Down
2 changes: 2 additions & 0 deletions src/optimizer/rule/combine_operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ impl Rule for CollapseProject {
if let Operator::Project(child_op) = graph.operator(child_id) {
if is_subset_exprs(&op.columns, &child_op.columns) {
graph.remove_node(child_id, false);
} else {
graph.remove_node(node_id, false);
}
}
}
Expand Down
Loading

0 comments on commit b6700ac

Please sign in to comment.