Skip to content

Commit

Permalink
feat: Refactor 'DataValue' length check logic; add support for varcha…
Browse files Browse the repository at this point in the history
…r computations (#60)
  • Loading branch information
loloxwg authored Sep 17, 2023
1 parent 15e72bc commit 1853a9b
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ mod tests {
assert_eq!(op.columns[0].desc, ColumnDesc::new(LogicalType::Integer, true));
assert_eq!(op.columns[1].name, "name".to_string());
assert_eq!(op.columns[1].nullable, true);
assert_eq!(op.columns[1].desc, ColumnDesc::new(LogicalType::Varchar, false));
assert_eq!(op.columns[1].desc, ColumnDesc::new(LogicalType::Varchar(Some(10)), false));
}
_ => unreachable!()
}
Expand Down
3 changes: 2 additions & 1 deletion src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ impl<S: Storage> Binder<S> {
for (i, expr) in expr_row.into_iter().enumerate() {
match &self.bind_expr(expr).await? {
ScalarExpression::Constant(value) => {
// Check if the value length is too long
value.check_length(columns[i].datatype())?;
let cast_value = DataValue::clone(value)
.cast(columns[i].datatype())?;

row.push(Arc::new(cast_value))
},
ScalarExpression::Unary { expr, op, .. } => {
Expand Down
1 change: 1 addition & 0 deletions src/binder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ impl<S: Storage> Binder<S> {
bind_table_name.as_ref()
).await? {
ScalarExpression::ColumnRef(catalog) => {
value.check_length(catalog.datatype())?;
columns.push(catalog);
row.push(value.clone());
},
Expand Down
97 changes: 96 additions & 1 deletion src/expression/value_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ fn unpack_date(value: DataValue) -> Option<i64> {
}
}

fn unpack_utf8(value: DataValue) -> Option<String> {
match value {
DataValue::Utf8(inner) => inner,
_ => None
}
}

pub fn unary_op(
value: &DataValue,
op: &UnaryOperator,
Expand Down Expand Up @@ -114,7 +121,7 @@ pub fn binary_op(
) -> Result<DataValue, TypeError> {
let unified_type = LogicalType::max_logical_type(
&left.logical_type(),
&right.logical_type()
&right.logical_type(),
)?;

let value = match &unified_type {
Expand Down Expand Up @@ -844,6 +851,76 @@ pub fn binary_op(
_ => todo!("unsupported operator")
}
}
LogicalType::Varchar(None) => {
let left_value = unpack_utf8(left.clone().cast(&unified_type)?);
let right_value = unpack_utf8(right.clone().cast(&unified_type)?);

match op {
BinaryOperator::Gt => {
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
Some(v1 > v2)
} else {
None
};

DataValue::Boolean(value)
}
BinaryOperator::Lt => {
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
Some(v1 < v2)
} else {
None
};

DataValue::Boolean(value)
}
BinaryOperator::GtEq => {
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
Some(v1 >= v2)
} else {
None
};

DataValue::Boolean(value)
}
BinaryOperator::LtEq => {
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
Some(v1 <= v2)
} else {
None
};

DataValue::Boolean(value)
}
BinaryOperator::Eq => {
let value = match (left_value, right_value) {
(Some(v1), Some(v2)) => {
Some(v1 == v2)
}
(None, None) => {
Some(true)
}
(_, _) => {
None
}
};

DataValue::Boolean(value)
}
BinaryOperator::NotEq => {
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
Some(v1 != v2)
} else {
None
};

DataValue::Boolean(value)
}
_ => todo!("unsupported operator")
}
}
// Utf8

_ => todo!("unsupported data type"),
};

Expand Down Expand Up @@ -1105,4 +1182,22 @@ mod test {

Ok(())
}

#[test]
fn test_binary_op_Utf8_compare()->Result<(),TypeError>{
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("b".to_string())), &BinaryOperator::Gt)?, DataValue::Boolean(Some(false)));
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("b".to_string())), &BinaryOperator::Lt)?, DataValue::Boolean(Some(true)));
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::GtEq)?, DataValue::Boolean(Some(true)));
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::LtEq)?, DataValue::Boolean(Some(true)));
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::NotEq)?, DataValue::Boolean(Some(false)));
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true)));

assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Gt)?, DataValue::Boolean(None));
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Lt)?, DataValue::Boolean(None));
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::GtEq)?, DataValue::Boolean(None));
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::LtEq)?, DataValue::Boolean(None));
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::NotEq)?, DataValue::Boolean(None));

