From 045c4eee723496ff4e3470a0c7864868b1f3d999 Mon Sep 17 00:00:00 2001 From: Xwg Date: Sun, 24 Sep 2023 19:23:34 +0800 Subject: [PATCH 1/3] fret(type): add support for Decimal type in database The database schema has been updated to support the Decimal data type. This includes casting from and to Decimal type in the DataValue structure, creating a new Decimal field in both LogicalType and DataValue structures, and adding appropriate match arms in the from_sqlparser_data_type and can_implicit_cast functions. A new error type for when an attempted conversion from Decimal fails have also been implemented. The round_dp_with_strategy method is used to provide flexibility with scale modification and rounding strategy when casting to Decimal type. Tests for the new functionality have been added as well. --- Cargo.toml | 1 + src/db.rs | 7 +++ src/expression/mod.rs | 2 +- src/storage/table_codec.rs | 9 ++- src/types/errors.rs | 6 ++ src/types/mod.rs | 16 ++++- src/types/value.rs | 122 +++++++++++++++++++++++++++++++++---- tests/slt/decimal | 6 ++ 8 files changed, 153 insertions(+), 16 deletions(-) create mode 100644 tests/slt/decimal diff --git a/Cargo.toml b/Cargo.toml index 35152592..a1ba0656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ comfy-table = "7.0.1" bytes = "*" kip_db = "0.1.2-alpha.15" async-recursion = "1.0.5" +rust_decimal = "1" [dev-dependencies] tokio-test = "0.4.2" diff --git a/src/db.rs b/src/db.rs index bc768d16..a8accc6e 100644 --- a/src/db.rs +++ b/src/db.rs @@ -186,6 +186,9 @@ mod test { let _ = kipsql.run("create table t2 (c int primary key, d int unsigned null, e datetime)").await?; let _ = kipsql.run("insert into t1 (a, b, k) values (-99, 1, 1), (-1, 2, 2), (5, 2, 2)").await?; let _ = kipsql.run("insert into t2 (d, c, e) values (2, 1, '2021-05-20 21:00:00'), (3, 4, '2023-09-10 00:00:00')").await?; + let _ = kipsql.run("create table t3 (a int primary key, b decimal(4,2))").await?; + let _ = kipsql.run("insert into t3 (a, b) values (1, 1111), (2, 2.01), (3, 3.00)").await?; + let _ = kipsql.run("insert into t3 (a, b) values (4, 4444), (5, 5222), (6, 1.00)").await?; println!("full t1:"); let tuples_full_fields_t1 = kipsql.run("select * from t1").await?; @@ -305,6 +308,10 @@ mod test { let tuples_show_tables = kipsql.run("show tables").await?; println!("{}", create_table(&tuples_show_tables)); + println!("decimal:"); + let tuples_decimal = kipsql.run("select * from t3").await?; + println!("{}", create_table(&tuples_decimal)); + Ok(()) } } diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 456d26a5..30678993 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -1,5 +1,5 @@ use std::fmt; -use std::fmt::Formatter; +use std::fmt::{Debug, Formatter}; use std::sync::Arc; use itertools::Itertools; diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 18de475a..aab27fef 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -146,6 +146,7 @@ mod tests { use std::ops::Bound; use std::sync::Arc; use itertools::Itertools; + use rust_decimal::Decimal; use crate::catalog::{ColumnCatalog, ColumnDesc, TableCatalog}; use crate::storage::table_codec::{COLUMNS_ID_LEN, TableCodec}; use crate::types::errors::TypeError; @@ -159,7 +160,12 @@ mod tests { "c1".into(), false, ColumnDesc::new(LogicalType::Integer, true) - ) + ), + ColumnCatalog::new( + "c2".into(), + false, + ColumnDesc::new(LogicalType::Decimal(None,None), false) + ), ]; let table_catalog = TableCatalog::new(Arc::new("t1".to_string()), columns).unwrap(); let codec = TableCodec { table: table_catalog.clone() }; @@ -175,6 +181,7 @@ mod tests { columns: table_catalog.all_columns(), values: vec![ Arc::new(DataValue::Int32(Some(0))), + Arc::new(DataValue::Decimal(Some(Decimal::new(1, 0)))), ] }; diff --git a/src/types/errors.rs b/src/types/errors.rs index cd5aa59b..c9476f62 100644 --- a/src/types/errors.rs +++ b/src/types/errors.rs @@ -46,4 +46,10 @@ pub enum TypeError { #[from] ParseError, ), + #[error("try from decimal")] + TryFromDecimal( + #[source] + #[from] + rust_decimal::Error, + ), } diff --git a/src/types/mod.rs b/src/types/mod.rs index cb4f2fce..c0ba1470 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -7,6 +7,7 @@ use std::sync::atomic::Ordering::{Acquire, Release}; use serde::{Deserialize, Serialize}; use integer_encoding::FixedInt; +use sqlparser::ast::ExactNumberInfo; use strum_macros::AsRefStr; use crate::types::errors::TypeError; @@ -57,6 +58,8 @@ pub enum LogicalType { Varchar(Option), Date, DateTime, + // decimal (precision, scale) + Decimal(Option, Option), } impl LogicalType { @@ -75,8 +78,9 @@ impl LogicalType { LogicalType::UBigint => Some(8), LogicalType::Float => Some(4), LogicalType::Double => Some(8), - /// Note: The non-fixed length type's raw_len is None + /// Note: The non-fixed length type's raw_len is None e.g. Varchar and Decimal LogicalType::Varchar(_)=>None, + LogicalType::Decimal(_, _) =>None, LogicalType::Date => Some(4), LogicalType::DateTime => Some(8), } @@ -269,6 +273,7 @@ impl LogicalType { LogicalType::Varchar(_) => false, LogicalType::Date => matches!(to, LogicalType::DateTime | LogicalType::Varchar(_)), LogicalType::DateTime => matches!(to, LogicalType::Date | LogicalType::Varchar(_)), + LogicalType::Decimal(_, _) => false, } } } @@ -296,6 +301,13 @@ impl TryFrom for LogicalType { sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint), sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean), sqlparser::ast::DataType::Datetime(_) => Ok(LogicalType::DateTime), + sqlparser::ast::DataType::Decimal(info) => match info { + ExactNumberInfo::None => Ok(Self::Decimal(None, None)), + ExactNumberInfo::Precision(p) => Ok(Self::Decimal(Some(p as u8), None)), + ExactNumberInfo::PrecisionAndScale(p, s) => { + Ok(Self::Decimal(Some(p as u8), Some(s as u8))) + } + }, other => Err(TypeError::NotImplementedSqlparserDataType( other.to_string(), )), @@ -313,7 +325,7 @@ impl std::fmt::Display for LogicalType { mod test { use std::sync::atomic::Ordering::Release; - use crate::types::{IdGenerator, ID_BUF, LogicalType}; + use crate::types::{IdGenerator, ID_BUF}; /// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰 #[test] diff --git a/src/types/value.rs b/src/types/value.rs index 3e02ee19..41249b07 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -8,8 +8,10 @@ use chrono::{NaiveDateTime, Datelike, NaiveDate}; use chrono::format::{DelayedFormat, StrftimeItems}; use integer_encoding::FixedInt; use lazy_static::lazy_static; +use rust_decimal::Decimal; use ordered_float::OrderedFloat; +use rust_decimal::prelude::{FromPrimitive, Signed}; use crate::types::errors::TypeError; use super::LogicalType; @@ -44,6 +46,7 @@ pub enum DataValue { Date32(Option), /// Date stored as a signed 64bit int timestamp since UNIX epoch 1970-01-01 Date64(Option), + Decimal(Option), } impl PartialEq for DataValue { @@ -88,6 +91,8 @@ impl PartialEq for DataValue { (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), (Date64(_), _) => false, + (Decimal(v1), Decimal(v2)) => v1.eq(v2), + (Decimal(_), _) => false, } } } @@ -134,6 +139,8 @@ impl PartialOrd for DataValue { (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), (Date64(_), _) => None, + (Decimal(v1), Decimal(v2)) => v1.partial_cmp(v2), + (Decimal(_), _) => None, } } } @@ -175,6 +182,7 @@ impl Hash for DataValue { Null => 1.hash(state), Date32(v) => v.hash(state), Date64(v) => v.hash(state), + Decimal(v) => v.hash(state), } } } @@ -193,20 +201,59 @@ macro_rules! varchar_cast { }; } +macro_rules! check_decimal_length { + ($data_value:expr, $logic_type:expr) => { + if let LogicalType::Decimal(precision, scale) = $logic_type { + let data_value_str = $data_value.to_string(); + let data_value_precision = data_value_str.chars().filter(|c| *c >= '0' && *c <= '9').count(); + if data_value_precision > precision.unwrap() as usize { + return Err(TypeError::TooLong); + } + if $data_value.scale() > scale.unwrap() as u32 { + return Err(TypeError::TooLong); + } + }else{ + return Ok(()) + } + }; +} + impl DataValue { pub(crate) fn check_length(&self, logic_type: &LogicalType) -> Result<(), TypeError> { match self { DataValue::Boolean(_) => return Ok(()), - DataValue::Float32(_) => return Ok(()), - DataValue::Float64(_) => return Ok(()), - DataValue::Int8(_) => return Ok(()), - DataValue::Int16(_) => return Ok(()), - DataValue::Int32(_) => return Ok(()), - DataValue::Int64(_) => return Ok(()), - DataValue::UInt8(_) => return Ok(()), - DataValue::UInt16(_) => return Ok(()), - DataValue::UInt32(_) => return Ok(()), - DataValue::UInt64(_) => return Ok(()), + DataValue::Float32(v) => { + // check literal to decimal + check_decimal_length!(Decimal::from_f32(v.unwrap()).unwrap(), logic_type) + } + DataValue::Float64(v) =>{ + // check literal to decimal + check_decimal_length!(Decimal::from_f64(v.unwrap()).unwrap(), logic_type) + }, + DataValue::Int8(v) => { + check_decimal_length!(Decimal::from(v.unwrap()), logic_type) + } + DataValue::Int16(v) => { + check_decimal_length!(Decimal::from(v.unwrap()), logic_type) + } + DataValue::Int32(v) => { + check_decimal_length!(Decimal::from(v.unwrap()), logic_type) + } + DataValue::Int64(v) => { + check_decimal_length!(Decimal::from(v.unwrap()), logic_type) + } + DataValue::UInt8(v) => { + check_decimal_length!(Decimal::from(v.unwrap()), logic_type) + } + DataValue::UInt16(v) => { + check_decimal_length!(Decimal::from(v.unwrap()), logic_type) + } + DataValue::UInt32(v) => { + check_decimal_length!(Decimal::from(v.unwrap()), logic_type) + } + DataValue::UInt64(v) => { + check_decimal_length!(Decimal::from(v.unwrap()), logic_type) + } DataValue::Date32(_) => return Ok(()), DataValue::Date64(_) => return Ok(()), DataValue::Utf8(value) => { @@ -218,6 +265,15 @@ impl DataValue { } } } + DataValue::Decimal(value) => { + if let LogicalType::Decimal(_, scale) = logic_type { + if let Some(value) = value { + if value.scale() as u8 > scale.ok_or(TypeError::InvalidType)? { + return Err(TypeError::TooLong); + } + } + } + } _ => { return Err(TypeError::InvalidType); } } Ok(()) @@ -238,6 +294,7 @@ impl DataValue { pub fn is_variable(&self) -> bool { match self { DataValue::Utf8(_) => true, + DataValue::Decimal(_) => true, _ => false } } @@ -259,6 +316,7 @@ impl DataValue { DataValue::Utf8(value) => value.is_none(), DataValue::Date32(value) => value.is_none(), DataValue::Date64(value) => value.is_none(), + DataValue::Decimal(value) => value.is_none(), } } @@ -279,7 +337,8 @@ impl DataValue { LogicalType::Double => DataValue::Float64(None), LogicalType::Varchar(_) => DataValue::Utf8(None), LogicalType::Date => DataValue::Date32(None), - LogicalType::DateTime => DataValue::Date64(None) + LogicalType::DateTime => DataValue::Date64(None), + LogicalType::Decimal(_, _) => DataValue::Decimal(None), } } @@ -300,7 +359,8 @@ impl DataValue { LogicalType::Double => DataValue::Float64(Some(0.0)), LogicalType::Varchar(_) => DataValue::Utf8(Some("".to_string())), LogicalType::Date => DataValue::Date32(Some(UNIX_DATETIME.num_days_from_ce())), - LogicalType::DateTime => DataValue::Date64(Some(UNIX_DATETIME.timestamp())) + LogicalType::DateTime => DataValue::Date64(Some(UNIX_DATETIME.timestamp())), + LogicalType::Decimal(_, _) => DataValue::Decimal(Some(Decimal::new(0, 0))), } } @@ -321,6 +381,7 @@ impl DataValue { DataValue::Utf8(v) => v.clone().map(|v| v.into_bytes()), DataValue::Date32(v) => v.map(|v| v.encode_fixed_vec()), DataValue::Date64(v) => v.map(|v| v.encode_fixed_vec()), + DataValue::Decimal(v) => v.clone().map(|v| v.serialize().to_vec()), }.unwrap_or(vec![]) } @@ -350,6 +411,7 @@ impl DataValue { LogicalType::Varchar(_) => DataValue::Utf8((!bytes.is_empty()).then(|| String::from_utf8(bytes.to_owned()).unwrap())), LogicalType::Date => DataValue::Date32((!bytes.is_empty()).then(|| i32::decode_fixed(bytes))), LogicalType::DateTime => DataValue::Date64((!bytes.is_empty()).then(|| i64::decode_fixed(bytes))), + LogicalType::Decimal(_, _) => DataValue::Decimal((!bytes.is_empty()).then(|| Decimal::deserialize(<[u8; 16]>::try_from(bytes).unwrap()))), } } @@ -370,6 +432,7 @@ impl DataValue { DataValue::Utf8(_) => LogicalType::Varchar(None), DataValue::Date32(_) => LogicalType::Date, DataValue::Date64(_) => LogicalType::DateTime, + DataValue::Decimal(_) => LogicalType::Decimal(None, None), } } @@ -408,6 +471,7 @@ impl DataValue { LogicalType::Varchar(_) => Ok(DataValue::Utf8(None)), LogicalType::Date => Ok(DataValue::Date32(None)), LogicalType::DateTime => Ok(DataValue::Date64(None)), + LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(None)), } } DataValue::Boolean(value) => { @@ -434,6 +498,9 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value)), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_,s) =>{ + Ok(DataValue::Decimal(value.map(|v| Decimal::from_f32(v).unwrap().round_dp_with_strategy( s.clone().unwrap() as u32, rust_decimal::RoundingStrategy::MidpointAwayFromZero)))) + } _ => Err(TypeError::CastFail), } } @@ -442,6 +509,9 @@ impl DataValue { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::Double => Ok(DataValue::Float64(value)), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_,s) => { + Ok(DataValue::Decimal(value.map(|v| Decimal::from_f64(v).unwrap().round_dp_with_strategy( s.clone().unwrap() as u32, rust_decimal::RoundingStrategy::MidpointAwayFromZero)))) + } _ => Err(TypeError::CastFail), } } @@ -459,6 +529,7 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), _ => Err(TypeError::CastFail), } } @@ -475,6 +546,7 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), _ => Err(TypeError::CastFail), } } @@ -489,6 +561,7 @@ impl DataValue { LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), _ => Err(TypeError::CastFail), } } @@ -501,6 +574,7 @@ impl DataValue { LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::try_from(v)).transpose()?)), LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), _ => Err(TypeError::CastFail), } } @@ -517,6 +591,7 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), _ => Err(TypeError::CastFail), } } @@ -531,6 +606,7 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), _ => Err(TypeError::CastFail), } } @@ -542,6 +618,7 @@ impl DataValue { LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), _ => Err(TypeError::CastFail), } } @@ -550,6 +627,7 @@ impl DataValue { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), + LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), _ => Err(TypeError::CastFail), } } @@ -588,6 +666,9 @@ impl DataValue { }).transpose()?; Ok(DataValue::Date64(option)) + }, + LogicalType::Decimal(_, _) => { + Ok(DataValue::Decimal(value.map(|v| Decimal::from_str(&v)).transpose()?)) } } } @@ -624,6 +705,14 @@ impl DataValue { _ => Err(TypeError::CastFail), } } + DataValue::Decimal(value) => { + match to { + LogicalType::SqlNull => Ok(DataValue::Null), + LogicalType::Decimal(_, _) => Ok(DataValue::Decimal(value)), + LogicalType::Varchar(len) => varchar_cast!(value, len), + _ => Err(TypeError::CastFail), + } + } } } @@ -636,6 +725,11 @@ impl DataValue { NaiveDateTime::from_timestamp_opt(v, 0) .map(|date_time| date_time.format(DATE_TIME_FMT)) } + + fn decimal_format(v: &Decimal) -> String { + v.to_string() + + } } macro_rules! impl_scalar { @@ -724,6 +818,9 @@ impl fmt::Display for DataValue { DataValue::Date64(e) => { format_option!(f, e.and_then(|s| DataValue::date_time_format(s)))? } + DataValue::Decimal(e) => { + format_option!(f, e.as_ref().map(|s| DataValue::decimal_format(s)))? + } }; Ok(()) } @@ -748,6 +845,7 @@ impl fmt::Debug for DataValue { DataValue::Null => write!(f, "null"), DataValue::Date32(_) => write!(f, "Date32({})", self), DataValue::Date64(_) => write!(f, "Date64({})", self), + DataValue::Decimal(_) => write!(f, "Decimal({})", self), } } } diff --git a/tests/slt/decimal b/tests/slt/decimal new file mode 100644 index 00000000..4d566b0a --- /dev/null +++ b/tests/slt/decimal @@ -0,0 +1,6 @@ + +statement ok +CREATE TABLE mytable ( title varchar(256) primary key, cost decimal(4,2)); + +statement ok +INSERT INTO mytable (title, cost) VALUES ('A', 1.00); \ No newline at end of file From dfcf5c6a6781fc28186efb864ada00ba6a6e1314 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 25 Sep 2023 03:35:41 +0800 Subject: [PATCH 2/3] fix: Optimize Decimal's check length logic and fix interference with other types --- src/binder/insert.rs | 2 +- src/binder/update.rs | 2 +- src/types/value.rs | 91 +++++++++----------------------------------- 3 files changed, 21 insertions(+), 74 deletions(-) diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 0a5e0b94..e5e21c93 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -50,7 +50,7 @@ impl Binder { match &self.bind_expr(expr).await? { ScalarExpression::Constant(value) => { // Check if the value length is too long - value.check_length(columns[i].datatype())?; + value.check_len(columns[i].datatype())?; let cast_value = DataValue::clone(value) .cast(columns[i].datatype())?; row.push(Arc::new(cast_value)) diff --git a/src/binder/update.rs b/src/binder/update.rs index 92de8986..ea99b4cc 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -44,7 +44,7 @@ impl Binder { bind_table_name.as_ref() ).await? { ScalarExpression::ColumnRef(catalog) => { - value.check_length(catalog.datatype())?; + value.check_len(catalog.datatype())?; columns.push(catalog); row.push(value.clone()); }, diff --git a/src/types/value.rs b/src/types/value.rs index 41249b07..96ab9183 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -11,7 +11,7 @@ use lazy_static::lazy_static; use rust_decimal::Decimal; use ordered_float::OrderedFloat; -use rust_decimal::prelude::{FromPrimitive, Signed}; +use rust_decimal::prelude::FromPrimitive; use crate::types::errors::TypeError; use super::LogicalType; @@ -201,81 +201,28 @@ macro_rules! varchar_cast { }; } -macro_rules! check_decimal_length { - ($data_value:expr, $logic_type:expr) => { - if let LogicalType::Decimal(precision, scale) = $logic_type { - let data_value_str = $data_value.to_string(); - let data_value_precision = data_value_str.chars().filter(|c| *c >= '0' && *c <= '9').count(); - if data_value_precision > precision.unwrap() as usize { - return Err(TypeError::TooLong); - } - if $data_value.scale() > scale.unwrap() as u32 { - return Err(TypeError::TooLong); - } - }else{ - return Ok(()) - } - }; -} - impl DataValue { - pub(crate) fn check_length(&self, logic_type: &LogicalType) -> Result<(), TypeError> { - match self { - DataValue::Boolean(_) => return Ok(()), - DataValue::Float32(v) => { - // check literal to decimal - check_decimal_length!(Decimal::from_f32(v.unwrap()).unwrap(), logic_type) - } - DataValue::Float64(v) =>{ - // check literal to decimal - check_decimal_length!(Decimal::from_f64(v.unwrap()).unwrap(), logic_type) - }, - DataValue::Int8(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::Int16(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::Int32(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::Int64(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::UInt8(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::UInt16(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::UInt32(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::UInt64(v) => { - check_decimal_length!(Decimal::from(v.unwrap()), logic_type) - } - DataValue::Date32(_) => return Ok(()), - DataValue::Date64(_) => return Ok(()), - DataValue::Utf8(value) => { - if let LogicalType::Varchar(len) = logic_type { - if let Some(len) = len { - if value.as_ref().map(|v| v.len() > *len as usize).unwrap_or(false) { - return Err(TypeError::TooLong); - } - } - } - } - DataValue::Decimal(value) => { - if let LogicalType::Decimal(_, scale) = logic_type { - if let Some(value) = value { - if value.scale() as u8 > scale.ok_or(TypeError::InvalidType)? { - return Err(TypeError::TooLong); - } - } + pub(crate) fn check_len(&self, logic_type: &LogicalType) -> Result<(), TypeError> { + let is_over_len = match (logic_type, self) { + (LogicalType::Varchar(Some(len)), DataValue::Utf8(Some(val))) => { + val.len() > *len as usize + } + (LogicalType::Decimal(full_len, scale_len), DataValue::Decimal(Some(val))) => { + if let Some(len) = full_len { + val.mantissa().ilog10() + 1 > *len as u32 + } else if let Some(len) = scale_len { + val.scale() > *len as u32 + } else { + false } } - _ => { return Err(TypeError::InvalidType); } + _ => false + }; + + if is_over_len { + return Err(TypeError::TooLong) } + Ok(()) } From f63c8da1926ad15164a00260fd52c58ccafcbddf Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 25 Sep 2023 04:03:25 +0800 Subject: [PATCH 3/3] fix: `DataValue::cast` fixes Decimal conversion problem --- src/types/mod.rs | 4 +- src/types/value.rs | 94 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 82 insertions(+), 16 deletions(-) diff --git a/src/types/mod.rs b/src/types/mod.rs index c0ba1470..de67aed9 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -79,8 +79,8 @@ impl LogicalType { LogicalType::Float => Some(4), LogicalType::Double => Some(8), /// Note: The non-fixed length type's raw_len is None e.g. Varchar and Decimal - LogicalType::Varchar(_)=>None, - LogicalType::Decimal(_, _) =>None, + LogicalType::Varchar(_) => None, + LogicalType::Decimal(_, _) => Some(16), LogicalType::Date => Some(4), LogicalType::DateTime => Some(8), } diff --git a/src/types/value.rs b/src/types/value.rs index 96ab9183..94b856f3 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -1,5 +1,5 @@ use std::cmp::Ordering; -use std::fmt; +use std::{fmt, mem}; use std::fmt::Formatter; use std::hash::Hash; use std::str::FromStr; @@ -241,7 +241,6 @@ impl DataValue { pub fn is_variable(&self) -> bool { match self { DataValue::Utf8(_) => true, - DataValue::Decimal(_) => true, _ => false } } @@ -445,8 +444,13 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value)), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) =>{ - Ok(DataValue::Decimal(value.map(|v| Decimal::from_f32(v).unwrap().round_dp_with_strategy( s.clone().unwrap() as u32, rust_decimal::RoundingStrategy::MidpointAwayFromZero)))) + LogicalType::Decimal(_, option) =>{ + Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from_f32(v).ok_or(TypeError::CastFail)?; + Self::decimal_round_f(option, &mut decimal); + + Ok::(decimal) + }).transpose()?)) } _ => Err(TypeError::CastFail), } @@ -456,8 +460,13 @@ impl DataValue { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::Double => Ok(DataValue::Float64(value)), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => { - Ok(DataValue::Decimal(value.map(|v| Decimal::from_f64(v).unwrap().round_dp_with_strategy( s.clone().unwrap() as u32, rust_decimal::RoundingStrategy::MidpointAwayFromZero)))) + LogicalType::Decimal(_, option) => { + Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from_f64(v).ok_or(TypeError::CastFail)?; + Self::decimal_round_f(option, &mut decimal); + + Ok::(decimal) + }).transpose()?)) } _ => Err(TypeError::CastFail), } @@ -476,7 +485,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -493,7 +507,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -508,7 +527,12 @@ impl DataValue { LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -521,7 +545,12 @@ impl DataValue { LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| u64::try_from(v)).transpose()?)), LogicalType::Bigint => Ok(DataValue::Int64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -538,7 +567,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -553,7 +587,12 @@ impl DataValue { LogicalType::Float => Ok(DataValue::Float32(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -565,7 +604,12 @@ impl DataValue { LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Double => Ok(DataValue::Float64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -574,7 +618,12 @@ impl DataValue { LogicalType::SqlNull => Ok(DataValue::Null), LogicalType::UBigint => Ok(DataValue::UInt64(value.map(|v| v.into()))), LogicalType::Varchar(len) => varchar_cast!(value, len), - LogicalType::Decimal(_,s) => Ok(DataValue::Decimal(value.map(|v| Decimal::from(v).trunc_with_scale(s.unwrap() as u32)))), + LogicalType::Decimal(_, option) => Ok(DataValue::Decimal(value.map(|v| { + let mut decimal = Decimal::from(v); + Self::decimal_round_i(option, &mut decimal); + + decimal + }))), _ => Err(TypeError::CastFail), } } @@ -663,6 +712,23 @@ impl DataValue { } } + fn decimal_round_i(option: &Option, decimal: &mut Decimal) { + if let Some(scale) = option { + let new_decimal = decimal.trunc_with_scale(*scale as u32); + let _ = mem::replace(decimal, new_decimal); + } + } + + fn decimal_round_f(option: &Option, decimal: &mut Decimal) { + if let Some(scale) = option { + let new_decimal = decimal.round_dp_with_strategy( + *scale as u32, + rust_decimal::RoundingStrategy::MidpointAwayFromZero + ); + let _ = mem::replace(decimal, new_decimal); + } + } + fn date_format<'a>(v: i32) -> Option>> { NaiveDate::from_num_days_from_ce_opt(v) .map(|date| date.format(DATE_FMT))