diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index d5e8b93b13..8d8b95706f 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -48,6 +48,7 @@ import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.ML; @@ -88,6 +89,7 @@ import org.opensearch.sql.planner.logical.LogicalEval; import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalJoin; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; @@ -136,6 +138,18 @@ public LogicalPlan analyze(UnresolvedPlan unresolved, AnalysisContext context) { return unresolved.accept(this, context); } + @Override + public LogicalPlan visitJoin(Join node, AnalysisContext context) { + // TODO tables-join instead of plans-join supported only now + LogicalPlan left = visitRelation((Relation) node.getLeft(), context); + LogicalPlan right = visitRelation((Relation) node.getRight(), context); + Expression condition = expressionAnalyzer.analyze(node.getJoinCondition(), context); + ExpressionReferenceOptimizer optimizer = + new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), left, right); + Expression optimized = optimizer.optimize(condition, context); + return new LogicalJoin(left, right, node.getJoinType(), optimized); + } + @Override public LogicalPlan visitRelation(Relation node, AnalysisContext context) { QualifiedName qualifiedName = node.getTableQualifiedName(); diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java index 398f848f16..e598bb2efc 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java @@ -5,10 +5,15 @@ package org.opensearch.sql.analysis; +import static org.opensearch.sql.common.utils.StringUtils.format; + +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.FunctionExpression; @@ -22,6 +27,7 @@ import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; +import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.logical.LogicalWindow; /** @@ -47,12 +53,34 @@ public class ExpressionReferenceOptimizer */ private final Map expressionMap = new HashMap<>(); + private String leftRelationName; + private String rightRelationName; + private Set leftSideAttributes; + private Set rightSideAttributes; + public ExpressionReferenceOptimizer( BuiltinFunctionRepository repository, LogicalPlan logicalPlan) { this.repository = repository; logicalPlan.accept(new ExpressionMapBuilder(), null); } + public ExpressionReferenceOptimizer( + BuiltinFunctionRepository repository, LogicalPlan... logicalPlans) { + this.repository = repository; + // To resolve join condition, we store left side and left side of join. + if (logicalPlans.length == 2) { + // TODO current implementation only support two-tables join, so we can directly convert them + // to LogicalRelation. To support two-plans join, we can get the LogicalRelation by searching. + this.leftRelationName = ((LogicalRelation) logicalPlans[0]).getRelationName(); + this.rightRelationName = ((LogicalRelation) logicalPlans[1]).getRelationName(); + this.leftSideAttributes = + ((LogicalRelation) logicalPlans[0]).getTable().getFieldTypes().keySet(); + this.rightSideAttributes = + ((LogicalRelation) logicalPlans[1]).getTable().getFieldTypes().keySet(); + } + Arrays.stream(logicalPlans).forEach(p -> p.accept(new ExpressionMapBuilder(), null)); + } + public Expression optimize(Expression analyzed, AnalysisContext context) { return analyzed.accept(this, context); } @@ -62,6 +90,45 @@ public Expression visitNode(Expression node, AnalysisContext context) { return node; } + /** + * Add index prefix to reference attribute of join condition. The attribute could be: case 1: + * Field -> Index.Field case 2: Field.Field -> Index.Field.Field case 3: .Index.Field, + * .Index.Field.Field -> do nothing case 4: Index.Field, Index.Field.Field -> do nothing + */ + @Override + public Expression visitReference(ReferenceExpression node, AnalysisContext context) { + if (leftRelationName == null || rightRelationName == null) { + return node; + } + + String attr = node.getAttr(); + // case 1 or case 2 + if (!attr.contains(".") || (!attr.startsWith(".") && !isIndexPrefix(attr))) { + return replaceReferenceExpressionWithIndexPrefix(node, attr); + } + return node; + } + + private ReferenceExpression replaceReferenceExpressionWithIndexPrefix( + ReferenceExpression node, String attr) { + if (leftSideAttributes.contains(attr) && rightSideAttributes.contains(attr)) { + throw new SemanticCheckException(format("Reference `%s` is ambiguous", attr)); + } else if (leftSideAttributes.contains(attr)) { + return new ReferenceExpression(format("%s.%s", leftRelationName, attr), node.type()); + } else if (rightSideAttributes.contains(attr)) { + return new ReferenceExpression(format("%s.%s", rightRelationName, attr), node.type()); + } else { + return node; + } + } + + private boolean isIndexPrefix(String attr) { + int separator = attr.indexOf('.'); + String possibleIndexPrefix = attr.substring(0, separator); + return leftRelationName.contains(possibleIndexPrefix) + || rightRelationName.contains(possibleIndexPrefix); + } + @Override public Expression visitFunction(FunctionExpression node, AnalysisContext context) { if (expressionMap.containsKey(node)) { diff --git a/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java b/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java index 8bb6824a63..e5798ba22c 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java +++ b/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java @@ -72,6 +72,21 @@ public Optional lookup(Symbol symbol) { Map table = tableByNamespace.get(symbol.getNamespace()); ExprType type = null; if (table != null) { + // To handle the field named start with [index.], for example index1.field1, + // this is used by Join query. + if (symbol.getNamespace() == Namespace.FIELD_NAME) { + String[] parts = symbol.getName().split("\\."); + if (parts.length == 2) { + // extract the indexName + if (tableByNamespace.get(Namespace.INDEX_NAME) != null) { + String indexName = tableByNamespace.get(Namespace.INDEX_NAME).firstKey(); + if (indexName != null && indexName.equals(parts[0])) { + type = table.get(parts[1]); + return Optional.ofNullable(type); + } + } + } + } type = table.get(symbol.getName()); } return Optional.ofNullable(type); diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 973b10310b..91cd4d3464 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -47,6 +47,7 @@ import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.ML; @@ -109,6 +110,10 @@ public T visitFilter(Filter node, C context) { return visitChildren(node, context); } + public T visitJoin(Join node, C context) { + return visitChildren(node, context); + } + public T visitProject(Project node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 4f3056b0f7..cbc88c8cae 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -48,6 +48,7 @@ import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; @@ -471,4 +472,12 @@ public static Parse parse( java.util.Map arguments) { return new Parse(parseMethod, sourceField, pattern, arguments, input); } + + public static Join join( + UnresolvedPlan left, + UnresolvedPlan right, + Join.JoinType joinType, + UnresolvedExpression condition) { + return new Join(left, right, joinType, condition); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Join.java b/core/src/main/java/org/opensearch/sql/ast/tree/Join.java new file mode 100644 index 0000000000..f70d46de84 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Join.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +@RequiredArgsConstructor +@Getter +@EqualsAndHashCode(callSuper = false) +@ToString +public class Join extends UnresolvedPlan { + private final UnresolvedPlan left; + private final UnresolvedPlan right; + private final JoinType joinType; + private final UnresolvedExpression joinCondition; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(left, right); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitJoin(this, context); + } + + public enum JoinType { + INNER, + LEFT, + RIGHT, + SEMI, + ANTI, + CROSS, + FULL + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index fd5ea14a2e..747fb26ae6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -289,6 +289,10 @@ public static Optional of(String str) { return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null)); } + public static BuiltinFunctionName of(FunctionName name) { + return ALL_NATIVE_FUNCTIONS.get(name); + } + public static Optional ofAggregation(String functionName) { return Optional.ofNullable( AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null)); diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index f962c3e4bf..54fd10010b 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -5,13 +5,29 @@ package org.opensearch.sql.planner; +import static org.opensearch.sql.planner.physical.join.JoinOperator.BuildSide.BuildLeft; +import static org.opensearch.sql.planner.physical.join.JoinOperator.BuildSide.BuildRight; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalCloseCursor; import org.opensearch.sql.planner.logical.LogicalDedupe; import org.opensearch.sql.planner.logical.LogicalEval; import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalJoin; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalNested; import org.opensearch.sql.planner.logical.LogicalPaginate; @@ -41,6 +57,10 @@ import org.opensearch.sql.planner.physical.TakeOrderedOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; +import org.opensearch.sql.planner.physical.join.HashJoinOperator; +import org.opensearch.sql.planner.physical.join.JoinOperator; +import org.opensearch.sql.planner.physical.join.JoinPredicatesHelper; +import org.opensearch.sql.planner.physical.join.NestedLoopJoinOperator; import org.opensearch.sql.storage.read.TableScanBuilder; import org.opensearch.sql.storage.write.TableWriteBuilder; @@ -54,6 +74,7 @@ * @param context type */ public class DefaultImplementor extends LogicalPlanNodeVisitor { + private static final Logger LOG = LogManager.getLogger(); @Override public PhysicalPlan visitRareTopN(LogicalRareTopN node, C context) { @@ -156,6 +177,89 @@ public PhysicalPlan visitRelation(LogicalRelation node, C context) { + "implementing and optimizing logical plan with relation involved"); } + @Override + public PhysicalPlan visitJoin(LogicalJoin join, C ctx) { + LOG.debug("join condition is {}", join.getCondition()); + List predicates = + JoinPredicatesHelper.splitConjunctivePredicates(join.getCondition()); + // Extract all equi-join key pairs + List> equiJoinKeys = new ArrayList<>(); + for (Expression predicate : predicates) { + if (JoinPredicatesHelper.isEqual(predicate)) { + Pair pair = + JoinPredicatesHelper.extractJoinKeys((FunctionExpression) predicate); + if (pair.getLeft() instanceof ReferenceExpression + && pair.getRight() instanceof ReferenceExpression) { + if (canEvaluate((ReferenceExpression) pair.getLeft(), join.getLeft()) + && canEvaluate((ReferenceExpression) pair.getRight(), join.getRight())) { + equiJoinKeys.add(pair); + } else { + throw new SemanticCheckException( + StringUtils.format("Join key must be a field of index.")); + } + } else { + throw new SemanticCheckException( + StringUtils.format( + "Join condition must contain field only. E.g. t1.field1 = t2.field2 AND" + + " t1.field3 = t2.field4. But found {}", + predicate.getClass().getSimpleName())); + } + } else { + equiJoinKeys.clear(); + break; + } + } + + // 1. Determining Join with Hint and build side. + JoinOperator.BuildSide buildSide = determineBuildSide(join.getType()); + // 2. Pick hash join if it is an equi-join and hash join supported + if (!equiJoinKeys.isEmpty()) { + Pair, List> unzipped = JoinPredicatesHelper.unzip(equiJoinKeys); + List leftKeys = unzipped.getLeft(); + List rightKeys = unzipped.getRight(); + LOG.info("EquiJoin leftKeys are {}, rightKeys are {}", leftKeys, rightKeys); + + return new HashJoinOperator( + leftKeys, + rightKeys, + join.getType(), + buildSide, + visitRelation((LogicalRelation) join.getLeft(), ctx), + visitRelation((LogicalRelation) join.getRight(), ctx), + Optional.empty()); + // 3. Pick sort merge join if the join keys are sortable. TODO + } else { + // 4. Pick Nested loop join if is a non-equi-join. TODO + return new NestedLoopJoinOperator( + visitRelation((LogicalRelation) join.getLeft(), ctx), + visitRelation((LogicalRelation) join.getRight(), ctx), + join.getType(), + buildSide, + join.getCondition()); + } + } + + /** + * Build side is right by default (except RightOuter). TODO set the smaller side as the build side + * TODO set build side from hint if provided + * + * @param joinType Join type + * @return Build side + */ + private JoinOperator.BuildSide determineBuildSide(Join.JoinType joinType) { + return joinType == Join.JoinType.RIGHT ? BuildLeft : BuildRight; + } + + /** Return true if the reference can be evaluated in relation */ + private boolean canEvaluate(ReferenceExpression expr, LogicalPlan plan) { + if (plan instanceof LogicalRelation relation) { + // TODO need fix, the attr() contains relation prefix: Index.Field + return relation.getTable().getFieldTypes().containsKey(expr.getAttr()); + } else { + throw new UnsupportedOperationException("Only relation can be used in join"); + } + } + @Override public PhysicalPlan visitFetchCursor(LogicalFetchCursor plan, C context) { return new PlanSerializer(plan.getEngine()).convertToPlan(plan.getCursor()); diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalJoin.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalJoin.java new file mode 100644 index 0000000000..4ba86fe920 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalJoin.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.logical; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.expression.Expression; + +@ToString +@EqualsAndHashCode(callSuper = true) +@Getter +public class LogicalJoin extends LogicalPlan { + private final LogicalPlan left; + private final LogicalPlan right; + private final Join.JoinType type; + private final Expression condition; + + public LogicalJoin( + LogicalPlan left, LogicalPlan right, Join.JoinType type, Expression condition) { + super(ImmutableList.of(left, right)); + this.left = left; + this.right = right; + this.type = type; + this.condition = condition; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitJoin(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 2a886ba0ca..40ebdc0ce3 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -13,6 +13,7 @@ import lombok.experimental.UtilityClass; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.expression.Expression; @@ -138,4 +139,13 @@ public LogicalPlan values(List... values) { public static LogicalPlan limit(LogicalPlan input, Integer limit, Integer offset) { return new LogicalLimit(input, limit, offset); } + + public LogicalPlan innerJoin(LogicalPlan left, LogicalPlan right, Expression condition) { + return join(left, right, Join.JoinType.INNER, condition); + } + + public LogicalPlan join( + LogicalPlan left, LogicalPlan right, Join.JoinType joinType, Expression condition) { + return new LogicalJoin(left, right, joinType, condition); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index 156db35306..532dcfb734 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -115,4 +115,8 @@ public R visitFetchCursor(LogicalFetchCursor plan, C context) { public R visitCloseCursor(LogicalCloseCursor plan, C context) { return visitNode(plan, context); } + + public R visitJoin(LogicalJoin plan, C context) { + return visitNode(plan, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index 67d7a05135..55771b0b15 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -5,6 +5,7 @@ package org.opensearch.sql.planner.physical; +import org.opensearch.sql.planner.physical.join.JoinOperator; import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.write.TableWriteOperator; @@ -99,4 +100,8 @@ public R visitML(PhysicalPlan node, C context) { public R visitCursorClose(CursorCloseOperator node, C context) { return visitNode(node, context); } + + public R visitJoin(JoinOperator node, C context) { + return visitNode(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTable.java b/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTable.java index 5542d0f0e4..9606f1cef9 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTable.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTable.java @@ -47,7 +47,7 @@ public static class DataSourceTableDefaultImplementor extends DefaultImplementor @Override public PhysicalPlan visitRelation(LogicalRelation node, Object context) { - return new DataSourceTableScan(dataSourceService); + return new DataSourceTableScan(dataSourceService, node); } } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScan.java b/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScan.java index 89e21377dc..b2c6fb737d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScan.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScan.java @@ -14,11 +14,14 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.storage.TableScanOperator; /** @@ -29,11 +32,19 @@ public class DataSourceTableScan extends TableScanOperator { private final DataSourceService dataSourceService; + private final LogicalRelation relation; + private final String relationName; private Iterator iterator; public DataSourceTableScan(DataSourceService dataSourceService) { + this(dataSourceService, null); + } + + public DataSourceTableScan(DataSourceService dataSourceService, LogicalRelation relation) { this.dataSourceService = dataSourceService; + this.relation = relation; + this.relationName = relation.getRelationName(); this.iterator = Collections.emptyIterator(); } @@ -68,4 +79,16 @@ public boolean hasNext() { public ExprValue next() { return iterator.next(); } + + @Override + public ExecutionEngine.Schema schema() { + List columns = + relation.getTable().getFieldTypes().entrySet().stream() + .map( + (entry) -> + new ExecutionEngine.Schema.Column( + entry.getKey(), relationName + "." + entry.getKey(), entry.getValue())) + .collect(Collectors.toList()); + return new ExecutionEngine.Schema(columns); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/DefaultHashedRelation.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/DefaultHashedRelation.java new file mode 100644 index 0000000000..6a032db5c1 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/DefaultHashedRelation.java @@ -0,0 +1,54 @@ +package org.opensearch.sql.planner.physical.join; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.opensearch.sql.data.model.ExprValue; + +public class DefaultHashedRelation implements HashedRelation, Serializable { + + private final Map> map = new HashMap<>(); + private int numKeys; + private int numValues; + + @Override + public List get(ExprValue key) { + return map.get(key); + } + + @Override + public ExprValue getValue(ExprValue key) { + List values = map.get(key); + return values != null && !values.isEmpty() ? values.getFirst() : null; + } + + @Override + public boolean containsKey(ExprValue key) { + return map.containsKey(key); + } + + @Override + public Iterator keyIterator() { + return map.keySet().iterator(); + } + + @Override + public boolean isUniqueKey() { + return numKeys == numValues; + } + + @Override + public void close() { + map.clear(); + } + + @Override + public void put(ExprValue key, ExprValue value) { + map.computeIfAbsent(key, k -> new ArrayList<>()).add(value); + numKeys++; + numValues++; + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java new file mode 100644 index 0000000000..112484c19c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java @@ -0,0 +1,275 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical.join; + +import static org.opensearch.sql.planner.physical.join.JoinOperator.BuildSide.BuildRight; + +import com.google.common.collect.ImmutableList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.planner.physical.PhysicalPlan; + +/** + * Hash Join Operator. For best performance, the build side should be set a smaller table, without + * hint and CBO, we treat right side as a smaller table by default and the build side set to right. + * TODO add join hint support. Best practice in PPL: source=bigger | INNER JOIN smaller ON + * bigger.field1 = smaller.field2 AND bigger.field3 = smaller.field4 The build side is right + * (smaller), and the streamed side is left (bigger). For RIGHT OUTER join, the build side is always + * left. If the smaller table is left, it will get the best performance: source=smaller | RIGHT JOIN + * bigger ON bigger.field1 = smaller.field2 AND bigger.field3 = smaller.field4 The build side is + * left (smaller), and the streamed side is right (bigger). + */ +public class HashJoinOperator extends JoinOperator { + private final List leftKeys; + private final List rightKeys; + private final BuildSide buildSide; + private final Optional nonEquiCond; + + // write the construct method + public HashJoinOperator( + List leftKeys, + List rightKeys, + Join.JoinType joinType, + BuildSide buildSide, + PhysicalPlan left, + PhysicalPlan right, + Optional nonEquiCond) { + super(left, right, joinType); + this.leftKeys = leftKeys; + this.rightKeys = rightKeys; + this.buildSide = buildSide; + this.nonEquiCond = nonEquiCond; + } + + private final ImmutableList.Builder joinedBuilder = ImmutableList.builder(); + private Iterator joinedIterator; + + private HashedRelation hashed; + private List buildKeys; + private List streamedKeys; + + @Override + public void open() { + left.open(); + right.open(); + if (!(leftKeys.size() == rightKeys.size() + && IntStream.range(0, leftKeys.size()) + .allMatch(i -> sameType(leftKeys.get(i), rightKeys.get(i))))) { + throw new IllegalArgumentException( + "Join keys from two sides should have same length and types"); + } + + Iterator streamed; + if (buildSide == BuildRight) { + hashed = buildHashed(right, rightKeys); + streamed = left; + buildKeys = rightKeys; + streamedKeys = leftKeys; + } else { + hashed = buildHashed(left, leftKeys); + streamed = right; + buildKeys = leftKeys; + streamedKeys = rightKeys; + } + + switch (joinType) { + case INNER -> innerJoin(streamed); + case LEFT, RIGHT -> outerJoin(streamed); + case SEMI -> semiJoin(streamed); + case ANTI -> antiJoin(streamed); + default -> throw new UnsupportedOperationException("Unsupported Join Type " + joinType); + } + } + + @Override + public void close() { + left.close(); + right.close(); + joinedIterator = null; + if (hashed != null) { + hashed.close(); + hashed = null; + } + } + + @Override + public void innerJoin(Iterator streamed) { + while (streamed.hasNext()) { + ExprValue streamedRow = streamed.next(); + + for (Expression streamedKey : streamedKeys) { + ExprValue streamedRowKey = streamedKey.valueOf(streamedRow.bindingTuples()); + if (streamedRowKey != null && hashed.containsKey(streamedRowKey)) { + List matchedBuildRows = hashed.get(streamedRowKey); + for (ExprValue matchedBuildRow : matchedBuildRows) { + ExprValue joinedRow = combineExprTupleValue(buildSide, streamedRow, matchedBuildRow); + if (nonEquiCond.isPresent()) { + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + joinedBuilder.add(joinedRow); + } + } else { + joinedBuilder.add(joinedRow); + } + } + } + } + } + joinedIterator = joinedBuilder.build().iterator(); + } + + /** The implementation for outer join: LeftOuter with BuildRight RightOuter with BuildLeft */ + @Override + public void outerJoin(Iterator streamed) { + while (streamed.hasNext()) { + ExprValue streamedRow = streamed.next(); + boolean matched = false; + for (Expression streamedKey : streamedKeys) { + ExprValue streamedRowKey = streamedKey.valueOf(streamedRow.bindingTuples()); + if (streamedRowKey != null && hashed.containsKey(streamedRowKey)) { + List matchedBuildRows = hashed.get(streamedRowKey); + for (ExprValue matchedBuildRow : matchedBuildRows) { + ExprValue joinedRow = combineExprTupleValue(buildSide, streamedRow, matchedBuildRow); + if (nonEquiCond.isPresent()) { + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + joinedBuilder.add(joinedRow); + matched = true; + } + } else { + joinedBuilder.add(joinedRow); + matched = true; + } + } + } else { + // if any streamedRowKey does not match, the remaining keys are not checked. + matched = false; + break; + } + } + + if (!matched) { + ExprTupleValue joinedRow = + combineExprTupleValue(buildSide, streamedRow, ExprValueUtils.nullValue()); + joinedBuilder.add(joinedRow); + } + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + @Override + public void semiJoin(Iterator streamed) { + Set matchedRows = new HashSet<>(); + + while (streamed.hasNext()) { + ExprValue streamedRow = streamed.next(); + for (Expression streamedKey : streamedKeys) { + ExprValue streamedRowKey = streamedKey.valueOf(streamedRow.bindingTuples()); + if (streamedRowKey != null && hashed.containsKey(streamedRowKey)) { + List matchedBuildRows = hashed.get(streamedRowKey); + for (ExprValue matchedBuildRow : matchedBuildRows) { + ExprValue joinedRow = combineExprTupleValue(buildSide, streamedRow, matchedBuildRow); + if (nonEquiCond.isPresent()) { + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + matchedRows.add(streamedRow); + } + } else { + matchedRows.add(streamedRow); + } + } + } else { + // if any streamedRowKey does not match, the remaining keys are not checked. + break; + } + } + } + + for (ExprValue row : matchedRows) { + joinedBuilder.add(row); + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + @Override + public void antiJoin(Iterator streamed) { + while (streamed.hasNext()) { + ExprValue streamedRow = streamed.next(); + boolean matched = false; + for (Expression streamedKey : streamedKeys) { + ExprValue streamedRowKey = streamedKey.valueOf(streamedRow.bindingTuples()); + if (streamedRowKey != null && hashed.containsKey(streamedRowKey)) { + List matchedBuildRows = hashed.get(streamedRowKey); + for (ExprValue matchedBuildRow : matchedBuildRows) { + if (nonEquiCond.isPresent()) { + ExprValue joinedRow = combineExprTupleValue(buildSide, streamedRow, matchedBuildRow); + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + matched = true; + } + } else { + matched = true; + } + } + } else { + // if any streamedRowKey does not match, the remaining keys are not checked. + matched = false; + break; + } + } + if (!matched) { + joinedBuilder.add(streamedRow); + } + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + private HashedRelation buildHashed(PhysicalPlan buildSide, List buildKeys) { + HashedRelation hashedRelation = new DefaultHashedRelation(); + while (buildSide.hasNext()) { + ExprValue row = buildSide.next(); + for (Expression buildKey : buildKeys) { + ExprValue rowKey = buildKey.valueOf(row.bindingTuples()); + if (rowKey != null) { + hashedRelation.put(rowKey, row); + break; + } + } + } + return hashedRelation; + } + + @Override + public boolean hasNext() { + return joinedIterator != null && joinedIterator.hasNext(); + } + + @Override + public ExprValue next() { + return joinedIterator.next(); + } + + @Override + public List getChild() { + return ImmutableList.of(left, right); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/HashedRelation.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashedRelation.java new file mode 100644 index 0000000000..9dbb226175 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashedRelation.java @@ -0,0 +1,32 @@ +package org.opensearch.sql.planner.physical.join; + +import java.util.Iterator; +import java.util.List; +import org.opensearch.sql.data.model.ExprValue; + +public interface HashedRelation { + + /** Return matched rows. */ + List get(ExprValue key); + + /** + * Return the single matched row. Only used in {@link DefaultHashedRelation#isUniqueKey()} is + * true. + */ + ExprValue getValue(ExprValue key); + + /** Whether the key exists. */ + boolean containsKey(ExprValue key); + + /** Return the key iterator. */ + Iterator keyIterator(); + + /** Whether the key is unique. */ + boolean isUniqueKey(); + + /** Put the key-value pair into the relation. */ + void put(ExprValue key, ExprValue value); + + /** Release the resources */ + void close(); +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java new file mode 100644 index 0000000000..d884bb643c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical.join; + +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +public abstract class JoinOperator extends PhysicalPlan { + protected PhysicalPlan left; + protected PhysicalPlan right; + protected Join.JoinType joinType; + + protected ExecutionEngine.Schema leftSchema; + protected ExecutionEngine.Schema rightSchema; + protected ExecutionEngine.Schema outputSchema; + + JoinOperator(PhysicalPlan left, PhysicalPlan right, Join.JoinType joinType) { + this.left = left; + this.right = right; + this.joinType = joinType; + this.leftSchema = left.schema(); + this.rightSchema = right.schema(); + getOutputSchema(); + } + + private void getOutputSchema() { + switch (joinType) { + case INNER, LEFT, RIGHT, FULL -> { // merge left and right schemas + List columns = + Stream.of(left.schema().getColumns(), right.schema().getColumns()) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + this.outputSchema = new ExecutionEngine.Schema(columns); + } + case SEMI, ANTI -> outputSchema = left.schema(); // left schema only + default -> throw new UnsupportedOperationException("Unsupported Join Type " + joinType); + } + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitJoin(this, context); + } + + @Override + public abstract List getChild(); + + public abstract void innerJoin(Iterator streamedSide); + + public abstract void outerJoin(Iterator streamedSide); + + public abstract void semiJoin(Iterator streamedSide); + + public abstract void antiJoin(Iterator streamedSide); + + protected ExprTupleValue combineExprTupleValue( + BuildSide buildSide, ExprValue streamedRow, ExprValue buildRow) { + ExprValue left = buildSide == BuildSide.BuildLeft ? buildRow : streamedRow; + ExprValue right = buildSide == BuildSide.BuildLeft ? streamedRow : buildRow; + Map leftTuple = getExprTupleMapFromSchema(left, leftSchema); + Map rightTuple = getExprTupleMapFromSchema(right, rightSchema); + Map combinedMap = new LinkedHashMap<>(leftTuple); + combinedMap.putAll(rightTuple); + return ExprTupleValue.fromExprValueMap(combinedMap); + } + + private Map getExprTupleMapFromSchema( + ExprValue row, ExecutionEngine.Schema schema) { + Map map = new LinkedHashMap<>(); + if (row.isNull()) { + schema.getColumns().forEach(col -> map.put(col.getAlias(), ExprNullValue.of())); + } else { + // replace to indexName.fieldName as tupleMap key in case the field names are same in join + // tables. + schema + .getColumns() + .forEach(col -> map.put(col.getAlias(), row.tupleValue().get(col.getName()))); + } + return map; + } + + protected boolean sameType(Expression expr1, Expression expr2) { + ExprType type1 = expr1.type(); + ExprType type2 = expr2.type(); + return type1.isCompatible(type2); + } + + public enum BuildSide { + BuildLeft, + BuildRight + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinPredicatesHelper.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinPredicatesHelper.java new file mode 100644 index 0000000000..4c789750a5 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinPredicatesHelper.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical.join; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +@UtilityClass +public class JoinPredicatesHelper { + + private static boolean instanceOf(Expression function, BuiltinFunctionName functionName) { + return function instanceof FunctionExpression + && ((FunctionExpression) function).getFunctionName().equals(functionName.getName()); + } + + private static boolean isValidJoinPredicate(FunctionExpression predicate) { + BuiltinFunctionName builtinFunctionName = BuiltinFunctionName.of(predicate.getFunctionName()); + switch (builtinFunctionName) { + case AND: + case OR: + case EQUAL: + case NOTEQUAL: + case LESS: + case LTE: + case GREATER: + case GTE: + return true; + default: + return false; + } + } + + public static ImmutablePair extractJoinKeys( + FunctionExpression predicate) { + if (isValidJoinPredicate(predicate)) { + throw new SemanticCheckException( + StringUtils.format( + "Join condition {} is an invalid function", + predicate.getFunctionName().getFunctionName())); + } else { + return ImmutablePair.of( + predicate.getArguments().getFirst(), predicate.getArguments().getLast()); + } + } + + public static List splitConjunctivePredicates(Expression condition) { + if (JoinPredicatesHelper.isAnd(condition)) { + return Stream.concat( + splitConjunctivePredicates(((FunctionExpression) condition).getArguments().getFirst()) + .stream(), + splitConjunctivePredicates(((FunctionExpression) condition).getArguments().getLast()) + .stream()) + .collect(Collectors.toList()); + } else { + return ImmutableList.of(condition); + } + } + + public static List splitDisjunctivePredicates(Expression condition) { + if (JoinPredicatesHelper.isOr(condition)) { + return Stream.concat( + splitDisjunctivePredicates(((FunctionExpression) condition).getArguments().getFirst()) + .stream(), + splitDisjunctivePredicates(((FunctionExpression) condition).getArguments().getLast()) + .stream()) + .collect(Collectors.toList()); + } else { + return ImmutableList.of(condition); + } + } + + public static Pair, List> unzip(List> pairs) { + List leftList = new ArrayList<>(); + List rightList = new ArrayList<>(); + for (Pair pair : pairs) { + leftList.add(pair.getLeft()); + rightList.add(pair.getRight()); + } + return Pair.of(leftList, rightList); + } + + public static boolean isAnd(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.AND); + } + + public static boolean isOr(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.OR); + } + + public static boolean isEqual(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.EQUAL); + } + + public static boolean isNot(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.NOT); + } + + public static boolean isXor(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.XOR); + } + + public static boolean isNotEqual(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.NOTEQUAL); + } + + public static boolean isLess(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.LESS); + } + + public static boolean isLte(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.LTE); + } + + public static boolean isGreater(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.GREATER); + } + + public static boolean isGte(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.GTE); + } + + public static boolean isLike(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.LIKE); + } + + public static boolean isNotLike(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.NOT_LIKE); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java new file mode 100644 index 0000000000..800d3798a3 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java @@ -0,0 +1,198 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical.join; + +import static org.opensearch.sql.planner.physical.join.JoinOperator.BuildSide.BuildRight; + +import com.google.common.collect.ImmutableList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.planner.physical.PhysicalPlan; + +/** + * Nested Loop Join Operator. For best performance, the build side should be set a smaller table, + * without hint and CBO, we treat right side as a smaller table by default and the build side set to + * right. TODO add join hint support. Best practice in PPL: source=bigger | INNER JOIN smaller ON + * bigger.field1 = smaller.field2 AND bigger.field3 = smaller.field4 The build side is right + * (smaller), and the streamed side is left (bigger). For RIGHT OUTER join, the build side is always + * left. If the smaller table is left, it will get the best performance: source=smaller | RIGHT JOIN + * bigger ON bigger.field1 = smaller.field2 AND bigger.field3 = smaller.field4 The build side is + * left (smaller), and the streamed side is right (bigger). + */ +public class NestedLoopJoinOperator extends JoinOperator { + private final BuildSide buildSide; + private final Expression condition; + + public NestedLoopJoinOperator( + PhysicalPlan left, + PhysicalPlan right, + Join.JoinType joinType, + BuildSide buildSide, + Expression condition) { + super(left, right, joinType); + this.buildSide = buildSide; + this.condition = condition; + } + + private final ImmutableList.Builder joinedBuilder = ImmutableList.builder(); + private Iterator joinedIterator; + + private List cachedBuildSide; + + @Override + public void open() { + left.open(); + right.open(); + Iterator streamed; + if (buildSide == BuildRight) { + cachedBuildSide = cacheIterator(right); + streamed = left; + } else { + cachedBuildSide = cacheIterator(left); + streamed = right; + } + + switch (joinType) { + case INNER -> innerJoin(streamed); + case LEFT, RIGHT -> outerJoin(streamed); + case SEMI -> semiJoin(streamed); + case ANTI -> antiJoin(streamed); + default -> throw new UnsupportedOperationException("Unsupported Join Type " + joinType); + } + } + + @Override + public void close() { + left.close(); + right.close(); + joinedIterator = null; + cachedBuildSide = null; + } + + /** The implementation for inner join: Inner with BuildRight */ + @Override + public void innerJoin(Iterator streamedSide) { + while (streamedSide.hasNext()) { + ExprValue streamedRow = streamedSide.next(); + + for (ExprValue buildRow : cachedBuildSide) { + ExprTupleValue joinedRow = combineExprTupleValue(buildSide, streamedRow, buildRow); + ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && (conditionValue.booleanValue())) { + joinedBuilder.add(joinedRow); + } + } + } + joinedIterator = joinedBuilder.build().iterator(); + } + + /** The implementation for outer join: LeftOuter with BuildRight RightOuter with BuildLeft */ + @Override + public void outerJoin(Iterator streamedSide) { + while (streamedSide.hasNext()) { + ExprValue streamedRow = streamedSide.next(); + boolean matched = false; + for (ExprValue buildRow : cachedBuildSide) { + ExprTupleValue joinedRow = combineExprTupleValue(buildSide, streamedRow, buildRow); + ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + joinedBuilder.add(joinedRow); + matched = true; + } + } + if (!matched) { + ExprTupleValue joinedRow = + combineExprTupleValue(buildSide, streamedRow, ExprValueUtils.nullValue()); + joinedBuilder.add(joinedRow); + } + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + /** + * The implementation for left semi join: LeftSemi with BuildRight TODO LeftSemi with buildLeft + */ + @Override + public void semiJoin(Iterator streamedSide) { + Set matchedRows = new HashSet<>(); + + while (streamedSide.hasNext()) { + ExprValue streamedRow = streamedSide.next(); + for (ExprValue buildRow : cachedBuildSide) { + ExprTupleValue joinedRow = combineExprTupleValue(buildSide, streamedRow, buildRow); + ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + matchedRows.add(streamedRow); + break; + } + } + } + + for (ExprValue row : matchedRows) { + joinedBuilder.add(row); + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + /** + * The implementation for left anti join: LeftAnti with BuildRight TODO LeftAnti with buildLeft + */ + @Override + public void antiJoin(Iterator streamedSide) { + while (streamedSide.hasNext()) { + ExprValue streamedRow = streamedSide.next(); + boolean matched = false; + for (ExprValue buildRow : cachedBuildSide) { + ExprTupleValue joinedRow = combineExprTupleValue(buildSide, streamedRow, buildRow); + ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + matched = true; + break; + } + } + if (!matched) { + joinedBuilder.add(streamedRow); + } + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + /** Convert iterator to a list to allow multiple iterations */ + private List cacheIterator(PhysicalPlan plan) { + ImmutableList.Builder streamedBuilder = ImmutableList.builder(); + plan.forEachRemaining(streamedBuilder::add); + return streamedBuilder.build(); + } + + @Override + public boolean hasNext() { + return joinedIterator != null && joinedIterator.hasNext(); + } + + @Override + public ExprValue next() { + return joinedIterator.next(); + } + + @Override + public List getChild() { + return ImmutableList.of(left, right); + } +} diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 8d935b11d2..f12ddf5fd6 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -81,6 +81,7 @@ import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.CloseCursor; import org.opensearch.sql.ast.tree.FetchCursor; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.Paginate; @@ -1767,4 +1768,229 @@ public void visit_close_cursor() { () -> assertEquals("pewpew", ((LogicalFetchCursor) analyzed.getChild().get(0)).getCursor())); } + + @Test + public void inner_join() { + assertAnalyzeEqual( + LogicalPlanDSL.innerJoin( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.INNER, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), AstDSL.field("schema2.double_value"))))); + } + + @Test + public void left_outer_join() { + assertAnalyzeEqual( + LogicalPlanDSL.join( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + Join.JoinType.LEFT, + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.LEFT, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), AstDSL.field("schema2.double_value"))))); + } + + @Test + public void right_outer_join() { + assertAnalyzeEqual( + LogicalPlanDSL.join( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + Join.JoinType.RIGHT, + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.RIGHT, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), AstDSL.field("schema2.double_value"))))); + } + + @Test + public void anti_join() { + assertAnalyzeEqual( + LogicalPlanDSL.join( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + Join.JoinType.ANTI, + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.ANTI, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), AstDSL.field("schema2.double_value"))))); + } + + @Test + public void semi_join() { + assertAnalyzeEqual( + LogicalPlanDSL.join( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + Join.JoinType.SEMI, + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.SEMI, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), AstDSL.field("schema2.double_value"))))); + } + + @Test + public void basic_SPJG() { + // Select(Filter)-Project-Join-GroupBy + // SELECT + // schema1.string_value, + // schema2.string_value, + // AVG(schema1.integer_value), + // MIN(schema2.long_value), + // FROM + // schema1 + // INNER JOIN + // schema2 + // ON + // schema1.integer_value = schema2.integer_value + // AND + // schema1.double_value = schema2.double_value + // WHERE + // schema1.integer_value > 10 + // GROUP BY + // schema1.string_value, schema2.string_value + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.filter( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.innerJoin( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + ImmutableList.of( + DSL.named( + "AVG(schema1.integer_value)", + DSL.avg(DSL.ref("schema1.integer_value", INTEGER))), + DSL.named( + "MIN(schema2.long_value)", + DSL.min(DSL.ref("schema2.long_value", LONG)))), + ImmutableList.of( + DSL.named("schema1.string_value", DSL.ref("schema1.string_value", STRING)), + DSL.named( + "schema2.string_value", DSL.ref("schema2.string_value", STRING)))), + DSL.greater( + DSL.ref("schema1.integer_value", INTEGER), DSL.literal(integerValue(10)))), + DSL.named("schema1.string_value", DSL.ref("schema1.string_value", STRING)), + DSL.named("schema2.string_value", DSL.ref("schema2.string_value", STRING))), + AstDSL.projectWithArg( + AstDSL.filter( + AstDSL.agg( + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.INNER, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), + AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), + AstDSL.field("schema2.double_value")))), + ImmutableList.of( + alias( + "AVG(schema1.integer_value)", + aggregate("AVG", qualifiedName("schema1.integer_value"))), + alias( + "MIN(schema2.long_value)", + aggregate("MIN", qualifiedName("schema2.long_value")))), + emptyList(), + ImmutableList.of( + alias("schema1.string_value", qualifiedName("schema1.string_value")), + alias("schema2.string_value", qualifiedName("schema2.string_value"))), + emptyList()), + compare(">", AstDSL.field("schema1.integer_value"), intLiteral(10))), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("schema1.string_value", AstDSL.field("schema1.string_value")), + AstDSL.alias("schema2.string_value", AstDSL.field("schema2.string_value")))); + } + + @Test + public void join_condition_is_ambiguous() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.INNER, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), + AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("double_value"), AstDSL.field("double_value")))))); + assertEquals("Reference `double_value` is ambiguous", exception.getMessage()); + } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index 0bf959a1b7..16da1e539d 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -177,7 +177,8 @@ protected ExpressionAnalyzer expressionAnalyzer() { } protected void assertAnalyzeEqual(LogicalPlan expected, UnresolvedPlan unresolvedPlan) { - assertEquals(expected, analyze(unresolvedPlan)); + LogicalPlan actual = analyze(unresolvedPlan); + assertEquals(expected, actual); } protected LogicalPlan analyze(UnresolvedPlan unresolvedPlan) { diff --git a/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java b/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java index 27edc588fa..df4cdcb99d 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; @@ -132,4 +133,45 @@ public void rename_and_project_all() { AstDSL.defaultFieldsArgs(), AllFields.of())); } + + @Test + public void project_all_from_join() { + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.innerJoin( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + DSL.named("schema1.integer_value", DSL.ref("schema1.integer_value", INTEGER)), + DSL.named("schema1.double_value", DSL.ref("schema1.double_value", DOUBLE)), + DSL.named("schema1.string_value", DSL.ref("schema1.string_value", STRING)), + DSL.named("schema2.integer_value", DSL.ref("schema2.integer_value", INTEGER)), + DSL.named("schema2.double_value", DSL.ref("schema2.double_value", DOUBLE)), + DSL.named("schema2.string_value", DSL.ref("schema2.string_value", STRING))), + AstDSL.projectWithArg( + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.INNER, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), + AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), + AstDSL.field("schema2.double_value")))), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("schema1.integer_value", AstDSL.field("schema1.integer_value")), + AstDSL.alias("schema1.double_value", AstDSL.field("schema1.double_value")), + AstDSL.alias("schema1.string_value", AstDSL.field("schema1.string_value")), + AstDSL.alias("schema2.integer_value", AstDSL.field("schema2.integer_value")), + AstDSL.alias("schema2.double_value", AstDSL.field("schema2.double_value")), + AstDSL.alias("schema2.string_value", AstDSL.field("schema2.string_value")))); + } } diff --git a/core/src/test/java/org/opensearch/sql/config/TestConfig.java b/core/src/test/java/org/opensearch/sql/config/TestConfig.java index 92b6aac64f..3c12c4a1a6 100644 --- a/core/src/test/java/org/opensearch/sql/config/TestConfig.java +++ b/core/src/test/java/org/opensearch/sql/config/TestConfig.java @@ -61,6 +61,17 @@ public class TestConfig { .put("comment.data", ExprCoreType.STRING) .build(); + public static Map typeMapping2 = + new ImmutableMap.Builder() + .put("i_value", ExprCoreType.INTEGER) + .put("l_value", ExprCoreType.LONG) + .put("f_value", ExprCoreType.FLOAT) + .put("d_value", ExprCoreType.DOUBLE) + .put("msg", ExprCoreType.STRING) + .put("msg.info", ExprCoreType.STRING) + .put("msg.info.id", ExprCoreType.STRING) + .build(); + protected StorageEngine storageEngine() { return new StorageEngine() { @Override diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index f212749f48..f4be576d79 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -78,6 +78,56 @@ public void logical_plan_should_be_traversable() { assertEquals(5, result); } + @Test + public void table_join_plan_should_be_traversable() { + LogicalPlan leftRelation = LogicalPlanDSL.relation("schema1", table); + LogicalPlan rightRelation = LogicalPlanDSL.relation("schema2", table); + LogicalPlan join = LogicalPlanDSL.innerJoin(leftRelation, rightRelation, expression); + LogicalPlan logicalPlan = + LogicalPlanDSL.rename( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.rareTopN( + LogicalPlanDSL.filter(join, expression), + CommandType.TOP, + ImmutableList.of(expression), + expression), + ImmutableList.of(DSL.named("avg", aggregator)), + ImmutableList.of(DSL.named("group", expression))), + ImmutableMap.of(ref, ref)); + Integer result = logicalPlan.accept(new NodesCount(), null); + assertEquals(7, result); + } + + @Test + public void complex_join_plan_should_be_traversable() { + LogicalPlan leftPlan = + LogicalPlanDSL.rename( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.rareTopN( + LogicalPlanDSL.filter(LogicalPlanDSL.relation("schema", table), expression), + CommandType.TOP, + ImmutableList.of(expression), + expression), + ImmutableList.of(DSL.named("avg", aggregator)), + ImmutableList.of(DSL.named("group", expression))), + ImmutableMap.of(ref, ref)); + + LogicalPlan rightPlan = + LogicalPlanDSL.rename( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.rareTopN( + LogicalPlanDSL.filter(LogicalPlanDSL.relation("schema", table), expression), + CommandType.TOP, + ImmutableList.of(expression), + expression), + ImmutableList.of(DSL.named("avg", aggregator)), + ImmutableList.of(DSL.named("group", expression))), + ImmutableMap.of(ref, ref)); + LogicalPlan join = LogicalPlanDSL.innerJoin(leftPlan, rightPlan, expression); + Integer result = join.accept(new NodesCount(), null); + assertEquals(11, result); + } + @SuppressWarnings("unchecked") private static Stream getLogicalPlansForVisitorTest() { LogicalPlan relation = LogicalPlanDSL.relation("schema", table); @@ -141,6 +191,12 @@ public TableWriteOperator build(PhysicalPlan child) { LogicalCloseCursor closeCursor = new LogicalCloseCursor(cursor); + LogicalPlan relation2 = LogicalPlanDSL.relation("schema2", table); + + LogicalPlan join = + LogicalPlanDSL.innerJoin( + (LogicalRelation) relation, (LogicalRelation) relation2, expression); + return Stream.of( relation, tableScanBuilder, @@ -163,7 +219,8 @@ public TableWriteOperator build(PhysicalPlan child) { paginate, nested, cursor, - closeCursor) + closeCursor, + join) .map(Arguments::of); } @@ -214,5 +271,14 @@ public Integer visitRareTopN(LogicalRareTopN plan, Object context) { .mapToInt(Integer::intValue) .sum(); } + + @Override + public Integer visitJoin(LogicalJoin plan, Object context) { + return 1 + + plan.getChild().stream() + .map(child -> child.accept(this, context)) + .mapToInt(Integer::intValue) + .sum(); + } } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java new file mode 100644 index 0000000000..416bde5c3a --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java @@ -0,0 +1,369 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import com.google.common.collect.ImmutableMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.planner.physical.join.JoinOperator; + +@ExtendWith(MockitoExtension.class) +public class HashJoinOperatorTest extends JoinOperatorTestHelper { + private final Optional emptyNonEquiCond = Optional.empty(); + private final Optional defaultNonEquiCond = + Optional.of( + DSL.and( + DSL.equal(DSL.ref("error_t.host", STRING), DSL.literal("h1")), + DSL.lte(DSL.ref("name_t.id", INTEGER), DSL.literal(5)))); + + @Test + public void inner_join_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10)); + } + + @Test + public void inner_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, true); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10)); + } + + @Test + public void inner_join_with_non_equi_cond_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, defaultNonEquiCond, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder(error1_id1, error1_id1_duplicated, error2_id2)); + } + + @Test + public void left_join_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.LEFT, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(9, result.size()); + assertThat( + result, + containsInAnyOrder( + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10, + error12_null, + error13_null)); + } + + @Test + public void left_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeHashJoin(Join.JoinType.LEFT, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, true); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(12, result.size()); + assertThat( + result, + containsInAnyOrder( + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10, + id4_null, + id5_null, + id7_null, + id9_null, + id11_null)); + } + + @Test + public void left_join_with_non_equi_cond_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.LEFT, JoinOperator.BuildSide.BuildRight, defaultNonEquiCond, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(9, result.size()); + assertThat( + result, + containsInAnyOrder( + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_null, + error6_null, + error8_null, + error10_null, + error12_null, + error13_null)); + } + + @Test + public void right_join_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft, emptyNonEquiCond, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(12, result.size()); + assertThat( + result, + containsInAnyOrder( + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10, + null_id4, + null_id5, + null_id7, + null_id9, + null_id11)); + } + + @Test + public void right_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeHashJoin(Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft, emptyNonEquiCond, true); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(9, result.size()); + assertThat( + result, + containsInAnyOrder( + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10, + null_error12, + null_error13)); + } + + @Test + public void right_join_with_non_equi_cond_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft, defaultNonEquiCond, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(12, result.size()); + assertThat( + result, + containsInAnyOrder( + error1_id1, + error1_id1_duplicated, + error2_id2, + null_id3, + null_id4, + null_id5, + null_id6, + null_id7, + null_id8, + null_id9, + null_id10, + null_id11)); + } + + @Test + public void semi_join_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.SEMI, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-04"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-06"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-03"), "host", "h1", "errors", 2)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h1", "errors", 6)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h2", "errors", 10)))); + } + + @Test + public void semi_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeHashJoin(Join.JoinType.SEMI, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, true); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(6, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "name", "b")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f")), + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 3); + put("name", null); + } + }), + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 8); + put("name", null); + } + }))); + } + + @Test + public void semi_join_non_equi_cond_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.SEMI, JoinOperator.BuildSide.BuildRight, defaultNonEquiCond, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-04"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-06"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h1", "errors", 2)))); + } + + @Test + public void anti_join_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.ANTI, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(2, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + } + + @Test + public void anti_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeHashJoin(Join.JoinType.ANTI, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, true); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(5, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + } + + @Test + public void anti_join_non_equi_cond_test() { + PhysicalPlan joinPlan = + makeHashJoin( + Join.JoinType.ANTI, JoinOperator.BuildSide.BuildRight, defaultNonEquiCond, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(6, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h1", "errors", 6)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h2", "errors", 10)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTestHelper.java b/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTestHelper.java new file mode 100644 index 0000000000..72f4d9b244 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTestHelper.java @@ -0,0 +1,780 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.opensearch.sql.data.type.ExprCoreType.DATE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Optional; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.planner.physical.join.HashJoinOperator; +import org.opensearch.sql.planner.physical.join.JoinOperator; +import org.opensearch.sql.planner.physical.join.NestedLoopJoinOperator; + +public class JoinOperatorTestHelper extends PhysicalPlanTestBase { + + private final List errorInputs = + new ImmutableList.Builder() + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h1", "errors", 2))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h1", "errors", 1))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h2", "errors", 10))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-06"), "host", "h1", "errors", 1))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h1", "errors", 6))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13))) + .build(); + + private final List nameInputs = + new ImmutableList.Builder() + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "name", "b"))) + .add( + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 3); + put("name", null); + } + })) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g"))) + .add( + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 8); + put("name", null); + } + })) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k"))) + .build(); + + private final List sameNameInputs = + new ImmutableList.Builder() + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 3, "name", "c"))) + .add( + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 5); + put("name", null); + } + })) + .add( + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 8); + put("name", null); + } + })) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "jj"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "jjj"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 15, "name", "o"))) + .add( + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 16); + put("name", null); + } + })) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 17, "name", "q"))) + .build(); + + private final ExecutionEngine.Schema errorSchema = + new ExecutionEngine.Schema( + List.of( + new ExecutionEngine.Schema.Column("day", "error_t.day", DATE), + new ExecutionEngine.Schema.Column("host", "error_t.host", STRING), + new ExecutionEngine.Schema.Column("errors", "error_t.errors", INTEGER))); + + private final ExecutionEngine.Schema nameSchema = + new ExecutionEngine.Schema( + List.of( + new ExecutionEngine.Schema.Column("id", "name_t.id", INTEGER), + new ExecutionEngine.Schema.Column("name", "name_t.name", STRING))); + + private final ExecutionEngine.Schema sameNameSchema = + new ExecutionEngine.Schema( + List.of( + new ExecutionEngine.Schema.Column("id", "name_t2.id", INTEGER), + new ExecutionEngine.Schema.Column("name", "name_t2.name", STRING))); + + public PhysicalPlan makeNestedLoopJoin( + Join.JoinType joinType, JoinOperator.BuildSide buildSide, boolean reversed) { + PhysicalPlan left = + reversed + ? testTableScan("name_t", nameSchema, nameInputs) + : testTableScan("error_t", errorSchema, errorInputs); + PhysicalPlan right = + reversed + ? testTableScan("error_t", errorSchema, errorInputs) + : testTableScan("name_t", nameSchema, nameInputs); + return new NestedLoopJoinOperator( + left, + right, + joinType, + buildSide, + DSL.equal(DSL.ref("error_t.errors", INTEGER), DSL.ref("name_t.id", INTEGER))); + } + + public PhysicalPlan makeNestedLoopJoinWithSameColumnNames( + Join.JoinType joinType, JoinOperator.BuildSide buildSide, boolean reversed) { + PhysicalPlan left = + reversed + ? testTableScan("name_t2", sameNameSchema, sameNameInputs) + : testTableScan("name_t", nameSchema, nameInputs); + PhysicalPlan right = + reversed + ? testTableScan("name_t", nameSchema, nameInputs) + : testTableScan("name_t2", sameNameSchema, sameNameInputs); + return new NestedLoopJoinOperator( + left, + right, + joinType, + buildSide, + DSL.equal(DSL.ref("name_t.id", INTEGER), DSL.ref("name_t2.id", INTEGER))); + } + + public PhysicalPlan makeHashJoin( + Join.JoinType joinType, + JoinOperator.BuildSide buildSide, + Optional nonEquiCond, + boolean reversed) { + PhysicalPlan left = + reversed + ? testTableScan("name_t", nameSchema, nameInputs) + : testTableScan("error_t", errorSchema, errorInputs); + PhysicalPlan right = + reversed + ? testTableScan("error_t", errorSchema, errorInputs) + : testTableScan("name_t", nameSchema, nameInputs); + List leftKeys = + reversed + ? ImmutableList.of(DSL.ref("id", INTEGER)) + : ImmutableList.of(DSL.ref("errors", INTEGER)); + + List rightKeys = + reversed + ? ImmutableList.of(DSL.ref("errors", INTEGER)) + : ImmutableList.of(DSL.ref("id", INTEGER)); + return new HashJoinOperator(leftKeys, rightKeys, joinType, buildSide, left, right, nonEquiCond); + } + + public PhysicalPlan makeHashJoinWithSameColumnNames( + Join.JoinType joinType, + JoinOperator.BuildSide buildSide, + Optional nonEquiCond, + boolean reversed) { + PhysicalPlan left = + reversed + ? testTableScan("name_t", nameSchema, nameInputs) + : testTableScan("name_t2", sameNameSchema, sameNameInputs); + PhysicalPlan right = + reversed + ? testTableScan("name_t2", sameNameSchema, sameNameInputs) + : testTableScan("name_t", nameSchema, nameInputs); + List leftKeys = ImmutableList.of(DSL.ref("id", INTEGER)); + + List rightKeys = ImmutableList.of(DSL.ref("id", INTEGER)); + return new HashJoinOperator(leftKeys, rightKeys, joinType, buildSide, left, right, nonEquiCond); + } + + /** {day:DATE '2021-01-04',host:"h1",errors:1,id:1,name:"a"} */ + protected ExprValue error1_id1 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "error_t.day", + new ExprDateValue("2021-01-04"), + "error_t.host", + "h1", + "error_t.errors", + 1, + "name_t.id", + 1, + "name_t.name", + "a")); + + /** {day:DATE '2021-01-06',host:"h1",errors:1,id:1,name:"a"} */ + protected ExprValue error1_id1_duplicated = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "error_t.day", + new ExprDateValue("2021-01-06"), + "error_t.host", + "h1", + "error_t.errors", + 1, + "name_t.id", + 1, + "name_t.name", + "a")); + + /** {day:DATE '2021-01-03',host:"h1",errors:2,id:2,name:"b"} */ + protected ExprValue error2_id2 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "error_t.day", + new ExprDateValue("2021-01-03"), + "error_t.host", + "h1", + "error_t.errors", + 2, + "name_t.id", + 2, + "name_t.name", + "b")); + + /** {day:DATE '2021-01-03',host:"h2",errors:3,id:3,name:NULL} */ + protected ExprValue error3_id3 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-03")); + put("error_t.host", "h2"); + put("error_t.errors", 3); + put("name_t.id", 3); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-03',host:"h2",errors:3,id:NULL,name:NULL} */ + protected ExprValue error3_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-03")); + put("error_t.host", "h2"); + put("error_t.errors", 3); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-07',host:"h1",errors:6,id:6,name:"f"} */ + protected ExprValue error6_id6 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "error_t.day", + new ExprDateValue("2021-01-07"), + "error_t.host", + "h1", + "error_t.errors", + 6, + "name_t.id", + 6, + "name_t.name", + "f")); + + /** {day:DATE '2021-01-07',host:"h1",errors:6,id:NULL,name:NULL} */ + protected ExprValue error6_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h1"); + put("error_t.errors", 6); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-07',host:"h2",errors:8,id:8,name:NULL} */ + protected ExprValue error8_id8 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h2"); + put("error_t.errors", 8); + put("name_t.id", 8); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-07',host:"h2",errors:8,id:NULL,name:NULL} */ + protected ExprValue error8_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h2"); + put("error_t.errors", 8); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-04',host:"h2",errors:10,id:10,name:"j"} */ + protected ExprValue error10_id10 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "error_t.day", + new ExprDateValue("2021-01-04"), + "error_t.host", + "h2", + "error_t.errors", + 10, + "name_t.id", + 10, + "name_t.name", + "j")); + + /** {day:DATE '2021-01-04',host:"h2",errors:10,id:NULL,name:NULL} */ + protected ExprValue error10_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-04")); + put("error_t.host", "h2"); + put("error_t.errors", 10); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-07',host:"h2",errors:12,id:NULL,name:NULL} */ + protected ExprValue error12_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h2"); + put("error_t.errors", 12); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-08',host:"h1",errors:13,id:NULL,name:NULL} */ + protected ExprValue error13_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-08")); + put("error_t.host", "h1"); + put("error_t.errors", 13); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {id:1,name:"a",day:DATE '2021-01-04',host:"h1",errors:1} */ + protected ExprValue id1_error1 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", + 1, + "name_t.name", + "a", + "error_t.day", + new ExprDateValue("2021-01-04"), + "error_t.host", + "h1", + "error_t.errors", + 1)); + + /** {id:1,name:"a",day:DATE '2021-01-06',host:"h1",errors:1} */ + protected ExprValue id1_error1_duplicated = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", + 1, + "name_t.name", + "a", + "error_t.day", + new ExprDateValue("2021-01-06"), + "error_t.host", + "h1", + "error_t.errors", + 1)); + + /** {id:2,name:"b",day:DATE '2021-01-03',host:"h1",errors:2} */ + protected ExprValue id2_error2 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", + 2, + "name_t.name", + "b", + "error_t.day", + new ExprDateValue("2021-01-03"), + "error_t.host", + "h1", + "error_t.errors", + 2)); + + /** {id:3,name:NULL,day:DATE '2021-01-03',host:"h2",errors:3} */ + protected ExprValue id3_error3 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 3); + put("name_t.name", null); + put("error_t.day", new ExprDateValue("2021-01-03")); + put("error_t.host", "h2"); + put("error_t.errors", 3); + } + }); + + /** {id:4,name:"d",day:NULL,host:NULL,errors:NULL} */ + protected ExprValue id4_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 4); + put("name_t.name", "d"); + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + } + }); + + /** {id:5,name:"e",day:NULL,host:NULL,errors:NULL} */ + protected ExprValue id5_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 5); + put("name_t.name", "e"); + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + } + }); + + /** {id:6,name:"f",day:DATE '2021-01-07',host:"h1",errors:6} */ + protected ExprValue id6_error6 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", + 6, + "name_t.name", + "f", + "error_t.day", + new ExprDateValue("2021-01-07"), + "error_t.host", + "h1", + "error_t.errors", + 6)); + + /** {id:7,name:"g",day:NULL,host:NULL,errors:NULL} */ + protected ExprValue id7_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 7); + put("name_t.name", "g"); + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + } + }); + + /** {id:8,name:NULL,day:DATE '2021-01-07',host:"h2",errors:8} */ + protected ExprValue id8_error8 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 8); + put("name_t.name", null); + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h2"); + put("error_t.errors", 8); + } + }); + + /** {id:9,name:"i",day:NULL,host:NULL,errors:NULL} */ + protected ExprValue id9_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 9); + put("name_t.name", "i"); + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + } + }); + + /** {id:10,name:"j",day:DATE '2021-01-04',host:"h2",errors:10} */ + protected ExprValue id10_error10 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", + 10, + "name_t.name", + "j", + "error_t.day", + new ExprDateValue("2021-01-04"), + "error_t.host", + "h2", + "error_t.errors", + 10)); + + /** {id:11,name:"k",day:NULL,host:NULL,errors:NULL} */ + protected ExprValue id11_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 11); + put("name_t.name", "k"); + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:3,name:NULL} */ + protected ExprValue null_id3 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 3); + put("name_t.name", null); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:4,name:"d"} */ + protected ExprValue null_id4 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 4); + put("name_t.name", "d"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:5,name:"e"} */ + protected ExprValue null_id5 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 5); + put("name_t.name", "e"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:6,name:"f"} */ + protected ExprValue null_id6 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 6); + put("name_t.name", "f"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:7,name:"g"} */ + protected ExprValue null_id7 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 7); + put("name_t.name", "g"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:8,name:NULL} */ + protected ExprValue null_id8 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 8); + put("name_t.name", null); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:9,name:"i"} */ + protected ExprValue null_id9 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 9); + put("name_t.name", "i"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:10,name:"j"} */ + protected ExprValue null_id10 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 10); + put("name_t.name", "j"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:11,name:"k"} */ + protected ExprValue null_id11 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 11); + put("name_t.name", "k"); + } + }); + + /** {id:NULL,name:NULL,day:DATE '2021-01-07',host:"h2",errors:12} */ + protected ExprValue null_error12 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", null); + put("name_t.name", null); + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h2"); + put("error_t.errors", 12); + } + }); + + /** {id:NULL,name:NULL,day:DATE '2021-01-08',host:"h1",errors:13} */ + protected ExprValue null_error13 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", null); + put("name_t.name", null); + put("error_t.day", new ExprDateValue("2021-01-08")); + put("error_t.host", "h1"); + put("error_t.errors", 13); + } + }); + + /** {name_t.id:1,name_t.name:"a",name_t2.id:1,name_t2.name:"a"} */ + ExprValue id1_same_id1 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", 1, "name_t.name", "a", "name_t2.id", 1, "name_t2.name", "a")); + + /** {name_t.id:3,name_t.name:NULL,name_t2.id:3,name_t2.name:"c"} */ + ExprValue id3_same_id3 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 3); + put("name_t.name", null); + put("name_t2.id", 3); + put("name_t2.name", "c"); + } + }); + + /** {name_t.id:5,name_t.name:"e",name_t2.id:5,name_t2.name:NULL} */ + ExprValue id5_same_id5 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 5); + put("name_t.name", "e"); + put("name_t2.id", 5); + put("name_t2.name", null); + } + }); + + /** {name_t.id:8,name_t.name:NULL,name_t2.id:8,name_t2.name:NULL} */ + ExprValue id8_same_id8 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 8); + put("name_t.name", null); + put("name_t2.id", 8); + put("name_t2.name", null); + } + }); + + /** {name_t.id:10,name_t.name:"j",name_t2.id:10,name_t2.name:"j"} */ + ExprValue id10_same_id10 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", 10, "name_t.name", "j", "name_t2.id", 10, "name_t2.name", "j")); + + /** {name_t.id:10,name_t.name:"j",name_t2.id:10,name_t2.name:"jj"} */ + ExprValue id10_same_id10_duplicated = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", 10, "name_t.name", "j", "name_t2.id", 10, "name_t2.name", "jj")); + + /** {name_t.id:10,name_t.name:"j",name_t2.id:10,name_t2.name:"jjj"} */ + ExprValue id10_same_id10_duplicated2 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", 10, "name_t.name", "j", "name_t2.id", 10, "name_t2.name", "jjj")); +} diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/NestedLoopJoinOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/NestedLoopJoinOperatorTest.java new file mode 100644 index 0000000000..8e78954662 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/physical/NestedLoopJoinOperatorTest.java @@ -0,0 +1,270 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.common.collect.ImmutableMap; +import java.util.LinkedHashMap; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.planner.physical.join.JoinOperator; + +@ExtendWith(MockitoExtension.class) +public class NestedLoopJoinOperatorTest extends JoinOperatorTestHelper { + + @Test + public void inner_join_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10)); + } + + @Test + public void inner_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, true); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10)); + } + + @Test + public void left_join_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.LEFT, JoinOperator.BuildSide.BuildRight, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(9, result.size()); + assertThat( + result, + containsInAnyOrder( + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10, + error12_null, + error13_null)); + } + + @Test + public void left_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.LEFT, JoinOperator.BuildSide.BuildRight, true); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(12, result.size()); + assertThat( + result, + containsInAnyOrder( + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10, + id4_null, + id5_null, + id7_null, + id9_null, + id11_null)); + } + + @Test + public void right_join_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(12, result.size()); + assertThat( + result, + containsInAnyOrder( + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10, + null_id4, + null_id5, + null_id7, + null_id9, + null_id11)); + } + + @Test + public void right_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft, true); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(9, result.size()); + assertThat( + result, + containsInAnyOrder( + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10, + null_error12, + null_error13)); + } + + @Test + public void semi_join_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.SEMI, JoinOperator.BuildSide.BuildRight, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-04"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-06"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-03"), "host", "h1", "errors", 2)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h1", "errors", 6)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h2", "errors", 10)))); + } + + @Test + public void semi_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.SEMI, JoinOperator.BuildSide.BuildRight, true); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(6, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "name", "b")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f")), + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 3); + put("name", null); + } + }), + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 8); + put("name", null); + } + }))); + } + + @Test + public void anti_join_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.ANTI, JoinOperator.BuildSide.BuildRight, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(2, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + } + + @Test + public void anti_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.ANTI, JoinOperator.BuildSide.BuildRight, true); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(5, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + } + + // +-----------------------------------------+ + // | Test join tables with same column names | + // +-----------------------------------------+ + + @Test + public void same_column_names_inner_join_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoinWithSameColumnNames( + Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + id1_same_id1, + id3_same_id3, + id5_same_id5, + id8_same_id8, + id10_same_id10, + id10_same_id10_duplicated, + id10_same_id10_duplicated2)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java index 6399f945ed..14b53d434a 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java @@ -20,6 +20,7 @@ import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.env.Environment; @@ -27,21 +28,6 @@ public class PhysicalPlanTestBase { - protected static final List countTestInputs = - new ImmutableList.Builder() - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 3, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 8, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "testString", "asdf"))) - .build(); - protected static final List inputs = new ImmutableList.Builder() .add( @@ -294,8 +280,15 @@ protected static PhysicalPlan testScan(List inputs) { return new TestScan(inputs); } + protected static PhysicalPlan testTableScan( + String relationName, ExecutionEngine.Schema schema, List inputs) { + return new TestScan(inputs, relationName, schema); + } + protected static class TestScan extends PhysicalPlan implements SerializablePlan { private final Iterator iterator; + private ExecutionEngine.Schema schema; + private String relationName; public TestScan() { iterator = inputs.iterator(); @@ -305,6 +298,12 @@ public TestScan(List inputs) { iterator = inputs.iterator(); } + public TestScan(List inputs, String relationName, ExecutionEngine.Schema schema) { + iterator = inputs.iterator(); + this.relationName = relationName; + this.schema = schema; + } + @Override public R accept(PhysicalPlanNodeVisitor visitor, C context) { return null; @@ -325,6 +324,11 @@ public ExprValue next() { return iterator.next(); } + @Override + public ExecutionEngine.Schema schema() { + return this.schema; + } + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {} diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 9f707c13cd..4c4b8a9c32 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -36,6 +36,16 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; +//Native JOIN KEYWORDS +JOIN: 'JOIN'; +ON: 'ON'; +INNER: 'INNER'; +OUTER: 'OUTER'; +FULL: 'FULL'; +SEMI: 'SEMI'; +ANTI: 'ANTI'; +CROSS: 'CROSS'; + // COMMAND ASSIST KEYWORDS AS: 'AS'; BY: 'BY'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 4dc223b028..29fc53f1c0 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -170,12 +170,48 @@ fromClause | INDEX EQUAL tableSourceClause | SOURCE EQUAL tableFunction | INDEX EQUAL tableFunction + | SOURCE EQUAL relation + | INDEX EQUAL relation ; tableSourceClause : tableSource (COMMA tableSource)* ; +// TODO two-tables join only. Multi-tables join `relationExtension*` is unsupported in current implementation. +relation + : tablePrimary relationExtension + ; + +tablePrimary + : tableSource (AS alias = qualifiedName)? + ; + +relationExtension + : joinSource + ; + + // TODO joinCriteria could be none `(joinCriteria?)` for complex cases. It's unsupported in current implementation. + // TODO join hints `(hintStatement)?` is unsupported in current implementation. + // TODO directly tables jon only, join two plans is unsupported in current implementation. +joinSource + : (joinType) JOIN right = tablePrimary joinCriteria + ; + +joinType + : INNER? + | CROSS + | LEFT OUTER? + | RIGHT OUTER? + | FULL OUTER? + | LEFT? SEMI + | LEFT? ANTI + ; + +joinCriteria + : ON logicalExpression + ; + renameClasue : orignalField = wcFieldExpression AS renamedField = wcFieldExpression ; @@ -925,4 +961,14 @@ keywordsCanBeId | SPARKLINE | C | DC + // JOIN + | ON + | INNER + | CROSS + | OUTER + | SEMI + | LEFT + | RIGHT + | FULL + | ANTI ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 78fe28b49e..3931ee8941 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -53,6 +53,7 @@ import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.Parse; @@ -312,6 +313,8 @@ public UnresolvedPlan visitTopCommand(TopCommandContext ctx) { public UnresolvedPlan visitFromClause(FromClauseContext ctx) { if (ctx.tableFunction() != null) { return visitTableFunction(ctx.tableFunction()); + } else if (ctx.relation() != null) { + return visitRelation(ctx.relation()); } else { return visitTableSourceClause(ctx.tableSourceClause()); } @@ -337,6 +340,47 @@ public UnresolvedPlan visitTableFunction(TableFunctionContext ctx) { return new TableFunction(this.internalVisitExpression(ctx.qualifiedName()), builder.build()); } + @Override + public UnresolvedPlan visitTablePrimary(OpenSearchPPLParser.TablePrimaryContext ctx) { + if (ctx.alias != null) { + return new Relation(this.internalVisitExpression(ctx.tableSource()), ctx.alias.getText()); + } else { + return new Relation(this.internalVisitExpression(ctx.tableSource())); + } + } + + @Override + public UnresolvedPlan visitRelation(OpenSearchPPLParser.RelationContext ctx) { + return withRelationExtensions(ctx, visitTablePrimary(ctx.tablePrimary())); + } + + private UnresolvedPlan withRelationExtensions( + OpenSearchPPLParser.RelationContext ctx, UnresolvedPlan tablePrimary) { + OpenSearchPPLParser.JoinSourceContext joinCtx = ctx.relationExtension().joinSource(); + Join.JoinType joinType; + if (joinCtx.joinType() == null) { + joinType = Join.JoinType.INNER; + } else if (joinCtx.joinType().INNER() != null) { + joinType = Join.JoinType.INNER; + } else if (joinCtx.joinType().LEFT() != null) { + joinType = Join.JoinType.LEFT; + } else if (joinCtx.joinType().RIGHT() != null) { + joinType = Join.JoinType.RIGHT; + } else if (joinCtx.joinType().SEMI() != null) { + joinType = Join.JoinType.SEMI; + } else if (joinCtx.joinType().ANTI() != null) { + joinType = Join.JoinType.ANTI; + } else if (joinCtx.joinType().CROSS() != null) { + joinType = Join.JoinType.CROSS; + } else if (joinCtx.joinType().FULL() != null) { + joinType = Join.JoinType.FULL; + } else { + joinType = Join.JoinType.INNER; + } + UnresolvedExpression joinCondition = this.internalVisitExpression(joinCtx.joinCriteria()); + return new Join(tablePrimary, visitTablePrimary(joinCtx.right), joinType, joinCondition); + } + /** Navigate to & build AST expression. */ private UnresolvedExpression internalVisitExpression(ParseTree tree) { return expressionBuilder.visit(tree);