Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect results in BitAnd GroupsAccumulator #6957

Merged
merged 1 commit into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 117 additions & 67 deletions datafusion/core/tests/sqllogictests/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1420,65 +1420,95 @@ select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.c
2 1 1.414213562373 1


# sum / count for all nulls
statement ok
create table the_nulls as values (null::bigint, 1), (null::bigint, 1), (null::bigint, 2);

# counts should be zeros (even for nulls)
query II
SELECT count(column1), column2 from the_nulls group by column2 order by column2;
----
0 1
0 2

# sums should be null
query II
SELECT sum(column1), column2 from the_nulls group by column2 order by column2;
# aggregates on empty tables
statement ok
CREATE TABLE empty (column1 bigint, column2 int);

# no group by column
query IIRIIIII
SELECT
count(column1), -- counts should be zero, even for nulls
sum(column1), -- other aggregates should be null
avg(column1),
min(column1),
max(column1),
bit_and(column1),
bit_or(column1),
bit_xor(column1)
FROM empty
----
0 NULL NULL NULL NULL NULL NULL NULL

# Same query but with grouping (no groups, so no output)
query IIRIIIIII
SELECT
count(column1),
sum(column1),
avg(column1),
min(column1),
max(column1),
bit_and(column1),
bit_or(column1),
bit_xor(column1),
column2
FROM empty
GROUP BY column2
ORDER BY column2;
----
NULL 1
NULL 2

# avg should be null
query RI
SELECT avg(column1), column2 from the_nulls group by column2 order by column2;
----
NULL 1
NULL 2

# bit_and should be null
query II
SELECT bit_and(column1), column2 from the_nulls group by column2 order by column2;
----
NULL 1
NULL 2
statement ok
drop table empty

# bit_or should be null
query II
SELECT bit_or(column1), column2 from the_nulls group by column2 order by column2;
----
NULL 1
NULL 2
# aggregates on all nulls
statement ok
CREATE TABLE the_nulls
AS VALUES
(null::bigint, 1),
(null::bigint, 1),
(null::bigint, 2);

# bit_xor should be null
query II
SELECT bit_xor(column1), column2 from the_nulls group by column2 order by column2;
select * from the_nulls
----
NULL 1
NULL 2

# min should be null
query II
SELECT min(column1), column2 from the_nulls group by column2 order by column2;
----
NULL 1
NULL 2

# max should be null
query II
SELECT max(column1), column2 from the_nulls group by column2 order by column2;
----
NULL 1
NULL 2
# no group by column
query IIRIIIII
SELECT
count(column1), -- counts should be zero, even for nulls
sum(column1), -- other aggregates should be null
avg(column1),
min(column1),
max(column1),
bit_and(column1),
bit_or(column1),
bit_xor(column1)
FROM the_nulls
----
0 NULL NULL NULL NULL NULL NULL NULL

# Same query but with grouping
query IIRIIIIII
SELECT
count(column1), -- counts should be zero, even for nulls
sum(column1), -- other aggregates should be null
avg(column1),
min(column1),
max(column1),
bit_and(column1),
bit_or(column1),
bit_xor(column1),
column2
FROM the_nulls
GROUP BY column2
ORDER BY column2;
----
0 NULL NULL NULL NULL NULL NULL NULL 1
0 NULL NULL NULL NULL NULL NULL NULL 2


statement ok
Expand All @@ -1489,29 +1519,49 @@ create table bit_aggregate_functions (
c1 SMALLINT NOT NULL,
c2 SMALLINT NOT NULL,
c3 SMALLINT,
tag varchar
)
as values
(5, 10, 11),
(33, 11, null),
(9, 12, null);

# query_bit_and
query III
SELECT bit_and(c1), bit_and(c2), bit_and(c3) FROM bit_aggregate_functions
----
1 8 11

# query_bit_or
query III
SELECT bit_or(c1), bit_or(c2), bit_or(c3) FROM bit_aggregate_functions
----
45 15 11
(5, 10, 11, 'A'),
(33, 11, null, 'B'),
(9, 12, null, 'A');

# query_bit_and, query_bit_or, query_bit_xor
query IIIIIIIII
SELECT
bit_and(c1),
bit_and(c2),
bit_and(c3),
bit_or(c1),
bit_or(c2),
bit_or(c3),
bit_xor(c1),
bit_xor(c2),
bit_xor(c3)
FROM bit_aggregate_functions
----
1 8 11 45 15 11 45 13 11

# query_bit_and, query_bit_or, query_bit_xor, with group
query IIIIIIIIIT
SELECT
bit_and(c1),
bit_and(c2),
bit_and(c3),
bit_or(c1),
bit_or(c2),
bit_or(c3),
bit_xor(c1),
bit_xor(c2),
bit_xor(c3),
tag
FROM bit_aggregate_functions
GROUP BY tag
ORDER BY tag
----
1 8 11 13 14 11 12 6 11 A
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before this PR, this test fails like

