Skip to content

Commit

Permalink
feat: allow to execute more than one epression + testing custom separ…
Browse files Browse the repository at this point in the history
…ator
  • Loading branch information
davisusanibar committed Feb 5, 2024
1 parent 123d859 commit 10e52c7
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 81 deletions.
5 changes: 4 additions & 1 deletion isthmus/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ isthmus 0.1
$ ./isthmus/build/graal/isthmus --help
Usage: isthmus [-hmV] [--crossjoinpolicy=<crossJoinPolicy>]
[-e=<sqlExpression>] [--outputformat=<outputFormat>]
[-e=<sqlExpression>] [-es=<sqlExpressionSeparator>]
[--outputformat=<outputFormat>]
[--sqlconformancemode=<sqlConformanceMode>]
[-c=<createStatements>]... [<sql>]
Substrait Java Native Image for parsing SQL Query and SQL Expressions
Expand All @@ -42,6 +43,8 @@ Substrait Java Native Image for parsing SQL Query and SQL Expressions
KEEP_AS_CROSS_JOIN, CONVERT_TO_INNER_JOIN
-e, --expression=<sqlExpression>
The sql expression we should parse.
-es, --separator=<sqlExpressionSeparator>
The separator for the sql expressions.
-h, --help Show this help message and exit.
-m, --multistatement Allow multiple statements terminated with a semicolon
--outputformat=<outputFormat>
Expand Down
44 changes: 18 additions & 26 deletions isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
import io.substrait.proto.ExtendedExpression;
import io.substrait.proto.Plan;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import picocli.CommandLine;

Expand All @@ -34,6 +32,12 @@ public class IsthmusEntryPoint implements Callable<Integer> {
description = "The sql expression we should parse.")
private String sqlExpression;

@Option(
names = {"-es", "--separator"},
defaultValue = ",",
description = "The separator for the sql expressions.")
private String sqlExpressionSeparator;

@Option(
names = {"-c", "--create"},
description =
Expand Down Expand Up @@ -87,37 +91,25 @@ public static void main(String... args) {
System.exit(exitCode);
}

private FeatureBoard featureBoard;

@Override
public Integer call() throws Exception {
this.featureBoard = buildFeatureBoard();
FeatureBoard featureBoard = buildFeatureBoard();
// Isthmus image is parsing SQL Expression if that argument is defined
if (sqlExpression != null) {
handleSQLExpression();
} else {
handleSQLPlan();
SqlExpressionToSubstrait converter =
new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults());
ExtendedExpression extendedExpression =
converter.convert(sqlExpression, sqlExpressionSeparator, createStatements);
printMessage(extendedExpression);
} else { // by default Isthmus image are parsing SQL Query
SqlToSubstrait converter = new SqlToSubstrait(featureBoard);
Plan plan = converter.execute(sql, createStatements);
printMessage(plan);
}
return 0;
}

private void handleSQLExpression() throws SqlParseException, IOException {
ExtendedExpression extendedExpression = createExpression();
printExpression(extendedExpression);
}

private void handleSQLPlan() throws SqlParseException, IOException {
SqlToSubstrait converter = new SqlToSubstrait(featureBoard);
Plan plan = converter.execute(sql, createStatements);
printExpression(plan);
}

private ExtendedExpression createExpression() throws IOException, SqlParseException {
SqlExpressionToSubstrait converter =
new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults());
return converter.convert(Arrays.asList(sqlExpression.split(",")), createStatements);
}

