Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fix(isthmus): fix rel converter for sort when slot is wrapped #235

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import io.substrait.expression.Expression;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.calcite.RexVisitorFinder;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.ExpressionRexConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
Expand Down Expand Up @@ -359,8 +360,16 @@ private RelFieldCollation toRelFieldCollation(Expression.SortField sortField) {
var expression = sortField.expr();
var rex = expression.accept(expressionRexConverter);
var sortDirection = sortField.direction();
RexSlot rexSlot = (RexSlot) rex;
int fieldIndex = rexSlot.getIndex();

RexSlot slot =
new RexVisitorFinder<>(RexSlot.class)
.findUnique(rex)
.orElseThrow(
() ->
new RuntimeException(
String.format(
"No slot found in sort field, expression type: %s", rex.getKind())));
int fieldIndex = slot.getIndex();
var fieldDirection = RelFieldCollation.Direction.ASCENDING;
var nullDirection = RelFieldCollation.NullDirection.UNSPECIFIED;
switch (sortDirection) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package io.substrait.isthmus.calcite;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitorImpl;

/**
* Visitor that finds all instances of a given class in a RexNode tree.
*
* @param <T> Class type to find instances.
*/
public class RexVisitorFinder<T> extends RexVisitorImpl<Void> {
final List<T> found;
final Class<T> findClass;

public RexVisitorFinder(Class<T> findClass) {
super(true);
this.found = new ArrayList<>();
this.findClass = findClass;
}

@Override
public void visitEach(Iterable<? extends RexNode> expressions) {
for (RexNode expr : expressions) {
if (findClass.isInstance(expr)) {
found.add(findClass.cast(expr));
}
}
super.visitEach(expressions);
}

/**
* Find all instances of the class in the given call.
*
* @param call The call to search
* @return List of instances of the class
*/
public List<T> find(RexNode call) {
if (findClass.isInstance(call)) {
found.add(findClass.cast(call));
}
call.accept(this);
return found;
}

/**
* Find a unique instance of the class in the given call.
*
* <p>Throws an exception if more than one instance is found.
*
* @param call The call to search
* @return Optional of the instance of the class
*/
public Optional<T> findUnique(RexNode call) {
this.find(call);

if (this.found.isEmpty()) {
return Optional.empty();
}
if (this.found.size() > 1) {
throw new IllegalStateException("Found more than one instance of " + findClass);
}
return Optional.of(this.found.get(0));
}
}
Loading