External error: query result mismatch:
[SQL] SELECT
  bit_and(c1),
  bit_and(c2),
  bit_and(c3),
  bit_or(c1),
  bit_or(c2),
  bit_or(c3),
  bit_xor(c1),
  bit_xor(c2),
  bit_xor(c3),
  tag
FROM bit_aggregate_functions
GROUP BY tag
ORDER BY tag
[Diff] (-expected|+actual)
-   1 8 11 13 14 11 12 6 11 A
-   33 11 NULL 33 11 NULL 33 11 NULL B
+   0 0 0 13 14 11 12 6 11 A
+   0 0 NULL 33 11 NULL 33 11 NULL B
at tests/sqllogictests/test_files/aggregate.slt:1546

33 11 NULL 33 11 NULL 33 11 NULL B

# query_bit_xor
query III
SELECT bit_xor(c1), bit_xor(c2), bit_xor(c3) FROM bit_aggregate_functions
----
45 13 11

statement ok
create table bool_aggregate_functions (
Expand Down
83 changes: 32 additions & 51 deletions datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,16 @@ use arrow::compute::{bit_and, bit_or, bit_xor};
use datafusion_row::accessor::RowAccessor;

/// Creates a [`PrimitiveGroupsAccumulator`] with the specified
/// [`ArrowPrimitiveType`] which applies `$FN` to each element
/// [`ArrowPrimitiveType`] that initailizes each accumulator to $START
/// and applies `$FN` to each element
///
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
macro_rules! instantiate_primitive_accumulator {
($SELF:expr, $PRIMTYPE:ident, $FN:expr) => {{
Ok(Box::new(PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
&$SELF.data_type,
$FN,
)))
macro_rules! instantiate_accumulator {
($SELF:expr, $START:expr, $PRIMTYPE:ident, $FN:expr) => {{
Ok(Box::new(
PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$SELF.data_type, $FN)
.with_starting_value($START),
))
}};
}

Expand Down Expand Up @@ -279,35 +280,31 @@ impl AggregateExpr for BitAnd {
use std::ops::BitAndAssign;
match self.data_type {
DataType::Int8 => {
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int8Type, |x, y| x.bitand_assign(y))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fix is passing in MAX / -1 here to get a bitpattern of all 1 initially

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix looks good. It's better to leave a comment like above in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added in #6964

}
DataType::Int16 => {
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int16Type, |x, y| x.bitand_assign(y))
}
DataType::Int32 => {
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int32Type, |x, y| x.bitand_assign(y))
}
DataType::Int64 => {
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int64Type, |x, y| x.bitand_assign(y))
}
DataType::UInt8 => {
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
instantiate_accumulator!(self, u8::MAX, UInt8Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt16 => {
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
instantiate_accumulator!(self, u16::MAX, UInt16Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt32 => {
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
instantiate_accumulator!(self, u32::MAX, UInt32Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt64 => {
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
instantiate_accumulator!(self, u64::MAX, UInt64Type, |x, y| x
.bitand_assign(y))
}

Expand Down Expand Up @@ -517,36 +514,28 @@ impl AggregateExpr for BitOr {
use std::ops::BitOrAssign;
match self.data_type {
DataType::Int8 => {
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int8Type, |x, y| x.bitor_assign(y))
}
DataType::Int16 => {
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int16Type, |x, y| x.bitor_assign(y))
}
DataType::Int32 => {
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int32Type, |x, y| x.bitor_assign(y))
}
DataType::Int64 => {
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int64Type, |x, y| x.bitor_assign(y))
}
DataType::UInt8 => {
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt8Type, |x, y| x.bitor_assign(y))
}
DataType::UInt16 => {
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt16Type, |x, y| x.bitor_assign(y))
}
DataType::UInt32 => {
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt32Type, |x, y| x.bitor_assign(y))
}
DataType::UInt64 => {
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt64Type, |x, y| x.bitor_assign(y))
}

_ => Err(DataFusionError::NotImplemented(format!(
Expand Down Expand Up @@ -756,36 +745,28 @@ impl AggregateExpr for BitXor {
use std::ops::BitXorAssign;
match self.data_type {
DataType::Int8 => {
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int8Type, |x, y| x.bitxor_assign(y))
}
DataType::Int16 => {
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int16Type, |x, y| x.bitxor_assign(y))
}
DataType::Int32 => {
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int32Type, |x, y| x.bitxor_assign(y))
}
DataType::Int64 => {
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int64Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt8 => {
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt8Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt16 => {
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt16Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt32 => {
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt32Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt64 => {
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt64Type, |x, y| x.bitxor_assign(y))
}

_ => Err(DataFusionError::NotImplemented(format!(
Expand Down
Loading