private void printExpression(Message message) throws IOException {
private void printMessage(Message message) throws IOException {
switch (outputFormat) {
case PROTOJSON -> System.out.println(
JsonFormat.printer().includingDefaultValueFields().print(message));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
Expand Down Expand Up @@ -62,13 +62,22 @@ private record Result(
*/
public ExtendedExpression convert(String sqlExpression, List<String> createStatements)
throws SqlParseException {
var result = registerCreateTablesForExtendedExpression(createStatements);
return executeInnerSQLExpression(
sqlExpression,
result.validator(),
result.catalogReader(),
result.nameToTypeMap(),
result.nameToNodeMap());
return convert(sqlExpression, ",", createStatements);
}

/**
* Converts the given SQL expression string to an {@link io.substrait.proto.ExtendedExpression }
*
* @param sqlExpression a SQL expression
* @param separator the separator for the sql expressions
* @param createStatements table creation statements defining fields referenced by the expression
* @return a {@link io.substrait.proto.ExtendedExpression }
* @throws SqlParseException
*/
public ExtendedExpression convert(
String sqlExpression, String separator, List<String> createStatements)
throws SqlParseException {
return convert(Arrays.asList(sqlExpression.split(separator)), createStatements);
}

/**
Expand All @@ -90,21 +99,6 @@ public ExtendedExpression convert(List<String> sqlExpressions, List<String> crea
result.nameToNodeMap());
}

private ExtendedExpression executeInnerSQLExpression(
String sqlExpression,
SqlValidator validator,
CalciteCatalogReader catalogReader,
Map<String, RelDataType> nameToTypeMap,
Map<String, RexNode> nameToNodeMap)
throws SqlParseException {
return executeInnerSQLExpressions(
Collections.singletonList(sqlExpression),
validator,
catalogReader,
nameToTypeMap,
nameToNodeMap);
}

private ExtendedExpression executeInnerSQLExpressions(
List<String> sqlExpressions,
SqlValidator validator,
Expand All @@ -117,7 +111,9 @@ private ExtendedExpression executeInnerSQLExpressions(
expressionReferences = new ArrayList<>();
RexNode rexNode;
for (String sqlExpression : sqlExpressions) {
rexNode = sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap);
rexNode =
sqlToRexNode(
sqlExpression.trim(), validator, catalogReader, nameToTypeMap, nameToNodeMap);
ImmutableExpressionReference expressionReference =
ImmutableExpressionReference.builder()
.expression(rexNode.accept(this.rexConverter))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,42 @@ public static List<String> tpchSchemaCreateStatements() throws IOException {
return tpchSchemaCreateStatements("tpch/schema.sql");
}

protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(String query)
throws IOException, SqlParseException {
return assertProtoExtendedExpressionRoundtrip(query, new SqlExpressionToSubstrait());
}

protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
String query, String schemaToLoad) throws IOException, SqlParseException {
return assertProtoExtendedExpressionRoundtrip(
query, new SqlExpressionToSubstrait(), schemaToLoad);
protected void assertProtoEEForExpressionsDefaultCommaSeparatorRoundtrip(String expressions)
throws SqlParseException, IOException {
// proto initial extended expression
ExtendedExpression extendedExpressionProtoInitial =
new SqlExpressionToSubstrait().convert(expressions, tpchSchemaCreateStatements());
asserProtoExtendedExpression(extendedExpressionProtoInitial);
}

protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException {
return assertProtoExtendedExpressionRoundtrip(query, s, tpchSchemaCreateStatements());
protected void assertProtoEEForExpressionsDefaultCommaSeparatorErrorRoundtrip(
String expressions, String schemaToLoad) throws SqlParseException, IOException {
// proto initial extended expression
ExtendedExpression extendedExpressionProtoInitial =
new SqlExpressionToSubstrait()
.convert(expressions, tpchSchemaCreateStatements(schemaToLoad));
asserProtoExtendedExpression(extendedExpressionProtoInitial);
}

protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
String query, SqlExpressionToSubstrait s, String schemaToLoad)
throws IOException, SqlParseException {
return assertProtoExtendedExpressionRoundtrip(
query, s, tpchSchemaCreateStatements(schemaToLoad));
protected void assertProtoEEForExpressionsCustomSeparatorRoundtrip(
String expressions, String separator) throws SqlParseException, IOException {
// proto initial extended expression
ExtendedExpression extendedExpressionProtoInitial =
new SqlExpressionToSubstrait()
.convert(expressions, separator, tpchSchemaCreateStatements());
asserProtoExtendedExpression(extendedExpressionProtoInitial);
}

protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(
String query, SqlExpressionToSubstrait s, List<String> creates)
protected void assertProtoEEForListExpressionRoundtrip(List<String> expression)
throws SqlParseException, IOException {
// proto initial extended expression
ExtendedExpression extendedExpressionProtoInitial = s.convert(query, creates);
ExtendedExpression extendedExpressionProtoInitial =
new SqlExpressionToSubstrait().convert(expression, tpchSchemaCreateStatements());
asserProtoExtendedExpression(extendedExpressionProtoInitial);
}

private static void asserProtoExtendedExpression(
ExtendedExpression extendedExpressionProtoInitial) throws IOException {
// pojo final extended expression
io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal =
new ProtoExtendedExpressionConverter().from(extendedExpressionProtoInitial);
Expand All @@ -66,7 +73,5 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(

// round-trip to validate extended expression proto initial equals to final
Assertions.assertEquals(extendedExpressionProtoFinal, extendedExpressionProtoInitial);

return extendedExpressionProtoInitial;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
Expand All @@ -18,30 +21,55 @@ private static Stream<Arguments> expressionTypeProvider() {
Arguments.of("L_ORDERKEY"), // FieldReferenceExpression
Arguments.of("L_ORDERKEY > 10"), // ScalarFunctionExpressionFilter
Arguments.of("L_ORDERKEY + 10"), // ScalarFunctionExpressionProjection
Arguments.of("L_ORDERKEY IN (10, 20)"), // ScalarFunctionExpressionIn
Arguments.of("L_ORDERKEY IN (10)"), // ScalarFunctionExpressionIn
Arguments.of("L_ORDERKEY is not null"), // ScalarFunctionExpressionIsNotNull
Arguments.of("L_ORDERKEY is null"), // ScalarFunctionExpressionIsNull
Arguments.of("L_ORDERKEY + 10", "L_ORDERKEY * 2"),
Arguments.of("L_ORDERKEY + 10", "L_ORDERKEY * 2", "L_ORDERKEY > 10"));
Arguments.of("L_ORDERKEY is null")); // ScalarFunctionExpressionIsNull
}

@ParameterizedTest
@MethodSource("expressionTypeProvider")
public void testExtendedExpressionsRoundTrip(String sqlExpression)
public void testExtendedExpressionsCommaSeparatorRoundTrip(String sqlExpression)
throws SqlParseException, IOException {
assertProtoExtendedExpressionRoundtrip(sqlExpression);
assertProtoEEForExpressionsDefaultCommaSeparatorRoundtrip(
sqlExpression); // comma-separator by default
}

@ParameterizedTest
@MethodSource("expressionTypeProvider")
public void testExtendedExpressionsRoundTripDuplicateColumnIdentifier(String sqlExpression) {
public void testExtendedExpressionsDuplicateColumnIdentifierRoundTrip(String sqlExpression) {
IllegalArgumentException illegalArgumentException =
assertThrows(
IllegalArgumentException.class,
() -> assertProtoExtendedExpressionRoundtrip(sqlExpression, "tpch/schema_error.sql"));
() ->
assertProtoEEForExpressionsDefaultCommaSeparatorErrorRoundtrip(
sqlExpression, "tpch/schema_error.sql"));
assertTrue(
illegalArgumentException
.getMessage()
.startsWith("There is no support for duplicate column names"));
}

@Test
public void testExtendedExpressionsCustomSeparatorRoundTrip()
throws SqlParseException, IOException {
String expressions =
"2#L_ORDERKEY#L_ORDERKEY > 10#L_ORDERKEY + 10#L_ORDERKEY IN (10, 20)#L_ORDERKEY is not null#L_ORDERKEY is null";
String separator = "#";
assertProtoEEForExpressionsCustomSeparatorRoundtrip(expressions, separator);
}

@Test
public void testExtendedExpressionsListExpressionRoundTrip()
throws SqlParseException, IOException {
List<String> expressions =
Arrays.asList(
"2",
"L_ORDERKEY",
"L_ORDERKEY > 10",
"L_ORDERKEY + 10",
"L_ORDERKEY IN (10, 20)", // the comma won't cause any problems
"L_ORDERKEY is not null",
"L_ORDERKEY is null");
assertProtoEEForListExpressionRoundtrip(expressions);
}
}
3 changes: 3 additions & 0 deletions isthmus/src/test/script/smoke.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ $CMD --expression 'l_orderkey + 9888486986' --create "${LINEITEM}"

# SQL Expression - 03 Projection expression (column-1, column-2, column-3)
$CMD --expression 'l_orderkey + 9888486986, l_orderkey * 2, l_orderkey > 10' --create "${LINEITEM}"

# SQL Expression - 03 Projection expression (column-1, column-2, column-3) with custom seprator
$CMD --expression 'l_orderkey + 9888486986 # l_orderkey * 2 # l_orderkey > 10' --create "${LINEITEM}" --separator "#"

0 comments on commit 10e52c7

Please sign in to comment.