Skip to content

Commit

Permalink
MSQ: Nicer error when sortMerge join falls back to broadcast. (apache…
Browse files Browse the repository at this point in the history
…#16002)

* MSQ: Nicer error when sortMerge join falls back to broadcast.

In certain cases, joins run as broadcast even when the user hinted
that they wanted sortMerge. This happens when the sortMerge algorithm
is unable to process the join, because it isn't a direct comparison
between two fields on the LHS and RHS.

When this happens, the error message from BroadcastTablesTooLargeFault
is quite confusing, since it mentions that you should try sortMerge
to fix it. But the user may have already configured sortMerge.

This patch fixes it by having two error messages, based on whether
broadcast join was used as a primary selection or as a fallback selection.

* Style.

* Better message.
  • Loading branch information
gianm authored Mar 1, 2024
1 parent ef48ace commit 8d3ed31
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
package org.apache.druid.msq.indexing.error;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;

import javax.annotation.Nullable;
import java.util.Objects;

@JsonTypeName(BroadcastTablesTooLargeFault.CODE)
Expand All @@ -34,19 +37,18 @@ public class BroadcastTablesTooLargeFault extends BaseMSQFault

private final long maxBroadcastTablesSize;

@Nullable
private final JoinAlgorithm configuredJoinAlgorithm;

@JsonCreator
public BroadcastTablesTooLargeFault(@JsonProperty("maxBroadcastTablesSize") final long maxBroadcastTablesSize)
public BroadcastTablesTooLargeFault(
@JsonProperty("maxBroadcastTablesSize") final long maxBroadcastTablesSize,
@Nullable @JsonProperty("configuredJoinAlgorithm") final JoinAlgorithm configuredJoinAlgorithm
)
{
super(
CODE,
"Size of broadcast tables in JOIN exceeds reserved memory limit "
+ "(memory reserved for broadcast tables = %d bytes). "
+ "Increase available memory, or set %s: %s in query context to use a shuffle-based join.",
maxBroadcastTablesSize,
PlannerContext.CTX_SQL_JOIN_ALGORITHM,
JoinAlgorithm.SORT_MERGE.toString()
);
super(CODE, makeMessage(maxBroadcastTablesSize, configuredJoinAlgorithm));
this.maxBroadcastTablesSize = maxBroadcastTablesSize;
this.configuredJoinAlgorithm = configuredJoinAlgorithm;
}

@JsonProperty
Expand All @@ -55,6 +57,14 @@ public long getMaxBroadcastTablesSize()
return maxBroadcastTablesSize;
}

@Nullable
@JsonProperty
@JsonInclude(JsonInclude.Include.NON_NULL)
public JoinAlgorithm getConfiguredJoinAlgorithm()
{
return configuredJoinAlgorithm;
}

@Override
public boolean equals(Object o)
{
Expand All @@ -68,12 +78,38 @@ public boolean equals(Object o)
return false;
}
BroadcastTablesTooLargeFault that = (BroadcastTablesTooLargeFault) o;
return maxBroadcastTablesSize == that.maxBroadcastTablesSize;
return maxBroadcastTablesSize == that.maxBroadcastTablesSize
&& configuredJoinAlgorithm == that.configuredJoinAlgorithm;
}

@Override
public int hashCode()
{
return Objects.hash(super.hashCode(), maxBroadcastTablesSize);
return Objects.hash(super.hashCode(), maxBroadcastTablesSize, configuredJoinAlgorithm);
}

