Skip to content

Commit

Permalink
Add PartialOrd for the DF subfields/structs for the WindowFunction ex…
Browse files Browse the repository at this point in the history
…pr (#12421)

* Added PartialOrd implementations for AggregateUDF, AggregateUDFImpl, BuiltInWindowFunction and WindowUDF.

* Added tests for PartialOrd in udwf.rs.

* Removed manual implementation of PartialOrd for TypeSignature, replaced with derives.

* Adjusted the assertion for clarity on comparing enum variants.

* Edited assertions to use partial_cmp for clarity, and reformatted with rustfmt.

---------

Co-authored-by: M <[email protected]>
  • Loading branch information
ngli-me and M authored Sep 12, 2024
1 parent 6bf3479 commit 389f7f7
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 9 deletions.
26 changes: 23 additions & 3 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub enum Volatility {
/// DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())),
/// ]);
/// ```
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum TypeSignature {
/// One or more arguments of an common type out of a list of valid types.
///
Expand Down Expand Up @@ -127,7 +127,7 @@ pub enum TypeSignature {
Numeric(usize),
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum ArrayFunctionSignature {
/// Specialized Signature for ArrayAppend and similar functions
/// The first argument should be List/LargeList/FixedSizedList, and the second argument should be non-list or list.
Expand Down Expand Up @@ -241,7 +241,7 @@ impl TypeSignature {
///
/// DataFusion will automatically coerce (cast) argument types to one of the supported
/// function signatures, if possible.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct Signature {
/// The data types that the function accepts. See [TypeSignature] for more information.
pub type_signature: TypeSignature,
Expand Down Expand Up @@ -418,4 +418,24 @@ mod tests {
);
}
}

#[test]
fn type_signature_partial_ord() {
// Test validates that partial ord is defined for TypeSignature and Signature.
assert!(TypeSignature::UserDefined < TypeSignature::VariadicAny);
assert!(TypeSignature::UserDefined < TypeSignature::Any(1));

assert!(
TypeSignature::Uniform(1, vec![DataType::Null])
< TypeSignature::Uniform(1, vec![DataType::Boolean])
);
assert!(
TypeSignature::Uniform(1, vec![DataType::Null])
< TypeSignature::Uniform(2, vec![DataType::Null])
);
assert!(
TypeSignature::Uniform(usize::MAX, vec![DataType::Null])
< TypeSignature::Exact(vec![DataType::Null])
);
}
}
2 changes: 1 addition & 1 deletion datafusion/expr/src/built_in_window_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl fmt::Display for BuiltInWindowFunction {
/// A [window function] built in to DataFusion
///
/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum BuiltInWindowFunction {
/// rank of the current row with gaps; same as row_number of its first peer
Rank,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ impl AggregateFunction {
}

/// WindowFunction
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
/// Defines which implementation of an aggregate function DataFusion should call.
pub enum WindowFunctionDefinition {
/// A built in aggregate function that leverages an aggregate function
Expand Down
129 changes: 128 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! [`AggregateUDF`]: User Defined Aggregate Functions

use std::any::Any;
use std::cmp::Ordering;
use std::fmt::{self, Debug, Formatter};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
Expand Down Expand Up @@ -68,7 +69,7 @@ use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
/// [`create_udaf`]: crate::expr_fn::create_udaf
/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialOrd)]
pub struct AggregateUDF {
inner: Arc<dyn AggregateUDFImpl>,
}
Expand Down Expand Up @@ -584,6 +585,24 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
}
}

impl PartialEq for dyn AggregateUDFImpl {
fn eq(&self, other: &Self) -> bool {
self.equals(other)
}
}

// manual implementation of `PartialOrd`
// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
impl PartialOrd for dyn AggregateUDFImpl {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.name().partial_cmp(other.name()) {
Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
cmp => cmp,
}
}
}

pub enum ReversedUDAF {
/// The expression is the same as the original expression, like SUM, COUNT
Identical,
Expand Down Expand Up @@ -758,3 +777,111 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
(self.accumulator)(acc_args)
}
}

