Skip to content

Commit

Permalink
feat: Refactor 'DataValue' length check logic; add support for varcha…
Browse files Browse the repository at this point in the history
…r computations

This commit revises the mechanism for checking limit conditions on variable length types such as varchar, moving the check from 'src/binder/insert.rs' to 'src/types/value.rs' to centralize the control. It also implements expression computation for the Udf8 data type. Furthermore, the now-unnecessary `Nvarchar` type has been removed. This overhaul enhances robustness against incorrect inputs in variable length fields and improves the computational abilities for the Utf8 data type.
  • Loading branch information
loloxwg committed Sep 16, 2023
1 parent ddbdbee commit e66417d
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 59 deletions.
8 changes: 2 additions & 6 deletions src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,10 @@ impl<S: Storage> Binder<S> {
for (i, expr) in expr_row.into_iter().enumerate() {
match &self.bind_expr(expr).await? {
ScalarExpression::Constant(value) => {
// Check if the value length is too long
value.check_length(columns[i].datatype())?;
let cast_value = DataValue::clone(value)
.cast(columns[i].datatype())?;

// Check if the value length is too long
if Some(cast_value.to_raw().len()) > columns[i].datatype().raw_len() {
return Err(BindError::InvalidTable(format!("value length is too long")))
}

row.push(Arc::new(cast_value))
},
ScalarExpression::Unary { expr, op, .. } => {
Expand Down
1 change: 1 addition & 0 deletions src/binder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ impl<S: Storage> Binder<S> {
bind_table_name.as_ref()
).await? {
ScalarExpression::ColumnRef(catalog) => {
value.check_length(catalog.datatype())?;
columns.push(catalog);
row.push(value.clone());
},
Expand Down
97 changes: 96 additions & 1 deletion src/expression/value_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ fn unpack_date(value: DataValue) -> Option<i64> {
}
}

fn unpack_utf8(value: DataValue) -> Option<String> {
match value {
DataValue::Utf8(inner) => inner,
_ => None
}
}

pub fn unary_op(
value: &DataValue,
op: &UnaryOperator,
Expand Down Expand Up @@ -114,7 +121,7 @@ pub fn binary_op(
) -> Result<DataValue, TypeError> {
let unified_type = LogicalType::max_logical_type(
&left.logical_type(),
&right.logical_type()
&right.logical_type(),
)?;

let value = match &unified_type {
Expand Down Expand Up @@ -844,6 +851,76 @@ pub fn binary_op(
_ => todo!("unsupported operator")
}
}
LogicalType::Varchar(None) => {
let left_value = unpack_utf8(left.clone().cast(&unified_type)?);
let right_value = unpack_utf8(right.clone().cast(&unified_type)?);

match op {
BinaryOperator::Gt => {
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
Some(v1 > v2)
} else {
None
};

DataValue::Boolean(value)
}
BinaryOperator::Lt => {
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
Some(v1 < v2)
} else {
None
};

DataValue::Boolean(value)
}
BinaryOperator::GtEq => {
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
Some(v1 >= v2)
} else {
None
};

DataValue::Boolean(value)
}
BinaryOperator::LtEq => {
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
Some(v1 <= v2)
} else {
None
};

DataValue::Boolean(value)
}
BinaryOperator::Eq => {
let value = match (left_value, right_value) {
(Some(v1), Some(v2)) => {
Some(v1 == v2)
}
(None, None) => {
Some(true)
}
(_, _) => {
None
}
};

DataValue::Boolean(value)
}
BinaryOperator::NotEq => {
let value = if let (Some(v1), Some(v2)) = (left_value, right_value) {
Some(v1 != v2)
} else {
None
};

DataValue::Boolean(value)
}
_ => todo!("unsupported operator")
}
}
// Utf8

_ => todo!("unsupported data type"),
};

Expand Down Expand Up @@ -1105,4 +1182,22 @@ mod test {

Ok(())
}

#[test]
fn test_binary_op_Utf8_compare()->Result<(),TypeError>{
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("b".to_string())), &BinaryOperator::Gt)?, DataValue::Boolean(Some(false)));
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("b".to_string())), &BinaryOperator::Lt)?, DataValue::Boolean(Some(true)));
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::GtEq)?, DataValue::Boolean(Some(true)));
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::LtEq)?, DataValue::Boolean(Some(true)));
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::NotEq)?, DataValue::Boolean(Some(false)));
assert_eq!(binary_op(&DataValue::Utf8(Some("a".to_string())), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Eq)?, DataValue::Boolean(Some(true)));

assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Gt)?, DataValue::Boolean(None));
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::Lt)?, DataValue::Boolean(None));
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::GtEq)?, DataValue::Boolean(None));
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::LtEq)?, DataValue::Boolean(None));
assert_eq!(binary_op(&DataValue::Utf8(None), &DataValue::Utf8(Some("a".to_string())), &BinaryOperator::NotEq)?, DataValue::Boolean(None));

