Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support two tables join PPL command #2950

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.opensearch.sql.ast.tree.FetchCursor;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Join;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.ML;
Expand Down Expand Up @@ -88,6 +89,7 @@
import org.opensearch.sql.planner.logical.LogicalEval;
import org.opensearch.sql.planner.logical.LogicalFetchCursor;
import org.opensearch.sql.planner.logical.LogicalFilter;
import org.opensearch.sql.planner.logical.LogicalJoin;
import org.opensearch.sql.planner.logical.LogicalLimit;
import org.opensearch.sql.planner.logical.LogicalML;
import org.opensearch.sql.planner.logical.LogicalMLCommons;
Expand Down Expand Up @@ -136,6 +138,18 @@ public LogicalPlan analyze(UnresolvedPlan unresolved, AnalysisContext context) {
return unresolved.accept(this, context);
}

@Override
public LogicalPlan visitJoin(Join node, AnalysisContext context) {
// TODO tables-join instead of plans-join supported only now
LogicalPlan left = visitRelation((Relation) node.getLeft(), context);
LogicalPlan right = visitRelation((Relation) node.getRight(), context);
Expression condition = expressionAnalyzer.analyze(node.getJoinCondition(), context);
ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), left, right);
Expression optimized = optimizer.optimize(condition, context);
return new LogicalJoin(left, right, node.getJoinType(), optimized);
}

