diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 44f1d1694146..b1759239c588 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -904,6 +904,9 @@ impl ScalarValue { DataType::Interval(IntervalUnit::YearMonth) => { build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth) } + DataType::Interval(IntervalUnit::MonthDayNano) => { + build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) + } DataType::List(fields) if fields.data_type() == &DataType::Int8 => { build_array_list_primitive!(Int8Type, Int8, i8) } diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index 4fa92c3cb98a..f84ed7fafe50 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -22,12 +22,14 @@ use ahash::RandomState; use arrow::{ array::{ - ArrayData, ArrayRef, BooleanArray, LargeStringArray, PrimitiveArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray, - UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, UInt64Builder, + ArrayData, ArrayRef, BooleanArray, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeStringArray, + PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampSecondArray, UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, + UInt64Builder, }, compute, - datatypes::{UInt32Type, UInt64Type}, + datatypes::{IntervalUnit, UInt32Type, UInt64Type}, }; use smallvec::{smallvec, SmallVec}; use std::sync::Arc; @@ -925,6 +927,38 @@ fn equal_rows( ) } }, + DataType::Interval(interval_unit) => match interval_unit { + IntervalUnit::YearMonth => { + equal_rows_elem!( + IntervalYearMonthArray, + l, + r, + left, + right, + null_equals_null + ) + } + IntervalUnit::DayTime => { + equal_rows_elem!( + IntervalDayTimeArray, + l, + r, + left, + right, + null_equals_null + ) + } + IntervalUnit::MonthDayNano => { + equal_rows_elem!( + IntervalMonthDayNanoArray, + l, + r, + left, + right, + null_equals_null + ) + } + }, DataType::Utf8 => { equal_rows_elem!(StringArray, l, r, left, right, null_equals_null) } diff --git a/datafusion/core/src/physical_plan/hash_utils.rs b/datafusion/core/src/physical_plan/hash_utils.rs index 9562e900298a..86bd0ca6a7c2 100644 --- a/datafusion/core/src/physical_plan/hash_utils.rs +++ b/datafusion/core/src/physical_plan/hash_utils.rs @@ -22,13 +22,14 @@ use ahash::{CallHasher, RandomState}; use arrow::array::{ Array, ArrayRef, BooleanArray, Date32Array, Date64Array, DecimalArray, DictionaryArray, Float32Array, Float64Array, GenericListArray, Int16Array, - Int32Array, Int64Array, Int8Array, LargeStringArray, OffsetSizeTrait, StringArray, + Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, + IntervalYearMonthArray, LargeStringArray, OffsetSizeTrait, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::datatypes::{ ArrowDictionaryKeyType, ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, - Int8Type, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + Int8Type, IntervalUnit, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use std::sync::Arc; @@ -469,6 +470,36 @@ pub fn create_hashes<'a>( multi_col ); } + DataType::Interval(IntervalUnit::YearMonth) => { + hash_array_primitive!( + IntervalYearMonthArray, + col, + i32, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Interval(IntervalUnit::DayTime) => { + hash_array_primitive!( + IntervalDayTimeArray, + col, + i64, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + hash_array_primitive!( + IntervalMonthDayNanoArray, + col, + i128, + hashes_buffer, + random_state, + multi_col + ); + } DataType::Date32 => { hash_array_primitive!( Date32Array, diff --git a/datafusion/physical-expr/src/expressions/binary_distinct.rs b/datafusion/physical-expr/src/expressions/binary_distinct.rs index a6c46a44bc39..41db971b39e5 100644 --- a/datafusion/physical-expr/src/expressions/binary_distinct.rs +++ b/datafusion/physical-expr/src/expressions/binary_distinct.rs @@ -48,6 +48,7 @@ pub fn distinct_types_allowed( Operator::Minus => matches!( (left_type, right_type), (Timestamp(Nanosecond, _), Interval(_)) + | (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) ), Operator::Multiply => matches!( (left_type, right_type), @@ -100,6 +101,10 @@ pub fn coerce_types_distinct( (Date64, Interval(iunit)) | (Date32, Interval(iunit)) => { Some((Timestamp(Nanosecond, None), Interval(iunit.clone()))) } + (Timestamp(_, tz), Timestamp(_, tz2)) => Some(( + Timestamp(Nanosecond, tz.clone()), + Timestamp(Nanosecond, tz2.clone()), + )), _ => None, }, Operator::Multiply => match (lhs_type, rhs_type) { @@ -166,6 +171,10 @@ pub fn evaluate_distinct_with_resolved_args( (Timestamp(Nanosecond, Some(tz)), Interval(_)) if tz == "UTC" => { Some(timestamp_add_interval(left, right, true)) } + (Timestamp(Nanosecond, None), Timestamp(Nanosecond, None)) => { + // TODO: Implement postgres behavior with time zones + Some(timestamp_subtract_timestamp(left, right)) + } _ => None, }, Operator::Multiply => match (left_data_type, right_data_type) { @@ -334,6 +343,27 @@ fn timestamp_add_interval( } } +fn timestamp_subtract_timestamp( + left: Arc, + right: Arc, +) -> Result { + let left = left + .as_any() + .downcast_ref::() + .unwrap(); + let right = right + .as_any() + .downcast_ref::() + .unwrap(); + + let result = left + .iter() + .zip(right.iter()) + .map(|(t_l, t_r)| scalar_timestamp_subtract_timestamp(t_l, t_r)) + .collect::>()?; + Ok(Arc::new(result)) +} + fn scalar_timestamp_add_interval_year_month( timestamp: Option, interval: Option, @@ -445,6 +475,34 @@ fn scalar_timestamp_add_interval_month_day_nano( Ok(Some(result.timestamp_nanos())) } +fn scalar_timestamp_subtract_timestamp( + timestamp_left: Option, + timestamp_right: Option, +) -> Result> { + if timestamp_left.is_none() || timestamp_right.is_none() { + return Ok(None); + } + + let datetime_left: NaiveDateTime = timestamp_ns_to_datetime(timestamp_left.unwrap()); + let datetime_right: NaiveDateTime = + timestamp_ns_to_datetime(timestamp_right.unwrap()); + let duration = datetime_left.signed_duration_since(datetime_right); + // TODO: What is Postgres behavior? E.g. if these timestamp values are i64::MAX and i64::MIN, + // we needlessly have a range error. + let nanos: i64 = duration.num_nanoseconds().ok_or_else(|| { + DataFusionError::Execution("Interval value is out of range".to_string()) + })?; + + let days = nanos / 86_400_000_000_000; + let nanos_rem = nanos % 86_400_000_000_000; + Ok(Some( + (((days as i128) & 0xFFFF_FFFF) << 64) + | ((nanos_rem as i128) & 0xFFFF_FFFF_FFFF_FFFF), + )) + + // TODO: How can day, above, in scalar_timestamp_add_interval_month_day_nano, be negative? +} + fn change_ym(t: NaiveDateTime, y: i32, m: u32) -> Result { // TODO: legacy code, check validity debug_assert!((1..=12).contains(&m));