diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 2043757a49fb..ffa5f17cec14 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -84,7 +84,7 @@ pub enum Volatility { /// DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())), /// ]); /// ``` -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum TypeSignature { /// One or more arguments of an common type out of a list of valid types. /// @@ -127,7 +127,7 @@ pub enum TypeSignature { Numeric(usize), } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ArrayFunctionSignature { /// Specialized Signature for ArrayAppend and similar functions /// The first argument should be List/LargeList/FixedSizedList, and the second argument should be non-list or list. @@ -241,7 +241,7 @@ impl TypeSignature { /// /// DataFusion will automatically coerce (cast) argument types to one of the supported /// function signatures, if possible. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Signature { /// The data types that the function accepts. See [TypeSignature] for more information. pub type_signature: TypeSignature, @@ -418,4 +418,24 @@ mod tests { ); } } + + #[test] + fn type_signature_partial_ord() { + // Test validates that partial ord is defined for TypeSignature and Signature. + assert!(TypeSignature::UserDefined < TypeSignature::VariadicAny); + assert!(TypeSignature::UserDefined < TypeSignature::Any(1)); + + assert!( + TypeSignature::Uniform(1, vec![DataType::Null]) + < TypeSignature::Uniform(1, vec![DataType::Boolean]) + ); + assert!( + TypeSignature::Uniform(1, vec![DataType::Null]) + < TypeSignature::Uniform(2, vec![DataType::Null]) + ); + assert!( + TypeSignature::Uniform(usize::MAX, vec![DataType::Null]) + < TypeSignature::Exact(vec![DataType::Null]) + ); + } } diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index 597e4e68a0c6..b136d6cacec8 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -38,7 +38,7 @@ impl fmt::Display for BuiltInWindowFunction { /// A [window function] built in to DataFusion /// /// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) -#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum BuiltInWindowFunction { /// rank of the current row with gaps; same as row_number of its first peer Rank, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 161e29e5925b..db0bfd6b1bc2 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -688,7 +688,7 @@ impl AggregateFunction { } /// WindowFunction -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] /// Defines which implementation of an aggregate function DataFusion should call. pub enum WindowFunctionDefinition { /// A built in aggregate function that leverages an aggregate function diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 7b4b3bb95c46..d3eaccb2c538 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -18,6 +18,7 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions use std::any::Any; +use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; @@ -68,7 +69,7 @@ use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; /// [`create_udaf`]: crate::expr_fn::create_udaf /// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs /// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialOrd)] pub struct AggregateUDF { inner: Arc, } @@ -584,6 +585,24 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { } } +impl PartialEq for dyn AggregateUDFImpl { + fn eq(&self, other: &Self) -> bool { + self.equals(other) + } +} + +// manual implementation of `PartialOrd` +// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl +// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5 +impl PartialOrd for dyn AggregateUDFImpl { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name().partial_cmp(other.name()) { + Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()), + cmp => cmp, + } + } +} + pub enum ReversedUDAF { /// The expression is the same as the original expression, like SUM, COUNT Identical, @@ -758,3 +777,111 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { (self.accumulator)(acc_args) } } + +#[cfg(test)] +mod test { + use crate::{AggregateUDF, AggregateUDFImpl}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::Result; + use datafusion_expr_common::accumulator::Accumulator; + use datafusion_expr_common::signature::{Signature, Volatility}; + use datafusion_functions_aggregate_common::accumulator::{ + AccumulatorArgs, StateFieldsArgs, + }; + use std::any::Any; + use std::cmp::Ordering; + + #[derive(Debug, Clone)] + struct AMeanUdf { + signature: Signature, + } + + impl AMeanUdf { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Float64], + Volatility::Immutable, + ), + } + } + } + + impl AggregateUDFImpl for AMeanUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "a" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + unimplemented!() + } + fn accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> Result> { + unimplemented!() + } + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + unimplemented!() + } + } + + #[derive(Debug, Clone)] + struct BMeanUdf { + signature: Signature, + } + impl BMeanUdf { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Float64], + Volatility::Immutable, + ), + } + } + } + + impl AggregateUDFImpl for BMeanUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "b" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + unimplemented!() + } + fn accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> Result> { + unimplemented!() + } + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + unimplemented!() + } + } + + #[test] + fn test_partial_ord() { + // Test validates that partial ord is defined for AggregateUDF using the name and signature, + // not intended to exhaustively test all possibilities + let a1 = AggregateUDF::from(AMeanUdf::new()); + let a2 = AggregateUDF::from(AMeanUdf::new()); + assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal)); + + let b1 = AggregateUDF::from(BMeanUdf::new()); + assert!(a1 < b1); + assert!(!(a1 == b1)); + } +} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index e5fdaaceb439..b24aaf8561e2 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -18,6 +18,8 @@ //! [`WindowUDF`]: User Defined Window Functions use arrow::compute::SortOptions; +use arrow::datatypes::DataType; +use std::cmp::Ordering; use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ any::Any, @@ -25,8 +27,6 @@ use std::{ sync::Arc, }; -use arrow::datatypes::DataType; - use datafusion_common::{not_impl_err, Result}; use crate::expr::WindowFunction; @@ -54,7 +54,7 @@ use crate::{ /// [`create_udwf`]: crate::expr_fn::create_udwf /// [`simple_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs /// [`advanced_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialOrd)] pub struct WindowUDF { inner: Arc, } @@ -386,6 +386,21 @@ pub trait WindowUDFImpl: Debug + Send + Sync { } } +impl PartialEq for dyn WindowUDFImpl { + fn eq(&self, other: &Self) -> bool { + self.equals(other) + } +} + +impl PartialOrd for dyn WindowUDFImpl { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name().partial_cmp(other.name()) { + Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()), + cmp => cmp, + } + } +} + /// WindowUDF that adds an alias to the underlying function. It is better to /// implement [`WindowUDFImpl`], which supports aliases, directly if possible. #[derive(Debug)] @@ -511,3 +526,96 @@ impl WindowUDFImpl for WindowUDFLegacyWrapper { (self.partition_evaluator_factory)() } } + +#[cfg(test)] +mod test { + use crate::{PartitionEvaluator, WindowUDF, WindowUDFImpl}; + use arrow::datatypes::DataType; + use datafusion_common::Result; + use datafusion_expr_common::signature::{Signature, Volatility}; + use std::any::Any; + use std::cmp::Ordering; + + #[derive(Debug, Clone)] + struct AWindowUDF { + signature: Signature, + } + + impl AWindowUDF { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + } + } + } + + /// Implement the WindowUDFImpl trait for AddOne + impl WindowUDFImpl for AWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "a" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + unimplemented!() + } + fn partition_evaluator(&self) -> Result> { + unimplemented!() + } + } + + #[derive(Debug, Clone)] + struct BWindowUDF { + signature: Signature, + } + + impl BWindowUDF { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + } + } + } + + /// Implement the WindowUDFImpl trait for AddOne + impl WindowUDFImpl for BWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "b" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + unimplemented!() + } + fn partition_evaluator(&self) -> Result> { + unimplemented!() + } + } + + #[test] + fn test_partial_ord() { + let a1 = WindowUDF::from(AWindowUDF::new()); + let a2 = WindowUDF::from(AWindowUDF::new()); + assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal)); + + let b1 = WindowUDF::from(BWindowUDF::new()); + assert!(a1 < b1); + assert!(!(a1 == b1)); + } +}