#[cfg(test)]
mod test {
use crate::{AggregateUDF, AggregateUDFImpl};
use arrow::datatypes::{DataType, Field};
use datafusion_common::Result;
use datafusion_expr_common::accumulator::Accumulator;
use datafusion_expr_common::signature::{Signature, Volatility};
use datafusion_functions_aggregate_common::accumulator::{
AccumulatorArgs, StateFieldsArgs,
};
use std::any::Any;
use std::cmp::Ordering;

#[derive(Debug, Clone)]
struct AMeanUdf {
signature: Signature,
}

impl AMeanUdf {
fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Float64],
Volatility::Immutable,
),
}
}
}

impl AggregateUDFImpl for AMeanUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"a"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
unimplemented!()
}
fn accumulator(
&self,
_acc_args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
unimplemented!()
}
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
unimplemented!()
}
}

#[derive(Debug, Clone)]
struct BMeanUdf {
signature: Signature,
}
impl BMeanUdf {
fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Float64],
Volatility::Immutable,
),
}
}
}

impl AggregateUDFImpl for BMeanUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"b"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
unimplemented!()
}
fn accumulator(
&self,
_acc_args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
unimplemented!()
}
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
unimplemented!()
}
}

#[test]
fn test_partial_ord() {
// Test validates that partial ord is defined for AggregateUDF using the name and signature,
// not intended to exhaustively test all possibilities
let a1 = AggregateUDF::from(AMeanUdf::new());
let a2 = AggregateUDF::from(AMeanUdf::new());
assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));

let b1 = AggregateUDF::from(BMeanUdf::new());
assert!(a1 < b1);
assert!(!(a1 == b1));
}
}
114 changes: 111 additions & 3 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
//! [`WindowUDF`]: User Defined Window Functions

use arrow::compute::SortOptions;
use arrow::datatypes::DataType;
use std::cmp::Ordering;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::{
any::Any,
fmt::{self, Debug, Display, Formatter},
sync::Arc,
};

use arrow::datatypes::DataType;

use datafusion_common::{not_impl_err, Result};

use crate::expr::WindowFunction;
Expand Down Expand Up @@ -54,7 +54,7 @@ use crate::{
/// [`create_udwf`]: crate::expr_fn::create_udwf
/// [`simple_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs
/// [`advanced_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialOrd)]
pub struct WindowUDF {
inner: Arc<dyn WindowUDFImpl>,
}
Expand Down Expand Up @@ -386,6 +386,21 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
}
}

impl PartialEq for dyn WindowUDFImpl {
fn eq(&self, other: &Self) -> bool {
self.equals(other)
}
}

impl PartialOrd for dyn WindowUDFImpl {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.name().partial_cmp(other.name()) {
Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
cmp => cmp,
}
}
}

/// WindowUDF that adds an alias to the underlying function. It is better to
/// implement [`WindowUDFImpl`], which supports aliases, directly if possible.
#[derive(Debug)]
Expand Down Expand Up @@ -511,3 +526,96 @@ impl WindowUDFImpl for WindowUDFLegacyWrapper {
(self.partition_evaluator_factory)()
}
}

#[cfg(test)]
mod test {
use crate::{PartitionEvaluator, WindowUDF, WindowUDFImpl};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_expr_common::signature::{Signature, Volatility};
use std::any::Any;
use std::cmp::Ordering;

#[derive(Debug, Clone)]
struct AWindowUDF {
signature: Signature,
}

impl AWindowUDF {
fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Int32],
Volatility::Immutable,
),
}
}
}

/// Implement the WindowUDFImpl trait for AddOne
impl WindowUDFImpl for AWindowUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"a"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
unimplemented!()
}
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
unimplemented!()
}
}

#[derive(Debug, Clone)]
struct BWindowUDF {
signature: Signature,
}

impl BWindowUDF {
fn new() -> Self {
Self {
signature: Signature::uniform(
1,
vec![DataType::Int32],
Volatility::Immutable,
),
}
}
}

/// Implement the WindowUDFImpl trait for AddOne
impl WindowUDFImpl for BWindowUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"b"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
unimplemented!()
}
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
unimplemented!()
}
}

#[test]
fn test_partial_ord() {
let a1 = WindowUDF::from(AWindowUDF::new());
let a2 = WindowUDF::from(AWindowUDF::new());
assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));

let b1 = WindowUDF::from(BWindowUDF::new());
assert!(a1 < b1);
assert!(!(a1 == b1));
}
}

0 comments on commit 389f7f7

Please sign in to comment.