private static String makeMessage(final long maxBroadcastTablesSize, final JoinAlgorithm configuredJoinAlgorithm)
{
if (configuredJoinAlgorithm == null || configuredJoinAlgorithm == JoinAlgorithm.BROADCAST) {
return StringUtils.format(
"Size of broadcast tables in JOIN exceeds reserved memory limit "
+ "(memory reserved for broadcast tables = [%,d] bytes). "
+ "Increase available memory, or set [%s: %s] in query context to use a shuffle-based join.",
maxBroadcastTablesSize,
PlannerContext.CTX_SQL_JOIN_ALGORITHM,
JoinAlgorithm.SORT_MERGE.toString()
);
} else {
return StringUtils.format(
"Size of broadcast tables in JOIN exceeds reserved memory limit "
+ "(memory reserved for broadcast tables = [%,d] bytes). "
+ "Try increasing available memory. "
+ "This query is using broadcast JOIN even though [%s: %s] is set in query context, because the configured "
+ "join algorithm does not support the join condition.",
maxBroadcastTablesSize,
PlannerContext.CTX_SQL_JOIN_ALGORITHM,
configuredJoinAlgorithm.toString()
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -242,7 +245,15 @@ boolean buildBroadcastTablesIncrementally(final IntSet readableInputs)
memoryUsed += frame.numBytes();

if (memoryUsed > memoryReservedForBroadcastJoin) {
throw new MSQException(new BroadcastTablesTooLargeFault(memoryReservedForBroadcastJoin));
throw new MSQException(
new BroadcastTablesTooLargeFault(
memoryReservedForBroadcastJoin,
Optional.ofNullable(query)
.map(q -> q.context().getString(PlannerContext.CTX_SQL_JOIN_ALGORITHM))
.map(JoinAlgorithm::fromString)
.orElse(null)
)
);
}

addFrame(channelNumber, frame);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -49,7 +50,8 @@ public void setUp()
@Test
public void testFaultSerde() throws IOException
{
assertFaultSerde(new BroadcastTablesTooLargeFault(10));
assertFaultSerde(new BroadcastTablesTooLargeFault(10, null));
assertFaultSerde(new BroadcastTablesTooLargeFault(10, JoinAlgorithm.SORT_MERGE));
assertFaultSerde(CanceledFault.INSTANCE);
assertFaultSerde(new CannotParseExternalDataFault("the message"));
assertFaultSerde(new ColumnTypeNotSupportedFault("the column", null));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.druid.msq.querykit;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
Expand All @@ -42,12 +43,17 @@
import org.apache.druid.query.DataSource;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.segment.QueryableIndexStorageAdapter;
import org.apache.druid.segment.StorageAdapter;
import org.apache.druid.segment.TestIndex;
import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.easymock.EasyMock;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.junit.Assert;
Expand Down Expand Up @@ -232,7 +238,59 @@ public void testBuildTableMemoryLimit() throws IOException
}
);

Assert.assertEquals(new BroadcastTablesTooLargeFault(100_000), e.getFault());
Assert.assertEquals(new BroadcastTablesTooLargeFault(100_000, null), e.getFault());
}

/**
* Like {@link #testBuildTableMemoryLimit()}, but with {@link JoinAlgorithm#SORT_MERGE} configured, so we can
* verify we get a better error message.
*/
@Test
public void testBuildTableMemoryLimitWithSortMergeConfigured() throws IOException
{
final Int2IntMap sideStageChannelNumberMap = new Int2IntOpenHashMap();
sideStageChannelNumberMap.put(0, 0);

final List<ReadableFrameChannel> channels = new ArrayList<>();
channels.add(new ReadableFileFrameChannel(FrameFile.open(testDataFile1, ByteTracker.unboundedTracker())));

final List<FrameReader> channelReaders = new ArrayList<>();
channelReaders.add(frameReader1);

// Query: used only to retrieve configured join from context
final Query<?> mockQuery = EasyMock.mock(Query.class);
EasyMock.expect(mockQuery.context()).andReturn(
QueryContext.of(
ImmutableMap.of(
PlannerContext.CTX_SQL_JOIN_ALGORITHM,
JoinAlgorithm.SORT_MERGE.getId()
)
)
);
EasyMock.replay(mockQuery);
final BroadcastJoinSegmentMapFnProcessor broadcastJoinHelper = new BroadcastJoinSegmentMapFnProcessor(
mockQuery,
sideStageChannelNumberMap,
channels,
channelReaders,
100_000 // Low memory limit; we will hit this
);

Assert.assertEquals(ImmutableSet.of(0), broadcastJoinHelper.getSideChannelNumbers());

final MSQException e = Assert.assertThrows(
MSQException.class,
() -> {
boolean doneReading = false;
while (!doneReading) {
final IntSet readableInputs = new IntOpenHashSet(new int[]{0});
doneReading = broadcastJoinHelper.buildBroadcastTablesIncrementally(readableInputs);
}
}
);

Assert.assertEquals(new BroadcastTablesTooLargeFault(100_000, JoinAlgorithm.SORT_MERGE), e.getFault());
EasyMock.verify(mockQuery);
}

/**
Expand Down

0 comments on commit 8d3ed31

Please sign in to comment.