From 41719e314b175ffe8feb9f2881e4c9ee9225352e Mon Sep 17 00:00:00 2001 From: Sevenannn Date: Fri, 16 Aug 2024 00:13:04 -0700 Subject: [PATCH] Support decimal 256 in building insert table statement --- Cargo.toml | 1 + src/sql/arrow_sql_gen/arrow.rs | 17 +++++++++++------ src/sql/arrow_sql_gen/statement.rs | 20 +++++++++++++++++++- tests/arrow_record_batch_gen/mod.rs | 2 +- tests/postgres/mod.rs | 2 -- 5 files changed, 32 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 068df65..26edc8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ repository = "https://github.com/datafusion-contrib/datafusion-table-providers" arrow = "52.0.0" async-stream = { version = "0.3.5", optional = true } async-trait = "0.1.80" +num-bigint = "0.4.4" bigdecimal = "0.4.5" bigdecimal_0_3_0 = { package = "bigdecimal", version = "0.3.0" } byteorder = "1.5.0" diff --git a/src/sql/arrow_sql_gen/arrow.rs b/src/sql/arrow_sql_gen/arrow.rs index 7ab0495..7c6a46a 100644 --- a/src/sql/arrow_sql_gen/arrow.rs +++ b/src/sql/arrow_sql_gen/arrow.rs @@ -1,12 +1,12 @@ use arrow::{ array::{ ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Date64Builder, - Decimal128Builder, FixedSizeBinaryBuilder, Float32Builder, Float64Builder, Int16Builder, - Int32Builder, Int64Builder, Int8Builder, IntervalMonthDayNanoBuilder, LargeBinaryBuilder, - LargeStringBuilder, ListBuilder, NullBuilder, StringBuilder, StructBuilder, - Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampMillisecondBuilder, - TimestampNanosecondBuilder, TimestampSecondBuilder, UInt16Builder, UInt32Builder, - UInt64Builder, UInt8Builder, + Decimal128Builder, Decimal256Builder, FixedSizeBinaryBuilder, Float32Builder, + Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, + IntervalMonthDayNanoBuilder, LargeBinaryBuilder, LargeStringBuilder, ListBuilder, + NullBuilder, StringBuilder, StructBuilder, Time64NanosecondBuilder, + TimestampMicrosecondBuilder, TimestampMillisecondBuilder, TimestampNanosecondBuilder, + TimestampSecondBuilder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, }, datatypes::{DataType, TimeUnit}, }; @@ -43,6 +43,11 @@ pub fn map_data_type_to_array_builder(data_type: &DataType) -> Box Box::new( + Decimal256Builder::new() + .with_precision_and_scale(*precision, *scale) + .unwrap_or_default(), + ), DataType::Timestamp(time_unit, time_zone) => match time_unit { TimeUnit::Microsecond => { Box::new(TimestampMicrosecondBuilder::new().with_timezone_opt(time_zone.clone())) diff --git a/src/sql/arrow_sql_gen/statement.rs b/src/sql/arrow_sql_gen/statement.rs index 94699b2..d66b5d3 100644 --- a/src/sql/arrow_sql_gen/statement.rs +++ b/src/sql/arrow_sql_gen/statement.rs @@ -4,6 +4,7 @@ use arrow::{ util::display::array_value_to_string, }; use bigdecimal_0_3_0::BigDecimal; +use num_bigint::BigInt; use sea_query::{ Alias, ColumnDef, ColumnType, Expr, GenericBuilder, Index, InsertStatement, IntoIden, IntoIndexColumn, Keyword, MysqlQueryBuilder, OnConflict, PostgresQueryBuilder, Query, @@ -217,6 +218,21 @@ impl InsertBuilder { ); } } + DataType::Decimal256(_, scale) => { + let array = column.as_any().downcast_ref::(); + if let Some(valid_array) = array { + if valid_array.is_null(row) { + row_values.push(Keyword::Null.into()); + continue; + } + + let bigint = + BigInt::from_signed_bytes_le(&valid_array.value(row).to_le_bytes()); + + println!("{:?}", BigDecimal::new(bigint.clone(), i64::from(*scale))); + row_values.push(BigDecimal::new(bigint, i64::from(*scale)).into()); + } + } DataType::Date32 => { let array = column.as_any().downcast_ref::(); if let Some(valid_array) = array { @@ -962,7 +978,9 @@ pub(crate) fn map_data_type_to_column_type(data_type: &DataType) -> ColumnType { DataType::Utf8 | DataType::LargeUtf8 => ColumnType::Text, DataType::Boolean => ColumnType::Boolean, #[allow(clippy::cast_sign_loss)] // This is safe because scale will never be negative - DataType::Decimal128(p, s) => ColumnType::Decimal(Some((u32::from(*p), *s as u32))), + DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => { + ColumnType::Decimal(Some((u32::from(*p), *s as u32))) + } DataType::Timestamp(_unit, _time_zone) => ColumnType::Timestamp, DataType::Date32 | DataType::Date64 => ColumnType::Date, DataType::Time64(_unit) | DataType::Time32(_unit) => ColumnType::Time, diff --git a/tests/arrow_record_batch_gen/mod.rs b/tests/arrow_record_batch_gen/mod.rs index 9fd1b4b..d7fcd7a 100644 --- a/tests/arrow_record_batch_gen/mod.rs +++ b/tests/arrow_record_batch_gen/mod.rs @@ -294,7 +294,7 @@ pub(crate) fn get_arrow_decimal_record_batch() -> RecordBatch { let decimal128_array = Decimal128Array::from(vec![i128::from(123), i128::from(222), i128::from(321)]); let decimal256_array = - Decimal256Array::from(vec![i256::from(123), i256::from(222), i256::from(321)]); + Decimal256Array::from(vec![i256::from(-123), i256::from(222), i256::from(0)]); let schema = Schema::new(vec![ Field::new("decimal128", DataType::Decimal128(38, 10), false), diff --git a/tests/postgres/mod.rs b/tests/postgres/mod.rs index bc65506..7723151 100644 --- a/tests/postgres/mod.rs +++ b/tests/postgres/mod.rs @@ -111,10 +111,8 @@ mod test { #[ignore] // TODO: time types are broken in Postgres #[case::time(get_arrow_time_record_batch(), "time")] #[case::timestamp(get_arrow_timestamp_record_batch(), "timestamp")] - #[ignore] // TODO: date types are broken in Postgres #[case::date(get_arrow_date_record_batch(), "date")] #[case::struct_type(get_arrow_struct_record_batch(), "struct")] - #[ignore] // TODO: decimal types are broken in Postgres #[case::decimal(get_arrow_decimal_record_batch(), "decimal")] #[case::interval(get_arrow_interval_record_batch(), "interval")] #[ignore] // TODO: duration types are broken in Postgres