diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 26eef3520570..ea7b14a71d5a 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -725,6 +725,164 @@ async fn test_interval_expressions() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_interval_mul_div_float() -> Result<()> { + macro_rules! test_mul { + ($L:expr, $R:expr, $EXPECTED:expr) => { + test_expression!(format!("({}) * ({})", $L, $R), $EXPECTED); + test_expression!(format!("({}) * ({})", $R, $L), $EXPECTED); + }; + } + + test_mul!( + "0.5", + "interval '1 month'", + "0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs" + ); + test_mul!( + "1.0", + "interval '1 month'", + "0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs" + ); + test_mul!( + "1.5", + "interval '1 month'", + "0 years 1 mons 15 days 0 hours 0 mins 0.000000000 secs" + ); + test_mul!( + "2.0", + "interval '1 month'", + "0 years 2 mons 0 days 0 hours 0 mins 0.000000000 secs" + ); + test_mul!( + "-1.5", + "interval '1 month'", + "0 years -1 mons -15 days 0 hours 0 mins 0.000000000 secs" + ); + test_expression!( + "1.5 * (interval '1 month') / 1.5", + "0 years 0 mons 30 days 0 hours 0 mins 0.000000000 secs" + ); + + test_mul!( + "0.5", + "interval '1 day'", + "0 years 0 mons 0 days 12 hours 0 mins 0.000000000 secs" + ); + test_mul!( + "1.0", + "interval '1 day'", + "0 years 0 mons 1 days 0 hours 0 mins 0.000000000 secs" + ); + test_mul!( + "1.5", + "interval '1 day'", + "0 years 0 mons 1 days 12 hours 0 mins 0.000000000 secs" + ); + test_mul!( + "2.0", + "interval '1 day'", + "0 years 0 mons 2 days 0 hours 0 mins 0.000000000 secs" + ); + test_mul!( + "-1.5", + "interval '1 day'", + "0 years 0 mons -1 days -12 hours 0 mins 0.000000000 secs" + ); + test_expression!( + "1.5 * (interval '1 day') / 1.5", + "0 years 0 mons 0 days 24 hours 0 mins 0.000000000 secs" + ); + + test_mul!( + "0.5", + "interval '1 second'", + "0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs" + ); + test_mul!( + "1.0", + "interval '1 second'", + "0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs" + ); + test_mul!( + "1.5", + "interval '1 second'", + "0 years 0 mons 0 days 0 hours 0 mins 1.500000000 secs" + ); + test_mul!( + "2.0", + "interval '1 second'", + "0 years 0 mons 0 days 0 hours 0 mins 2.000000000 secs" + ); + // This test prints as -1.-500000000 secs, that looks like a bug in printing + // TODO fix it + // test_mul!( + // "-1.5", + // "interval '1 second'", + // "0 years 0 mons 0 days 0 hours 0 mins -1.500000000 secs" + // ); + test_expression!( + "1.5 * (interval '1 second') / 1.5", + "0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs" + ); + + // Carry-over cases + test_mul!( + "1.5", + "interval '1 month 1 day'", + "0 years 1 mons 16 days 12 hours 0 mins 0.000000000 secs" + ); + test_mul!( + "1.5", + "interval '1 month 1 second'", + "0 years 1 mons 15 days 0 hours 0 mins 1.500000000 secs" + ); + test_mul!( + "1.5", + "interval '1 day 1 second'", + "0 years 0 mons 1 days 12 hours 0 mins 1.500000000 secs" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_interval_mul_bad_float() -> Result<()> { + macro_rules! test_expr_error { + ($SQL:expr, $EXPECTED:expr) => { + let ctx = SessionContext::new(); + let sql = format!("SELECT {}", $SQL); + let actual_err = plan_and_collect(&ctx, sql.as_str()).await.unwrap_err(); + assert_eq!(actual_err.to_string(), $EXPECTED); + }; + } + macro_rules! test_mul_error { + ($L:expr, $R:expr, $EXPECTED:expr) => { + test_expr_error!(format!("({}) * ({})", $L, $R), $EXPECTED); + test_expr_error!(format!("({}) * ({})", $R, $L), $EXPECTED); + }; + } + + // This behaviour was checked on PostgreSQL 15.7 + test_mul_error!( + "cast('NaN' as double precision)", + "interval '1 month'", + "Execution error: interval out of range (float)" + ); + test_mul_error!( + "cast('inf' as double precision)", + "interval '1 month'", + "Execution error: interval out of range (float)" + ); + test_mul_error!( + "cast('-inf' as double precision)", + "interval '1 month'", + "Execution error: interval out of range (float)" + ); + + Ok(()) +} + #[cfg(feature = "unicode_expressions")] #[tokio::test] async fn test_substring_expr() -> Result<()> { diff --git a/datafusion/expr/src/binary_rule.rs b/datafusion/expr/src/binary_rule.rs index fb219d0c5aa2..876fc6f3379b 100644 --- a/datafusion/expr/src/binary_rule.rs +++ b/datafusion/expr/src/binary_rule.rs @@ -747,6 +747,20 @@ pub fn distinct_coercion( | (Interval(unit), UInt8) | (Null, Interval(unit)) | (Interval(unit), Null) => Some(Interval(unit.clone())), + // float*interval result is always represented as MonthDayNano, to avoid precision loss + (Float64, Interval(_)) + | (Interval(_), Float64) + | (Float32, Interval(_)) + | (Interval(_), Float32) + | (Float16, Interval(_)) + | (Interval(_), Float16) => Some(Interval(MonthDayNano)), + _ => None, + }, + Operator::Divide => match (lhs_type, rhs_type) { + // interval/float result is represented as MonthDayNano, to avoid precision loss + (Interval(_), Float64) | (Interval(_), Float32) | (Interval(_), Float16) => { + Some(Interval(MonthDayNano)) + } _ => None, }, _ => None, diff --git a/datafusion/physical-expr/src/expressions/binary_distinct.rs b/datafusion/physical-expr/src/expressions/binary_distinct.rs index 1a3e74bac8ac..07dbb768f526 100644 --- a/datafusion/physical-expr/src/expressions/binary_distinct.rs +++ b/datafusion/physical-expr/src/expressions/binary_distinct.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use arrow::{ array::{ - Array, ArrayRef, Date32Array, Int64Array, IntervalDayTimeArray, + Array, ArrayRef, Date32Array, Float64Array, Int64Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, TimestampNanosecondArray, }, datatypes::{ @@ -47,6 +47,7 @@ pub fn distinct_types_allowed( ) -> bool { use arrow::datatypes::TimeUnit::*; use DataType::*; + use IntervalUnit::*; match op { Operator::Plus => matches!( @@ -62,8 +63,16 @@ pub fn distinct_types_allowed( ), Operator::Multiply => matches!( (left_type, right_type), - (Int64, Interval(_)) | (Interval(_), Int64) + (Int64, Interval(_)) + | (Interval(_), Int64) + // Expect both operands coerced to most precise + | (Float64, Interval(MonthDayNano)) + | (Interval(MonthDayNano), Float64) ), + // Expect both operands coerced to most precise + Operator::Divide => { + matches!((left_type, right_type), (Interval(MonthDayNano), Float64)) + } _ => false, } } @@ -143,6 +152,20 @@ pub fn coerce_types_distinct( | (Interval(unit), UInt16) | (Interval(unit), UInt8) | (Interval(unit), Null) => Some((Interval(unit.clone()), Int64)), + // For float * interval expression both operands are extended to most precise + (Float64, Interval(_)) | (Float32, Interval(_)) | (Float16, Interval(_)) => { + Some((Float64, Interval(MonthDayNano))) + } + (Interval(_), Float64) | (Interval(_), Float32) | (Interval(_), Float16) => { + Some((Interval(MonthDayNano), Float64)) + } + _ => None, + }, + Operator::Divide => match (lhs_type, rhs_type) { + // For interval / float expression both operands are extended to most precise + (Interval(_), Float64) | (Interval(_), Float32) | (Interval(_), Float16) => { + Some((Interval(MonthDayNano), Float64)) + } _ => None, }, _ => None, @@ -160,6 +183,7 @@ pub fn evaluate_distinct_with_resolved_args( ) -> Option> { use arrow::datatypes::TimeUnit::*; use DataType::*; + use IntervalUnit::*; match op { Operator::Plus => match (left_data_type, right_data_type) { @@ -196,6 +220,18 @@ pub fn evaluate_distinct_with_resolved_args( Operator::Multiply => match (left_data_type, right_data_type) { (Int64, Interval(_)) => Some(interval_multiply_int(right, left)), (Interval(_), Int64) => Some(interval_multiply_int(left, right)), + // Expect both operands coerced to most precise + (Float64, Interval(MonthDayNano)) => { + Some(interval_multiply_float(right, left)) + } + (Interval(MonthDayNano), Float64) => { + Some(interval_multiply_float(left, right)) + } + _ => None, + }, + Operator::Divide => match (left_data_type, right_data_type) { + // Expect both operands coerced to most precise + (Interval(MonthDayNano), Float64) => Some(interval_divide_float(left, right)), _ => None, }, _ => None, @@ -261,6 +297,78 @@ fn interval_multiply_int( } } +fn interval_multiply_float( + intervals: Arc, + multipliers: Arc, +) -> Result { + let multipliers = match multipliers.data_type() { + DataType::Float64 => multipliers, + t => { + return Err(DataFusionError::Execution(format!( + "unsupported multiplicand type {}", + t + ))) + } + }; + let multipliers = multipliers.as_any().downcast_ref::().unwrap(); + + match intervals.data_type() { + // Expect both operands coerced to most precise, no need to handle other interval units + DataType::Interval(IntervalUnit::MonthDayNano) => { + let intervals = intervals + .as_any() + .downcast_ref::() + .unwrap(); + let result = intervals + .iter() + .zip(multipliers.iter()) + .map(|(i, m)| scalar_interval_month_day_nano_time_mul_float(i, m)) + .collect::>()?; + Ok(Arc::new(result)) + } + t => Err(DataFusionError::Execution(format!( + "multiplication expected Interval(MonthDayNano), got {}", + t + ))), + } +} + +fn interval_divide_float( + intervals: Arc, + divisors: Arc, +) -> Result { + let divisors = match divisors.data_type() { + DataType::Float64 => divisors, + t => { + return Err(DataFusionError::Execution(format!( + "unsupported divisor type {}", + t + ))) + } + }; + let divisors = divisors.as_any().downcast_ref::().unwrap(); + + match intervals.data_type() { + // Expect both operands coerced to most precise, no need to handle other interval units + DataType::Interval(IntervalUnit::MonthDayNano) => { + let intervals = intervals + .as_any() + .downcast_ref::() + .unwrap(); + let result = intervals + .iter() + .zip(divisors.iter()) + .map(|(i, m)| scalar_interval_month_day_nano_time_div_float(i, m)) + .collect::>()?; + Ok(Arc::new(result)) + } + t => Err(DataFusionError::Execution(format!( + "division expected Interval(MonthDayNano), got {}", + t + ))), + } +} + fn scalar_interval_year_month_mul_int( interval: Option, multiplier: Option, @@ -343,6 +451,123 @@ fn scalar_interval_month_day_nano_time_mul_int( Ok(Some(interval)) } +fn scalar_interval_month_day_nano_time_mul_float( + interval: Option, + multiplier: Option, +) -> Result> { + if interval.is_none() || multiplier.is_none() { + return Ok(None); + } + let interval = interval.unwrap(); + let multiplier = multiplier.unwrap(); + + // We can leave infinity unchecked, since multiplication by infinity should result in infinity, so we handle it together with very large multipliers + // We can leave 0 unchecked, since multiplication by zero should result in 0, which is ok as interval + + let result = scalar_interval_month_day_nano_time_float_mul_div( + interval, + multiplier, + |i, f| i * f, + )?; + Ok(Some(result)) +} + +fn scalar_interval_month_day_nano_time_div_float( + interval: Option, + divisor: Option, +) -> Result> { + if interval.is_none() || divisor.is_none() { + return Ok(None); + } + let interval = interval.unwrap(); + let divisor = divisor.unwrap(); + + // We can leave infinity unchecked, since division by infinity should result in 0, which is ok as interval + // We can leave 0 unchecked, since division by zero should result in infinity, so we handle it together with very small divisors + + let result = + scalar_interval_month_day_nano_time_float_mul_div(interval, divisor, |i, f| { + i / f + })?; + Ok(Some(result)) +} + +fn scalar_interval_month_day_nano_time_float_mul_div( + interval: IntervalMonthDayNano, + float_operand: f64, + op: impl Fn(f64, f64) -> f64, +) -> Result { + // https://github.com/postgres/postgres/blob/86d33987e8b0364b468c9b40c5f2a0a1aed87ef1/src/backend/utils/adt/timestamp.c#L3567-L3786 + + let out_or_range = + || DataFusionError::Execution("interval out of range (float)".to_string()); + + if float_operand.is_nan() { + return Err(out_or_range()); + } + + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(interval); + + fn try_to_i32(v: f64) -> Option<(i32, f64)> { + if !v.is_finite() { + return None; + } + if v > f64::from(i32::MAX) { + return None; + } + if v < f64::from(i32::MIN) { + return None; + } + // This cast should not saturate nor handle NaN/Inf: we've checked limits and special values + let vi = v as i32; + let rem = v - vi as f64; + Some((vi, rem)) + } + + fn try_to_i64(v: f64) -> Option<(i64, f64)> { + if !v.is_finite() { + return None; + } + // i64::MAX is not representable precisely in f64, because it's 2^n - 1 + // But i64::MIN is, because it's -1 * 2^n + if v >= -(i64::MIN as f64) { + return None; + } + if v < i64::MIN as f64 { + return None; + } + // This cast should not saturate nor handle NaN/Inf + let vi = v as i64; + let rem = v - vi as f64; + Some((vi, rem)) + } + + const DAYS_PER_MONTH: f64 = 30f64; + const NANOS_PER_DAY: f64 = 86_400_000_000_000f64; + + let (months, months_rem) = + try_to_i32(op(f64::from(months), float_operand)).ok_or_else(out_or_range)?; + let (days, days_rem) = + try_to_i32(op(f64::from(days), float_operand)).ok_or_else(out_or_range)?; + // `nanos as f64` can lose precision for high values of nanos + let (nanos, _nanos_rem) = + try_to_i64(op(nanos as f64, float_operand)).ok_or_else(out_or_range)?; + + let months_rem_days = months_rem * DAYS_PER_MONTH; + let (months_rem_days, months_rem_days_rem) = + try_to_i32(months_rem_days).ok_or_else(out_or_range)?; + let days = days + months_rem_days; + let days_rem = days_rem + months_rem_days_rem; + + let days_rem_nanos = days_rem * NANOS_PER_DAY; + let (days_rem_nanos, _days_rem_nanos_rem) = + try_to_i64(days_rem_nanos).ok_or_else(out_or_range)?; + let nanos = nanos + days_rem_nanos; + + let interval = IntervalMonthDayNanoType::make_value(months, days, nanos); + Ok(interval) +} + fn timestamp_add_interval( timestamps: Arc, intervals: Arc,