diff --git a/src/sql/arrow_sql_gen/mysql.rs b/src/sql/arrow_sql_gen/mysql.rs index 4dee42e..e6ffaa2 100644 --- a/src/sql/arrow_sql_gen/mysql.rs +++ b/src/sql/arrow_sql_gen/mysql.rs @@ -1,17 +1,17 @@ use crate::sql::arrow_sql_gen::arrow::map_data_type_to_array_builder_optional; use arrow::{ array::{ - ArrayBuilder, ArrayRef, Date32Builder, Decimal128Builder, Float32Builder, Float64Builder, - Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeStringBuilder, NullBuilder, - RecordBatch, RecordBatchOptions, Time64NanosecondBuilder, TimestampMillisecondBuilder, - UInt64Builder, + ArrayBuilder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Float32Builder, + Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeStringBuilder, + NullBuilder, RecordBatch, RecordBatchOptions, Time64NanosecondBuilder, + TimestampMillisecondBuilder, UInt64Builder, }, datatypes::{DataType, Date32Type, Field, Schema, TimeUnit}, }; use bigdecimal::BigDecimal; use bigdecimal::ToPrimitive; use chrono::{NaiveDate, NaiveTime, Timelike}; -use mysql_async::{consts::ColumnType, FromValueError, Row, Value}; +use mysql_async::{consts::ColumnFlags, consts::ColumnType, FromValueError, Row, Value}; use snafu::{ResultExt, Snafu}; use std::{convert, sync::Arc}; @@ -93,13 +93,15 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { let mut arrow_columns_builders: Vec>> = Vec::new(); let mut mysql_types: Vec = Vec::new(); let mut column_names: Vec = Vec::new(); + let mut column_is_binary_stats: Vec = Vec::new(); if !rows.is_empty() { let row = &rows[0]; for column in row.columns().iter() { let column_name = column.name_str(); let column_type = column.column_type(); - let data_type = map_column_to_data_type(column_type); + let column_is_binary = column.flags().contains(ColumnFlags::BINARY_FLAG); + let data_type = map_column_to_data_type(column_type, column_is_binary); arrow_fields.push( data_type .clone() @@ -109,6 +111,7 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { .push(map_data_type_to_array_builder_optional(data_type.as_ref())); mysql_types.push(column_type); column_names.push(column_name.to_string()); + column_is_binary_stats.push(column_is_binary); } } @@ -258,8 +261,6 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { dec_builder.append_value(val); } column_type @ (ColumnType::MYSQL_TYPE_VARCHAR - | ColumnType::MYSQL_TYPE_STRING - | ColumnType::MYSQL_TYPE_VAR_STRING | ColumnType::MYSQL_TYPE_JSON | ColumnType::MYSQL_TYPE_TINY_BLOB | ColumnType::MYSQL_TYPE_BLOB @@ -275,6 +276,28 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { i ); } + column_type @ (ColumnType::MYSQL_TYPE_STRING + | ColumnType::MYSQL_TYPE_VAR_STRING) => { + if column_is_binary_stats[i] { + handle_primitive_type!( + builder, + column_type, + BinaryBuilder, + Vec, + row, + i + ); + } else { + handle_primitive_type!( + builder, + column_type, + LargeStringBuilder, + String, + row, + i + ); + } + } ColumnType::MYSQL_TYPE_DATE => { let Some(builder) = builder else { return NoBuilderForIndexSnafu { index: i }.fail(); @@ -407,7 +430,10 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { } #[allow(clippy::unnecessary_wraps)] -pub fn map_column_to_data_type(column_type: ColumnType) -> Option { +pub fn map_column_to_data_type( + column_type: ColumnType, + column_is_binary: bool, +) -> Option { match column_type { ColumnType::MYSQL_TYPE_NULL => Some(DataType::Null), ColumnType::MYSQL_TYPE_BIT => Some(DataType::UInt64), @@ -426,8 +452,6 @@ pub fn map_column_to_data_type(column_type: ColumnType) -> Option { Some(DataType::Time64(TimeUnit::Nanosecond)) } ColumnType::MYSQL_TYPE_VARCHAR - | ColumnType::MYSQL_TYPE_STRING - | ColumnType::MYSQL_TYPE_VAR_STRING | ColumnType::MYSQL_TYPE_JSON | ColumnType::MYSQL_TYPE_ENUM | ColumnType::MYSQL_TYPE_SET @@ -435,7 +459,14 @@ pub fn map_column_to_data_type(column_type: ColumnType) -> Option { | ColumnType::MYSQL_TYPE_BLOB | ColumnType::MYSQL_TYPE_MEDIUM_BLOB | ColumnType::MYSQL_TYPE_LONG_BLOB => Some(DataType::LargeUtf8), - + ColumnType::MYSQL_TYPE_STRING + | ColumnType::MYSQL_TYPE_VAR_STRING => { + if column_is_binary { + Some(DataType::Binary) + } else { + Some(DataType::LargeUtf8) + } + }, // replication only ColumnType::MYSQL_TYPE_TYPED_ARRAY // internal diff --git a/src/sql/arrow_sql_gen/postgres.rs b/src/sql/arrow_sql_gen/postgres.rs index ab04d08..7b376c6 100644 --- a/src/sql/arrow_sql_gen/postgres.rs +++ b/src/sql/arrow_sql_gen/postgres.rs @@ -228,6 +228,9 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { Type::VARCHAR => { handle_primitive_type!(builder, Type::VARCHAR, StringBuilder, &str, row, i); } + Type::BYTEA => { + handle_primitive_type!(builder, Type::BYTEA, BinaryBuilder, Vec, row, i); + } Type::BPCHAR => { let Some(builder) = builder else { return NoBuilderForIndexSnafu { index: i }.fail(); @@ -524,6 +527,7 @@ fn map_column_type_to_data_type(column_type: &Type) -> Option { Type::FLOAT4 => Some(DataType::Float32), Type::FLOAT8 => Some(DataType::Float64), Type::TEXT | Type::VARCHAR | Type::BPCHAR | Type::UUID => Some(DataType::Utf8), + Type::BYTEA => Some(DataType::Binary), Type::BOOL => Some(DataType::Boolean), // Inspect the scale from the first row. Precision will always be 38 for Decimal128. Type::NUMERIC => None, diff --git a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs index 7b5265d..bc806d3 100644 --- a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs +++ b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs @@ -167,6 +167,7 @@ fn columns_meta_to_schema(columns_meta: Vec) -> Result { })?; let column_type = map_str_type_to_column_type(&data_type)?; + let column_is_binary = map_str_type_to_is_binary(&data_type); let arrow_data_type = match column_type { // map_column_to_data_type does not support decimal mapping and uses special logic to handle conversion based on actual value @@ -177,7 +178,7 @@ fn columns_meta_to_schema(columns_meta: Vec) -> Result { // rows_to_arrow uses hardcoded precision 38 for decimal so we use it here as well DataType::Decimal128(38, scale) } - _ => map_column_to_data_type(column_type) + _ => map_column_to_data_type(column_type, column_is_binary) .context(UnsupportedDataTypeSnafu { data_type })?, }; fields.push(Field::new(&column_name, arrow_data_type, true)); @@ -234,6 +235,13 @@ fn map_str_type_to_column_type(data_type: &str) -> Result { Ok(column_type) } +fn map_str_type_to_is_binary(data_type: &str) -> bool { + if data_type.starts_with("binary") | data_type.starts_with("varbinary") { + return true; + } + false +} + fn extract_decimal_precision_and_scale(data_type: &str) -> Result<(u8, i8)> { let (start, end) = match (data_type.find('('), data_type.find(')')) { (Some(start), Some(end)) => (start, end),