Skip to content

Commit

Permalink
fix issues with join filter pushdown and virtual column resolution (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
clintropolis authored Jul 11, 2024
1 parent 4b293fc commit d6c0727
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@
import org.apache.druid.segment.virtual.VirtualizedColumnSelectorFactory;

import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Class allowing lookup and usage of virtual columns.
Expand Down Expand Up @@ -86,11 +85,21 @@ public static Pair<String, String> splitColumnName(String columnName)
}

@JsonCreator
public static VirtualColumns create(List<VirtualColumn> virtualColumns)
public static VirtualColumns create(@Nullable List<VirtualColumn> virtualColumns)
{
if (virtualColumns == null || virtualColumns.isEmpty()) {
return EMPTY;
}
return fromIterable(virtualColumns);
}

public static VirtualColumns create(VirtualColumn... virtualColumns)
{
return create(Arrays.asList(virtualColumns));
}

public static VirtualColumns fromIterable(Iterable<VirtualColumn> virtualColumns)
{
Map<String, VirtualColumn> withDotSupport = new HashMap<>();
Map<String, VirtualColumn> withoutDotSupport = new HashMap<>();
for (VirtualColumn vc : virtualColumns) {
Expand All @@ -115,11 +124,6 @@ public static VirtualColumns create(List<VirtualColumn> virtualColumns)
return new VirtualColumns(ImmutableList.copyOf(virtualColumns), withDotSupport, withoutDotSupport);
}

public static VirtualColumns create(VirtualColumn... virtualColumns)
{
return create(Arrays.asList(virtualColumns));
}

public static VirtualColumns nullToEmpty(@Nullable VirtualColumns virtualColumns)
{
return virtualColumns == null ? EMPTY : virtualColumns;
Expand All @@ -134,6 +138,14 @@ public static boolean shouldVectorize(Query<?> query, VirtualColumns virtualColu
}
}

// For equals, hashCode, toString, and serialization:
private final List<VirtualColumn> virtualColumns;
private final List<String> virtualColumnNames;

// For getVirtualColumn:
private final Map<String, VirtualColumn> withDotSupport;
private final Map<String, VirtualColumn> withoutDotSupport;

private VirtualColumns(
List<VirtualColumn> virtualColumns,
Map<String, VirtualColumn> withDotSupport,
Expand All @@ -143,19 +155,14 @@ private VirtualColumns(
this.virtualColumns = virtualColumns;
this.withDotSupport = withDotSupport;
this.withoutDotSupport = withoutDotSupport;
this.virtualColumnNames = new ArrayList<>(virtualColumns.size());

for (VirtualColumn virtualColumn : virtualColumns) {
detectCycles(virtualColumn, null);
virtualColumnNames.add(virtualColumn.getOutputName());
}
}

// For equals, hashCode, toString, and serialization:
private final List<VirtualColumn> virtualColumns;

// For getVirtualColumn:
private final Map<String, VirtualColumn> withDotSupport;
private final Map<String, VirtualColumn> withoutDotSupport;

/**
* Returns true if a virtual column exists with a particular columnName.
*
Expand Down Expand Up @@ -468,6 +475,16 @@ public byte[] getCacheKey()
return new CacheKeyBuilder((byte) 0).appendCacheablesIgnoringOrder(virtualColumns).build();
}

public boolean isEmpty()
{
return virtualColumns.isEmpty();
}

public List<String> getColumnNames()
{
return virtualColumnNames;
}

private VirtualColumn getVirtualColumnForSelector(String columnName)
{
VirtualColumn virtualColumn = getVirtualColumn(columnName);
Expand Down Expand Up @@ -538,14 +555,4 @@ public boolean equals(Object obj)
((VirtualColumns) obj).virtualColumns.isEmpty();
}
}

public boolean isEmpty()
{
return virtualColumns.isEmpty();
}

public List<String> getColumnNames()
{
return virtualColumns.stream().map(v -> v.getOutputName()).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

package org.apache.druid.segment.join;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.guava.Sequence;
Expand All @@ -30,7 +33,6 @@
import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.Metadata;
import org.apache.druid.segment.StorageAdapter;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.data.Indexed;
Expand All @@ -46,13 +48,10 @@

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

public class HashJoinSegmentStorageAdapter implements StorageAdapter
{
Expand Down Expand Up @@ -292,43 +291,40 @@ public Sequence<Cursor> makeCursors(
);

final JoinFilterPreAnalysisKey keyCached = joinFilterPreAnalysis.getKey();
final JoinFilterSplit joinFilterSplit;

final JoinFilterPreAnalysis preAnalysis;
if (keyIn.equals(keyCached)) {
// Common case: key used during filter pre-analysis (keyCached) matches key implied by makeCursors call (keyIn).
joinFilterSplit = JoinFilterAnalyzer.splitFilter(joinFilterPreAnalysis, baseFilter);
preAnalysis = joinFilterPreAnalysis;
} else {
// Less common case: key differs. Re-analyze the filter. This case can happen when an unnest datasource is
// layered on top of a join datasource.
joinFilterSplit = JoinFilterAnalyzer.splitFilter(
JoinFilterAnalyzer.computeJoinFilterPreAnalysis(keyIn),
baseFilter
);
preAnalysis = JoinFilterAnalyzer.computeJoinFilterPreAnalysis(keyIn);
}

final List<VirtualColumn> preJoinVirtualColumns = new ArrayList<>();
final List<VirtualColumn> postJoinVirtualColumns = new ArrayList<>();

determineBaseColumnsWithPreAndPostJoinVirtualColumns(
virtualColumns,
preJoinVirtualColumns,
postJoinVirtualColumns
final JoinFilterSplit joinFilterSplit = JoinFilterAnalyzer.splitFilter(
preAnalysis,
baseFilter
);

// We merge the filter on base table specified by the user and filter on the base table that is pushed from
// the join
preJoinVirtualColumns.addAll(joinFilterSplit.getPushDownVirtualColumns());

final Sequence<Cursor> baseCursorSequence = baseAdapter.makeCursors(
joinFilterSplit.getBaseTableFilter().isPresent() ? joinFilterSplit.getBaseTableFilter().get() : null,
interval,
VirtualColumns.create(preJoinVirtualColumns),
VirtualColumns.fromIterable(
Iterables.concat(
Sets.difference(
ImmutableSet.copyOf(virtualColumns.getVirtualColumns()),
joinFilterPreAnalysis.getPostJoinVirtualColumns()
),
joinFilterSplit.getPushDownVirtualColumns()
)
),
gran,
descending,
queryMetrics
);

Closer joinablesCloser = Closer.create();
final Closer joinablesCloser = Closer.create();
return Sequences.<Cursor, Cursor>map(
baseCursorSequence,
cursor -> {
Expand All @@ -341,7 +337,7 @@ public Sequence<Cursor> makeCursors(

return PostJoinCursor.wrap(
retVal,
VirtualColumns.create(postJoinVirtualColumns),
VirtualColumns.fromIterable(preAnalysis.getPostJoinVirtualColumns()),
joinFilterSplit.getJoinTableFilter().orElse(null)
);
}
Expand All @@ -357,47 +353,6 @@ public boolean isBaseColumn(final String column)
return !getClauseForColumn(column).isPresent();
}

/**
* Return a String set containing the name of columns that belong to the base table (including any pre-join virtual
* columns as well).
*
* Additionally, if the preJoinVirtualColumns and/or postJoinVirtualColumns arguments are provided, this method
* will add each VirtualColumn in the provided virtualColumns to either preJoinVirtualColumns or
* postJoinVirtualColumns based on whether the virtual column is pre-join or post-join.
*
* @param virtualColumns List of virtual columns from the query
* @param preJoinVirtualColumns If provided, virtual columns determined to be pre-join will be added to this list
* @param postJoinVirtualColumns If provided, virtual columns determined to be post-join will be added to this list
*
* @return The set of base column names, including any pre-join virtual columns.
*/
public Set<String> determineBaseColumnsWithPreAndPostJoinVirtualColumns(
VirtualColumns virtualColumns,
@Nullable List<VirtualColumn> preJoinVirtualColumns,
@Nullable List<VirtualColumn> postJoinVirtualColumns
)
{
final Set<String> baseColumns = new HashSet<>(baseAdapter.getRowSignature().getColumnNames());

for (VirtualColumn virtualColumn : virtualColumns.getVirtualColumns()) {
// Virtual columns cannot depend on each other, so we don't need to check transitive dependencies.
if (baseColumns.containsAll(virtualColumn.requiredColumns())) {
// Since pre-join virtual columns can be computed using only base columns, we include them in the
// base column set.
baseColumns.add(virtualColumn.getOutputName());
if (preJoinVirtualColumns != null) {
preJoinVirtualColumns.add(virtualColumn);
}
} else {
if (postJoinVirtualColumns != null) {
postJoinVirtualColumns.add(virtualColumn);
}
}
}

return baseColumns;
}

/**
* Returns the JoinableClause corresponding to a particular column, based on the clauses' prefixes.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@
*
* The result of this pre-analysis method should be passed into the next step of join filter analysis, described below.
*
* The {@link #splitFilter(JoinFilterPreAnalysis)} method takes the pre-analysis result and optionally applies the
* filter rewrite and push down operations on a per-segment level.
* The {@link #splitFilter(JoinFilterPreAnalysis, Filter)} method takes the pre-analysis result and optionally applies
* the filter rewrite and push down operations on a per-segment level.
*/
public class JoinFilterAnalyzer
{
Expand All @@ -90,11 +90,10 @@ public class JoinFilterAnalyzer
*/
public static JoinFilterPreAnalysis computeJoinFilterPreAnalysis(final JoinFilterPreAnalysisKey key)
{
final List<VirtualColumn> preJoinVirtualColumns = new ArrayList<>();
final List<VirtualColumn> postJoinVirtualColumns = new ArrayList<>();

final JoinableClauses joinableClauses = JoinableClauses.fromList(key.getJoinableClauses());
joinableClauses.splitVirtualColumns(key.getVirtualColumns(), preJoinVirtualColumns, postJoinVirtualColumns);
final Set<VirtualColumn> postJoinVirtualColumns = joinableClauses.getPostJoinVirtualColumns(
key.getVirtualColumns()
);

final JoinFilterPreAnalysis.Builder preAnalysisBuilder =
new JoinFilterPreAnalysis.Builder(key, postJoinVirtualColumns);
Expand Down Expand Up @@ -159,13 +158,6 @@ public static JoinFilterPreAnalysis computeJoinFilterPreAnalysis(final JoinFilte
return preAnalysisBuilder.withCorrelations(correlations).build();
}

public static JoinFilterSplit splitFilter(
JoinFilterPreAnalysis joinFilterPreAnalysis
)
{
return splitFilter(joinFilterPreAnalysis, null);
}

/**
* @param joinFilterPreAnalysis The pre-analysis computed by {@link #computeJoinFilterPreAnalysis)}
* @param baseFilter - Filter on base table that was specified in the query itself
Expand Down Expand Up @@ -210,7 +202,8 @@ public static JoinFilterSplit splitFilter(
);
if (joinFilterAnalysis.isCanPushDown()) {
//noinspection OptionalGetWithoutIsPresent isCanPushDown checks isPresent
leftFilters.add(joinFilterAnalysis.getPushDownFilter().get());
final Filter pushDown = joinFilterAnalysis.getPushDownFilter().get();
leftFilters.add(pushDown);
}
if (joinFilterAnalysis.isRetainAfterJoin()) {
rightFilters.add(joinFilterAnalysis.getOriginalFilter());
Expand Down Expand Up @@ -519,7 +512,7 @@ private static String getCorrelatedBaseExprVirtualColumnName(int counter)
}

private static boolean isColumnFromPostJoinVirtualColumns(
List<VirtualColumn> postJoinVirtualColumns,
Set<VirtualColumn> postJoinVirtualColumns,
String column
)
{
Expand All @@ -532,7 +525,7 @@ private static boolean isColumnFromPostJoinVirtualColumns(
}

private static boolean areSomeColumnsFromPostJoinVirtualColumns(
List<VirtualColumn> postJoinVirtualColumns,
Set<VirtualColumn> postJoinVirtualColumns,
Collection<String> columns
)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ public class JoinFilterPreAnalysis
private final List<Filter> normalizedBaseTableClauses;
private final List<Filter> normalizedJoinTableClauses;
private final JoinFilterCorrelations correlations;
private final List<VirtualColumn> postJoinVirtualColumns;
private final Set<VirtualColumn> postJoinVirtualColumns;
private final Equiconditions equiconditions;

private JoinFilterPreAnalysis(
final JoinFilterPreAnalysisKey key,
final List<VirtualColumn> postJoinVirtualColumns,
final Set<VirtualColumn> postJoinVirtualColumns,
final List<Filter> normalizedBaseTableClauses,
final List<Filter> normalizedJoinTableClauses,
final JoinFilterCorrelations correlations,
Expand Down Expand Up @@ -86,7 +86,7 @@ public Filter getOriginalFilter()
return key.getFilter();
}

public List<VirtualColumn> getPostJoinVirtualColumns()
public Set<VirtualColumn> getPostJoinVirtualColumns()
{
return postJoinVirtualColumns;
}
Expand Down Expand Up @@ -140,13 +140,13 @@ public static class Builder
@Nullable
private JoinFilterCorrelations correlations;
@Nonnull
private final List<VirtualColumn> postJoinVirtualColumns;
private final Set<VirtualColumn> postJoinVirtualColumns;
@Nonnull
private Equiconditions equiconditions = new Equiconditions(Collections.emptyMap());

public Builder(
@Nonnull JoinFilterPreAnalysisKey key,
@Nonnull List<VirtualColumn> postJoinVirtualColumns
@Nonnull Set<VirtualColumn> postJoinVirtualColumns
)
{
this.key = key;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public VirtualColumns getVirtualColumns()
return virtualColumns;
}

@Nullable
public Filter getFilter()
{
return filter;
Expand Down
Loading

0 comments on commit d6c0727

Please sign in to comment.