Skip to content

Commit

Permalink
Fix string array grouping comparator (apache#17183)
Browse files Browse the repository at this point in the history
  • Loading branch information
clintropolis authored Oct 8, 2024
1 parent a67a3c8 commit ab0d6eb
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1112,12 +1112,17 @@ private static int compareDimsInRowsWithAggs(
DimensionHandlerUtils.convertObjectToType(rhs, fieldType)
);
} else if (fieldType.equals(ColumnType.STRING_ARRAY)) {
cmp = new DimensionComparisonUtils.ArrayComparator<String>(
comparator == null ? StringComparators.LEXICOGRAPHIC : comparator
).compare(
DimensionHandlerUtils.coerceToStringArray(lhs),
DimensionHandlerUtils.coerceToStringArray(rhs)
);
if (useNaturalStringArrayComparator(comparator)) {
cmp = fieldType.getNullableStrategy().compare(
DimensionHandlerUtils.coerceToStringArray(lhs),
DimensionHandlerUtils.coerceToStringArray(rhs)
);
} else {
cmp = new DimensionComparisonUtils.ArrayComparator<>(comparator).compare(
DimensionHandlerUtils.coerceToStringArray(lhs),
DimensionHandlerUtils.coerceToStringArray(rhs)
);
}
} else if (fieldType.equals(ColumnType.LONG_ARRAY)
|| fieldType.equals(ColumnType.DOUBLE_ARRAY)) {
cmp = fieldType.getNullableStrategy().compare(
Expand Down Expand Up @@ -1806,13 +1811,17 @@ private class ArrayStringRowBasedKeySerdeHelper extends DictionaryBuildingSingle
)
{
super(keyBufferPosition);
final Comparator<Object[]> comparator;
if (useNaturalStringArrayComparator(stringComparator)) {
comparator = ColumnType.STRING_ARRAY.getNullableStrategy();
} else {
comparator = new DimensionComparisonUtils.ArrayComparator<>(stringComparator);
}
bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) ->
new DimensionComparisonUtils.ArrayComparator<>(
stringComparator == null ? StringComparators.LEXICOGRAPHIC : stringComparator)
.compare(
stringArrayDictionary.get(lhsBuffer.getInt(lhsPosition + keyBufferPosition)),
stringArrayDictionary.get(rhsBuffer.getInt(rhsPosition + keyBufferPosition))
);
comparator.compare(
stringArrayDictionary.get(lhsBuffer.getInt(lhsPosition + keyBufferPosition)),
stringArrayDictionary.get(rhsBuffer.getInt(rhsPosition + keyBufferPosition))
);
}

@Override
Expand Down Expand Up @@ -1859,7 +1868,7 @@ private abstract class AbstractStringRowBasedKeySerdeHelper implements RowBasedK
rankOfDictionaryIds[rhsBuffer.getInt(rhsPosition + keyBufferPosition)]
);
} else {
final StringComparator realComparator = stringComparator == null ?
final StringComparator realComparator = useNaturalStringArrayComparator(stringComparator) ?
StringComparators.LEXICOGRAPHIC :
stringComparator;
bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> {
Expand Down Expand Up @@ -2182,4 +2191,16 @@ public Class<?> getClazz()
}
}
}

/**
* Check if the {@link StringComparator} is the 'natural' {@link ColumnType#STRING_ARRAY} comparator. If so,
* callers can safely use the column type to compare values. If false, the {@link StringComparator} must be called
* against each array element for the comparison of values.
*/
private static boolean useNaturalStringArrayComparator(@Nullable StringComparator stringComparator)
{
return stringComparator == null
|| StringComparators.NATURAL.equals(stringComparator)
|| StringComparators.LEXICOGRAPHIC.equals(stringComparator);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import org.apache.druid.query.aggregation.AggregationTestHelper;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.segment.Segment;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
Expand Down Expand Up @@ -147,6 +150,43 @@ public void testGroupByRootArrayString()
);
}

@Test
public void testGroupByRootArrayStringOrderAndLimit()
{
GroupByQuery groupQuery = GroupByQuery.builder()
.setDataSource("test_datasource")
.setGranularity(Granularities.ALL)
.setInterval(Intervals.ETERNITY)
.setDimensions(DefaultDimensionSpec.of("arrayString", ColumnType.STRING_ARRAY))
.setAggregatorSpecs(new CountAggregatorFactory("count"))
.setContext(getContext())
.setLimitSpec(
new DefaultLimitSpec(
ImmutableList.of(
new OrderByColumnSpec(
"arrayString",
OrderByColumnSpec.Direction.ASCENDING,
StringComparators.NATURAL
)
),
100
)
)
.build();


runResults(
groupQuery,
ImmutableList.of(
new Object[]{null, 8L},
new Object[]{new Object[]{"a", "b"}, 8L},
new Object[]{new Object[]{"a", "b", "c"}, 4L},
new Object[]{new Object[]{"b", "c"}, 4L},
new Object[]{new Object[]{"d", "e"}, 4L}
)
);
}

@Test
public void testGroupByRootArrayLong()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7481,4 +7481,49 @@ public void testNullArray()
)
);
}

@Test
public void testArrayGroupStringArrayColumnLimit()
{
cannotVectorize();
testQuery(
"SELECT arrayStringNulls, SUM(cnt) FROM druid.arrays GROUP BY 1 ORDER BY 1 DESC LIMIT 10",
QUERY_CONTEXT_NO_STRINGIFY_ARRAY,
ImmutableList.of(
GroupByQuery.builder()
.setDataSource(CalciteTests.ARRAYS_DATASOURCE)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(
dimensions(
new DefaultDimensionSpec("arrayStringNulls", "d0", ColumnType.STRING_ARRAY)
)
)
.setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt")))
.setLimitSpec(
new DefaultLimitSpec(
ImmutableList.of(
new OrderByColumnSpec(
"d0",
OrderByColumnSpec.Direction.DESCENDING,
StringComparators.NATURAL
)
),
10
)
)
.setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY)
.build()
),
ImmutableList.of(
new Object[]{Arrays.asList("d", null, "b"), 2L},
new Object[]{Arrays.asList("b", "b"), 2L},
new Object[]{Arrays.asList("a", "b"), 3L},
new Object[]{Arrays.asList(null, "b"), 2L},
new Object[]{Collections.singletonList(null), 1L},
new Object[]{Collections.emptyList(), 1L},
new Object[]{null, 3L}
)
);
}
}

0 comments on commit ab0d6eb

Please sign in to comment.