Skip to content

Commit

Permalink
feat: re-implement using a REINTERPRET special operator
Browse files Browse the repository at this point in the history
This commit re-implements the UDT literal conversion to calcite
using REINTERPRET, a calcite sql special operator, wrapped around
the binary literal. This is needed to pass Calcite type checking.
  • Loading branch information
cheikhachraf committed Feb 16, 2024
1 parent 8d46460 commit 9329550
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.substrait.expression.FunctionArg;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.Rel;
Expand Down Expand Up @@ -222,17 +223,15 @@ public Expression visit(io.substrait.expression.Expression.StructLiteral expr) {
@Override
public Expression visit(io.substrait.expression.Expression.UserDefinedLiteral expr) {

var typeReference =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name()));
return lit(
bldr -> {
try {
bldr.setNullable(expr.nullable())
.setUserDefined(
Expression.Literal.UserDefined.newBuilder()
.setTypeReference(
expr.getType()
.accept(typeProtoConverter)
.getUserDefined()
.getTypeReference())
.setTypeReference(typeReference)
.setValue(Any.parseFrom(expr.value())))
.build();
} catch (InvalidProtocolBufferException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.ImmutableExpression;
import io.substrait.isthmus.*;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -32,6 +33,37 @@ public class CallConverters {
visitor.apply(call.getOperands().get(0)));
};

public static Function<TypeConverter, SimpleCallConverter> REINTERPRET =
typeConverter ->
(call, visitor) -> {
if (call.getKind() != SqlKind.REINTERPRET) {
return null;
}

var operand = visitor.apply(call.getOperands().get(0));
var type = typeConverter.toSubstrait(call.getType());

// for now, we only support reinterpretation of fixed binary literals to user defined
// type literals
// this is a needed workaround as calcite does not support user defined type literals
// and has
// strict type checking for literals, specifically checking if the value matches the
// calcite.sql.type.SqlTypeName
// note: This is tightly coupled to
// ExpressionRexConverter.visit(Expression.UserDefinedLiteral expr)
// if we ever start accepting other ways to encode user defined type literals (e.g.
// structured UDTs), this will need to be updated
if (operand instanceof Expression.FixedBinaryLiteral literal
&& type instanceof Type.UserDefined t) {
return Expression.UserDefinedLiteral.builder()
.uri(t.uri())
.name(t.name())
.value(literal.value())
.build();
}
return null;
};

// public static SimpleCallConverter OrAnd(FunctionConverter c) {
// return (call, visitor) -> {
// if (call.getKind() != SqlKind.AND && call.getKind() != SqlKind.OR) {
Expand Down Expand Up @@ -93,6 +125,7 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
new FieldSelectionConverter(typeConverter),
CallConverters.CASE,
CallConverters.CAST.apply(typeConverter),
CallConverters.REINTERPRET.apply(typeConverter),
new LiteralConstructorConverter());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ public RexNode visit(Expression.NullLiteral expr) throws RuntimeException {

@Override
public RexNode visit(Expression.UserDefinedLiteral expr) throws RuntimeException {
return rexBuilder.makeLiteral(
new ByteString(expr.value().toByteArray()),
typeConverter.toCalcite(typeFactory, expr.getType()));
var binaryLiteral = rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteArray()));
return rexBuilder.makeReinterpretCast(
typeConverter.toCalcite(typeFactory, expr.getType()), binaryLiteral, null);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,6 @@ public Expression.Literal convert(RexLiteral literal) {
return typedNull(type);
}

if (type instanceof Type.UserDefined t) {
return Expression.UserDefinedLiteral.builder()
.uri(t.uri())
.name(t.name())
.value(ByteString.copyFrom(literal.getValueAs(byte[].class)))
.build();
}

return switch (literal.getType().getSqlTypeName()) {
case TINYINT -> i8(n, i(literal).intValue());
case SMALLINT -> i16(n, i(literal).intValue());
Expand Down
25 changes: 25 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import com.google.protobuf.Any;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.ExpressionCreator;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.FunctionMappings;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.isthmus.utils.UserTypeFactory;
import io.substrait.proto.Expression;
import io.substrait.relation.Rel;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
Expand Down Expand Up @@ -232,4 +235,26 @@ void customTypesInFunctionsRoundtrip() {
var relReturned = calciteToSubstrait.apply(calciteRel);
assertEquals(rel, relReturned);
}

@Test
void customTypesLiteralInFunctionsRoundtrip() {

var bldr = Expression.Literal.newBuilder();
var anyValue = Any.pack(bldr.setI32(10).build());
var val = ExpressionCreator.userDefinedLiteral(false, anyValue, "a_type", NAMESPACE);

Rel rel =
b.project(
input ->
List.of(
b.scalarFn(
NAMESPACE, "to_b_type:u!a_type", R.userDefined(NAMESPACE, "b_type"), val)),
b.remap(1),
b.namedScan(
List.of("example"), List.of("a"), List.of(N.userDefined(NAMESPACE, "a_type"))));

RelNode calciteRel = substraitToCalcite.convert(rel);
var relReturned = calciteToSubstrait.apply(calciteRel);
assertEquals(rel, relReturned);
}
}

0 comments on commit 9329550

Please sign in to comment.