@Override
public LogicalPlan visitRelation(Relation node, AnalysisContext context) {
QualifiedName qualifiedName = node.getTableQualifiedName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@

package org.opensearch.sql.analysis;

import static org.opensearch.sql.common.utils.StringUtils.format;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.ExpressionNodeVisitor;
import org.opensearch.sql.expression.FunctionExpression;
Expand All @@ -22,6 +27,7 @@
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor;
import org.opensearch.sql.planner.logical.LogicalRelation;
import org.opensearch.sql.planner.logical.LogicalWindow;

/**
Expand All @@ -47,12 +53,34 @@ public class ExpressionReferenceOptimizer
*/
private final Map<Expression, Expression> expressionMap = new HashMap<>();

private String leftRelationName;
private String rightRelationName;
private Set<String> leftSideAttributes;
private Set<String> rightSideAttributes;

public ExpressionReferenceOptimizer(
BuiltinFunctionRepository repository, LogicalPlan logicalPlan) {
this.repository = repository;
logicalPlan.accept(new ExpressionMapBuilder(), null);
}

public ExpressionReferenceOptimizer(
BuiltinFunctionRepository repository, LogicalPlan... logicalPlans) {
this.repository = repository;
// To resolve join condition, we store left side and left side of join.
if (logicalPlans.length == 2) {
// TODO current implementation only support two-tables join, so we can directly convert them
// to LogicalRelation. To support two-plans join, we can get the LogicalRelation by searching.
this.leftRelationName = ((LogicalRelation) logicalPlans[0]).getRelationName();
this.rightRelationName = ((LogicalRelation) logicalPlans[1]).getRelationName();
this.leftSideAttributes =
((LogicalRelation) logicalPlans[0]).getTable().getFieldTypes().keySet();
this.rightSideAttributes =
((LogicalRelation) logicalPlans[1]).getTable().getFieldTypes().keySet();
}
Arrays.stream(logicalPlans).forEach(p -> p.accept(new ExpressionMapBuilder(), null));
}

public Expression optimize(Expression analyzed, AnalysisContext context) {
return analyzed.accept(this, context);
}
Expand All @@ -62,6 +90,45 @@ public Expression visitNode(Expression node, AnalysisContext context) {
return node;
}

/**
* Add index prefix to reference attribute of join condition. The attribute could be: case 1:
* Field -> Index.Field case 2: Field.Field -> Index.Field.Field case 3: .Index.Field,
* .Index.Field.Field -> do nothing case 4: Index.Field, Index.Field.Field -> do nothing
*/
@Override
public Expression visitReference(ReferenceExpression node, AnalysisContext context) {
if (leftRelationName == null || rightRelationName == null) {
return node;
}

String attr = node.getAttr();
// case 1 or case 2
if (!attr.contains(".") || (!attr.startsWith(".") && !isIndexPrefix(attr))) {
return replaceReferenceExpressionWithIndexPrefix(node, attr);
}
return node;
}

private ReferenceExpression replaceReferenceExpressionWithIndexPrefix(
ReferenceExpression node, String attr) {
if (leftSideAttributes.contains(attr) && rightSideAttributes.contains(attr)) {
throw new SemanticCheckException(format("Reference `%s` is ambiguous", attr));
} else if (leftSideAttributes.contains(attr)) {
return new ReferenceExpression(format("%s.%s", leftRelationName, attr), node.type());
} else if (rightSideAttributes.contains(attr)) {
return new ReferenceExpression(format("%s.%s", rightRelationName, attr), node.type());
} else {
return node;
}
}

private boolean isIndexPrefix(String attr) {
int separator = attr.indexOf('.');
String possibleIndexPrefix = attr.substring(0, separator);
return leftRelationName.contains(possibleIndexPrefix)
|| rightRelationName.contains(possibleIndexPrefix);
}

@Override
public Expression visitFunction(FunctionExpression node, AnalysisContext context) {
if (expressionMap.containsKey(node)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ public Optional<ExprType> lookup(Symbol symbol) {
Map<String, ExprType> table = tableByNamespace.get(symbol.getNamespace());
ExprType type = null;
if (table != null) {
// To handle the field named start with [index.], for example index1.field1,
// this is used by Join query.
if (symbol.getNamespace() == Namespace.FIELD_NAME) {
String[] parts = symbol.getName().split("\\.");
if (parts.length == 2) {
// extract the indexName
if (tableByNamespace.get(Namespace.INDEX_NAME) != null) {
String indexName = tableByNamespace.get(Namespace.INDEX_NAME).firstKey();
if (indexName != null && indexName.equals(parts[0])) {
type = table.get(parts[1]);
return Optional.ofNullable(type);
}
}
}
}
type = table.get(symbol.getName());
}
return Optional.ofNullable(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.sql.ast.tree.FetchCursor;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Join;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.ML;
Expand Down Expand Up @@ -109,6 +110,10 @@ public T visitFilter(Filter node, C context) {
return visitChildren(node, context);
}

public T visitJoin(Join node, C context) {
return visitChildren(node, context);
}

public T visitProject(Project node, C context) {
return visitChildren(node, context);
}
Expand Down
9 changes: 9 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Join;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Parse;
import org.opensearch.sql.ast.tree.Project;
Expand Down Expand Up @@ -471,4 +472,12 @@ public static Parse parse(
java.util.Map<String, Literal> arguments) {
return new Parse(parseMethod, sourceField, pattern, arguments, input);
}

public static Join join(
UnresolvedPlan left,
UnresolvedPlan right,
Join.JoinType joinType,
UnresolvedExpression condition) {
return new Join(left, right, joinType, condition);
}
}
51 changes: 51 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Join.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

@RequiredArgsConstructor
@Getter
@EqualsAndHashCode(callSuper = false)
@ToString
public class Join extends UnresolvedPlan {
private final UnresolvedPlan left;
private final UnresolvedPlan right;
private final JoinType joinType;
private final UnresolvedExpression joinCondition;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
return this;
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(left, right);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitJoin(this, context);
}

public enum JoinType {
INNER,
LEFT,
RIGHT,
SEMI,
ANTI,
CROSS,
FULL
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ public static Optional<BuiltinFunctionName> of(String str) {
return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null));
}

public static BuiltinFunctionName of(FunctionName name) {
return ALL_NATIVE_FUNCTIONS.get(name);
}

public static Optional<BuiltinFunctionName> ofAggregation(String functionName) {
return Optional.ofNullable(
AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null));
Expand Down
Loading
Loading