Ok(())
}
}
28 changes: 5 additions & 23 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ pub enum LogicalType {
Float,
Double,
Varchar(Option<u32>),
Nvarchar(Option<u32>),
Date,
DateTime,
}
Expand All @@ -76,8 +75,8 @@ impl LogicalType {
LogicalType::UBigint => Some(8),
LogicalType::Float => Some(4),
LogicalType::Double => Some(8),
LogicalType::Varchar(len)=> len.map(|len| len as usize),
LogicalType::Nvarchar(len) => len.map(|len| len as usize),
/// Note: The non-fixed length type's raw_len is None
LogicalType::Varchar(_)=>None,
LogicalType::Date => Some(4),
LogicalType::DateTime => Some(8),
}
Expand Down Expand Up @@ -158,13 +157,13 @@ impl LogicalType {
if left.is_numeric() && right.is_numeric() {
return LogicalType::combine_numeric_types(left, right);
}
if matches!((left, right), (LogicalType::Date, LogicalType::Varchar(len)) | (LogicalType::Varchar(len), LogicalType::Date)) {
if matches!((left, right), (LogicalType::Date, LogicalType::Varchar(_)) | (LogicalType::Varchar(_), LogicalType::Date)) {
return Ok(LogicalType::Date);
}
if matches!((left, right), (LogicalType::Date, LogicalType::DateTime) | (LogicalType::DateTime, LogicalType::Date)) {
return Ok(LogicalType::DateTime);
}
if matches!((left, right), (LogicalType::DateTime, LogicalType::Varchar(len)) | (LogicalType::Varchar(len), LogicalType::DateTime)) {
if matches!((left, right), (LogicalType::DateTime, LogicalType::Varchar(_)) | (LogicalType::Varchar(_), LogicalType::DateTime)) {
return Ok(LogicalType::DateTime);
}
Err(TypeError::InternalError(format!(
Expand Down Expand Up @@ -267,8 +266,7 @@ impl LogicalType {
LogicalType::UBigint => matches!(to, LogicalType::Float | LogicalType::Double),
LogicalType::Float => matches!(to, LogicalType::Double),
LogicalType::Double => false,
LogicalType::Varchar(_len) => false,
LogicalType::Nvarchar(_len) => false,
LogicalType::Varchar(_) => false,
LogicalType::Date => matches!(to, LogicalType::DateTime | LogicalType::Varchar(_)),
LogicalType::DateTime => matches!(to, LogicalType::Date | LogicalType::Varchar(_)),
}
Expand All @@ -283,7 +281,6 @@ impl TryFrom<sqlparser::ast::DataType> for LogicalType {
match value {
sqlparser::ast::DataType::Char(len)
| sqlparser::ast::DataType::Varchar(len)=> Ok(LogicalType::Varchar(len.map(|len| len.length as u32))),
sqlparser::ast::DataType::Nvarchar(len) => Ok(LogicalType::Nvarchar(len.map(|len| len as u32))),
sqlparser::ast::DataType::Float(_) => Ok(LogicalType::Float),
sqlparser::ast::DataType::Double => Ok(LogicalType::Double),
sqlparser::ast::DataType::TinyInt(_) => Ok(LogicalType::Tinyint),
Expand Down Expand Up @@ -339,19 +336,4 @@ mod test {
fn test_id_generator_reset() {
ID_BUF.store(0, Release)
}


#[test]
fn test_logical_type() {
assert_eq!(LogicalType::Integer.raw_len(), Some(4));
assert_eq!(LogicalType::UInteger.raw_len(), Some(4));
assert_eq!(LogicalType::Bigint.raw_len(), Some(8));
assert_eq!(LogicalType::UBigint.raw_len(), Some(8));
assert_eq!(LogicalType::Float.raw_len(), Some(4));
assert_eq!(LogicalType::Double.raw_len(), Some(8));
assert_eq!(LogicalType::Varchar(Some(10)).raw_len(), Some(10));
assert_eq!(LogicalType::Nvarchar(Some(10)).raw_len(), Some(10));
assert_eq!(LogicalType::Date.raw_len(), Some(4));
assert_eq!(LogicalType::DateTime.raw_len(), Some(8));
}
}
4 changes: 3 additions & 1 deletion src/types/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ impl Tuple {
if bit_index(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) {
values.push(Arc::new(DataValue::none(logic_type)));
} else if let Some(len) = logic_type.raw_len() {
/// fixed length (e.g.: int)
values.push(Arc::new(DataValue::from_raw(&bytes[pos..pos + len], logic_type)));
pos += len;
} else {
/// variable length (e.g.: varchar)
let len = u32::decode_fixed(&bytes[pos..pos + 4]) as usize;
pos += 4;
values.push(Arc::new(DataValue::from_raw(&bytes[pos..pos + len], logic_type)));
Expand Down Expand Up @@ -133,7 +135,7 @@ mod tests {
Arc::new(ColumnCatalog::new(
"c3".to_string(),
false,
ColumnDesc::new(LogicalType::Varchar(None), false)
ColumnDesc::new(LogicalType::Varchar(Some(2)), false)
)),
Arc::new(ColumnCatalog::new(
"c4".to_string(),
Expand Down
Loading

0 comments on commit e66417d

Please sign in to comment.