diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java index 8c42bb935d84..9de8ec17f76e 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java @@ -1112,12 +1112,17 @@ private static int compareDimsInRowsWithAggs( DimensionHandlerUtils.convertObjectToType(rhs, fieldType) ); } else if (fieldType.equals(ColumnType.STRING_ARRAY)) { - cmp = new DimensionComparisonUtils.ArrayComparator( - comparator == null ? StringComparators.LEXICOGRAPHIC : comparator - ).compare( - DimensionHandlerUtils.coerceToStringArray(lhs), - DimensionHandlerUtils.coerceToStringArray(rhs) - ); + if (useNaturalStringArrayComparator(comparator)) { + cmp = fieldType.getNullableStrategy().compare( + DimensionHandlerUtils.coerceToStringArray(lhs), + DimensionHandlerUtils.coerceToStringArray(rhs) + ); + } else { + cmp = new DimensionComparisonUtils.ArrayComparator<>(comparator).compare( + DimensionHandlerUtils.coerceToStringArray(lhs), + DimensionHandlerUtils.coerceToStringArray(rhs) + ); + } } else if (fieldType.equals(ColumnType.LONG_ARRAY) || fieldType.equals(ColumnType.DOUBLE_ARRAY)) { cmp = fieldType.getNullableStrategy().compare( @@ -1806,13 +1811,17 @@ private class ArrayStringRowBasedKeySerdeHelper extends DictionaryBuildingSingle ) { super(keyBufferPosition); + final Comparator comparator; + if (useNaturalStringArrayComparator(stringComparator)) { + comparator = ColumnType.STRING_ARRAY.getNullableStrategy(); + } else { + comparator = new DimensionComparisonUtils.ArrayComparator<>(stringComparator); + } bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> - new DimensionComparisonUtils.ArrayComparator<>( - stringComparator == null ? StringComparators.LEXICOGRAPHIC : stringComparator) - .compare( - stringArrayDictionary.get(lhsBuffer.getInt(lhsPosition + keyBufferPosition)), - stringArrayDictionary.get(rhsBuffer.getInt(rhsPosition + keyBufferPosition)) - ); + comparator.compare( + stringArrayDictionary.get(lhsBuffer.getInt(lhsPosition + keyBufferPosition)), + stringArrayDictionary.get(rhsBuffer.getInt(rhsPosition + keyBufferPosition)) + ); } @Override @@ -1859,7 +1868,7 @@ private abstract class AbstractStringRowBasedKeySerdeHelper implements RowBasedK rankOfDictionaryIds[rhsBuffer.getInt(rhsPosition + keyBufferPosition)] ); } else { - final StringComparator realComparator = stringComparator == null ? + final StringComparator realComparator = useNaturalStringArrayComparator(stringComparator) ? StringComparators.LEXICOGRAPHIC : stringComparator; bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> { @@ -2182,4 +2191,16 @@ public Class getClazz() } } } + + /** + * Check if the {@link StringComparator} is the 'natural' {@link ColumnType#STRING_ARRAY} comparator. If so, + * callers can safely use the column type to compare values. If false, the {@link StringComparator} must be called + * against each array element for the comparison of values. + */ + private static boolean useNaturalStringArrayComparator(@Nullable StringComparator stringComparator) + { + return stringComparator == null + || StringComparators.NATURAL.equals(stringComparator) + || StringComparators.LEXICOGRAPHIC.equals(stringComparator); + } } diff --git a/processing/src/test/java/org/apache/druid/query/groupby/NestedGroupByArrayQueryTest.java b/processing/src/test/java/org/apache/druid/query/groupby/NestedGroupByArrayQueryTest.java index f0e581291d36..dd93978b6fdb 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/NestedGroupByArrayQueryTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/NestedGroupByArrayQueryTest.java @@ -32,6 +32,9 @@ import org.apache.druid.query.aggregation.AggregationTestHelper; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.groupby.orderby.DefaultLimitSpec; +import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; +import org.apache.druid.query.ordering.StringComparators; import org.apache.druid.segment.Segment; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; @@ -147,6 +150,43 @@ public void testGroupByRootArrayString() ); } + @Test + public void testGroupByRootArrayStringOrderAndLimit() + { + GroupByQuery groupQuery = GroupByQuery.builder() + .setDataSource("test_datasource") + .setGranularity(Granularities.ALL) + .setInterval(Intervals.ETERNITY) + .setDimensions(DefaultDimensionSpec.of("arrayString", ColumnType.STRING_ARRAY)) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(getContext()) + .setLimitSpec( + new DefaultLimitSpec( + ImmutableList.of( + new OrderByColumnSpec( + "arrayString", + OrderByColumnSpec.Direction.ASCENDING, + StringComparators.NATURAL + ) + ), + 100 + ) + ) + .build(); + + + runResults( + groupQuery, + ImmutableList.of( + new Object[]{null, 8L}, + new Object[]{new Object[]{"a", "b"}, 8L}, + new Object[]{new Object[]{"a", "b", "c"}, 4L}, + new Object[]{new Object[]{"b", "c"}, 4L}, + new Object[]{new Object[]{"d", "e"}, 4L} + ) + ); + } + @Test public void testGroupByRootArrayLong() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index d98d5759b81b..f48166211344 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -7481,4 +7481,49 @@ public void testNullArray() ) ); } + + @Test + public void testArrayGroupStringArrayColumnLimit() + { + cannotVectorize(); + testQuery( + "SELECT arrayStringNulls, SUM(cnt) FROM druid.arrays GROUP BY 1 ORDER BY 1 DESC LIMIT 10", + QUERY_CONTEXT_NO_STRINGIFY_ARRAY, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.ARRAYS_DATASOURCE) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions( + dimensions( + new DefaultDimensionSpec("arrayStringNulls", "d0", ColumnType.STRING_ARRAY) + ) + ) + .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) + .setLimitSpec( + new DefaultLimitSpec( + ImmutableList.of( + new OrderByColumnSpec( + "d0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NATURAL + ) + ), + 10 + ) + ) + .setContext(QUERY_CONTEXT_NO_STRINGIFY_ARRAY) + .build() + ), + ImmutableList.of( + new Object[]{Arrays.asList("d", null, "b"), 2L}, + new Object[]{Arrays.asList("b", "b"), 2L}, + new Object[]{Arrays.asList("a", "b"), 3L}, + new Object[]{Arrays.asList(null, "b"), 2L}, + new Object[]{Collections.singletonList(null), 1L}, + new Object[]{Collections.emptyList(), 1L}, + new Object[]{null, 3L} + ) + ); + } }