Ok(())
}
}
2 changes: 2 additions & 0 deletions src/types/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub enum TypeError {
InternalError(String),
#[error("cast fail")]
CastFail,
#[error("Too long")]
TooLong,
#[error("cannot be Null")]
NotNull,
#[error("try from int")]
Expand Down
24 changes: 11 additions & 13 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub enum LogicalType {
UBigint,
Float,
Double,
Varchar,
Varchar(Option<u32>),
Date,
DateTime,
}
Expand All @@ -75,7 +75,8 @@ impl LogicalType {
LogicalType::UBigint => Some(8),
LogicalType::Float => Some(4),
LogicalType::Double => Some(8),
LogicalType::Varchar => None,
/// Note: The non-fixed length type's raw_len is None
LogicalType::Varchar(_)=>None,
LogicalType::Date => Some(4),
LogicalType::DateTime => Some(8),
}
Expand Down Expand Up @@ -156,13 +157,13 @@ impl LogicalType {
if left.is_numeric() && right.is_numeric() {
return LogicalType::combine_numeric_types(left, right);
}
if matches!((left, right), (LogicalType::Date, LogicalType::Varchar) | (LogicalType::Varchar, LogicalType::Date)) {
if matches!((left, right), (LogicalType::Date, LogicalType::Varchar(_)) | (LogicalType::Varchar(_), LogicalType::Date)) {
return Ok(LogicalType::Date);
}
if matches!((left, right), (LogicalType::Date, LogicalType::DateTime) | (LogicalType::DateTime, LogicalType::Date)) {
return Ok(LogicalType::DateTime);
}
if matches!((left, right), (LogicalType::DateTime, LogicalType::Varchar) | (LogicalType::Varchar, LogicalType::DateTime)) {
if matches!((left, right), (LogicalType::DateTime, LogicalType::Varchar(_)) | (LogicalType::Varchar(_), LogicalType::DateTime)) {
return Ok(LogicalType::DateTime);
}
Err(TypeError::InternalError(format!(
Expand Down Expand Up @@ -265,9 +266,9 @@ impl LogicalType {
LogicalType::UBigint => matches!(to, LogicalType::Float | LogicalType::Double),
LogicalType::Float => matches!(to, LogicalType::Double),
LogicalType::Double => false,
LogicalType::Varchar => false,
LogicalType::Date => matches!(to, LogicalType::DateTime | LogicalType::Varchar),
LogicalType::DateTime => matches!(to, LogicalType::Date | LogicalType::Varchar),
LogicalType::Varchar(_) => false,
LogicalType::Date => matches!(to, LogicalType::DateTime | LogicalType::Varchar(_)),
LogicalType::DateTime => matches!(to, LogicalType::Date | LogicalType::Varchar(_)),
}
}
}
Expand All @@ -278,11 +279,8 @@ impl TryFrom<sqlparser::ast::DataType> for LogicalType {

fn try_from(value: sqlparser::ast::DataType) -> Result<Self, Self::Error> {
match value {
sqlparser::ast::DataType::Char(_)
| sqlparser::ast::DataType::Varchar(_)
| sqlparser::ast::DataType::Nvarchar(_)
| sqlparser::ast::DataType::Text
| sqlparser::ast::DataType::String => Ok(LogicalType::Varchar),
sqlparser::ast::DataType::Char(len)
| sqlparser::ast::DataType::Varchar(len)=> Ok(LogicalType::Varchar(len.map(|len| len.length as u32))),
sqlparser::ast::DataType::Float(_) => Ok(LogicalType::Float),
sqlparser::ast::DataType::Double => Ok(LogicalType::Double),
sqlparser::ast::DataType::TinyInt(_) => Ok(LogicalType::Tinyint),
Expand Down Expand Up @@ -315,7 +313,7 @@ impl std::fmt::Display for LogicalType {
mod test {
use std::sync::atomic::Ordering::Release;

use crate::types::{IdGenerator, ID_BUF};
use crate::types::{IdGenerator, ID_BUF, LogicalType};

/// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰
#[test]
Expand Down
4 changes: 3 additions & 1 deletion src/types/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ impl Tuple {
if bit_index(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) {
values.push(Arc::new(DataValue::none(logic_type)));
} else if let Some(len) = logic_type.raw_len() {
/// fixed length (e.g.: int)
values.push(Arc::new(DataValue::from_raw(&bytes[pos..pos + len], logic_type)));
pos += len;
} else {
/// variable length (e.g.: varchar)
let len = u32::decode_fixed(&bytes[pos..pos + 4]) as usize;
pos += 4;
values.push(Arc::new(DataValue::from_raw(&bytes[pos..pos + len], logic_type)));
Expand Down Expand Up @@ -133,7 +135,7 @@ mod tests {
Arc::new(ColumnCatalog::new(
"c3".to_string(),
false,
ColumnDesc::new(LogicalType::Varchar, false)
ColumnDesc::new(LogicalType::Varchar(Some(2)), false)
)),
Arc::new(ColumnCatalog::new(
"c4".to_string(),
Expand Down
Loading

0 comments on commit 1853a9b

Please sign in to comment.