From 10e52c7ca08d53b1ee8129577b9918414fda8ada Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Mon, 5 Feb 2024 12:08:06 -0500 Subject: [PATCH] feat: allow to execute more than one epression + testing custom separator --- isthmus/README.md | 5 +- .../substrait/isthmus/IsthmusEntryPoint.java | 44 +++++++---------- .../isthmus/SqlExpressionToSubstrait.java | 44 ++++++++--------- .../isthmus/ExtendedExpressionTestBase.java | 49 ++++++++++--------- .../SimpleExtendedExpressionsTest.java | 44 ++++++++++++++--- isthmus/src/test/script/smoke.sh | 3 ++ 6 files changed, 108 insertions(+), 81 deletions(-) diff --git a/isthmus/README.md b/isthmus/README.md index 37e0b2818..a0e4c8aad 100644 --- a/isthmus/README.md +++ b/isthmus/README.md @@ -29,7 +29,8 @@ isthmus 0.1 $ ./isthmus/build/graal/isthmus --help Usage: isthmus [-hmV] [--crossjoinpolicy=] - [-e=] [--outputformat=] + [-e=] [-es=] + [--outputformat=] [--sqlconformancemode=] [-c=]... [] Substrait Java Native Image for parsing SQL Query and SQL Expressions @@ -42,6 +43,8 @@ Substrait Java Native Image for parsing SQL Query and SQL Expressions KEEP_AS_CROSS_JOIN, CONVERT_TO_INNER_JOIN -e, --expression= The sql expression we should parse. + -es, --separator= + The separator for the sql expressions. -h, --help Show this help message and exit. -m, --multistatement Allow multiple statements terminated with a semicolon --outputformat= diff --git a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java index 80bcd627b..50173de47 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java +++ b/isthmus/src/main/java/io/substrait/isthmus/IsthmusEntryPoint.java @@ -13,10 +13,8 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import java.io.IOException; -import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; -import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.validate.SqlConformanceEnum; import picocli.CommandLine; @@ -34,6 +32,12 @@ public class IsthmusEntryPoint implements Callable { description = "The sql expression we should parse.") private String sqlExpression; + @Option( + names = {"-es", "--separator"}, + defaultValue = ",", + description = "The separator for the sql expressions.") + private String sqlExpressionSeparator; + @Option( names = {"-c", "--create"}, description = @@ -87,37 +91,25 @@ public static void main(String... args) { System.exit(exitCode); } - private FeatureBoard featureBoard; - @Override public Integer call() throws Exception { - this.featureBoard = buildFeatureBoard(); + FeatureBoard featureBoard = buildFeatureBoard(); + // Isthmus image is parsing SQL Expression if that argument is defined if (sqlExpression != null) { - handleSQLExpression(); - } else { - handleSQLPlan(); + SqlExpressionToSubstrait converter = + new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults()); + ExtendedExpression extendedExpression = + converter.convert(sqlExpression, sqlExpressionSeparator, createStatements); + printMessage(extendedExpression); + } else { // by default Isthmus image are parsing SQL Query + SqlToSubstrait converter = new SqlToSubstrait(featureBoard); + Plan plan = converter.execute(sql, createStatements); + printMessage(plan); } return 0; } - private void handleSQLExpression() throws SqlParseException, IOException { - ExtendedExpression extendedExpression = createExpression(); - printExpression(extendedExpression); - } - - private void handleSQLPlan() throws SqlParseException, IOException { - SqlToSubstrait converter = new SqlToSubstrait(featureBoard); - Plan plan = converter.execute(sql, createStatements); - printExpression(plan); - } - - private ExtendedExpression createExpression() throws IOException, SqlParseException { - SqlExpressionToSubstrait converter = - new SqlExpressionToSubstrait(featureBoard, SimpleExtension.loadDefaults()); - return converter.convert(Arrays.asList(sqlExpression.split(",")), createStatements); - } - - private void printExpression(Message message) throws IOException { + private void printMessage(Message message) throws IOException { switch (outputFormat) { case PROTOJSON -> System.out.println( JsonFormat.printer().includingDefaultValueFields().print(message)); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index ce96a0ac5..9a75a3382 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -11,7 +11,7 @@ import io.substrait.type.NamedStruct; import io.substrait.type.Type; import java.util.ArrayList; -import java.util.Collections; +import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; @@ -62,13 +62,22 @@ private record Result( */ public ExtendedExpression convert(String sqlExpression, List createStatements) throws SqlParseException { - var result = registerCreateTablesForExtendedExpression(createStatements); - return executeInnerSQLExpression( - sqlExpression, - result.validator(), - result.catalogReader(), - result.nameToTypeMap(), - result.nameToNodeMap()); + return convert(sqlExpression, ",", createStatements); + } + + /** + * Converts the given SQL expression string to an {@link io.substrait.proto.ExtendedExpression } + * + * @param sqlExpression a SQL expression + * @param separator the separator for the sql expressions + * @param createStatements table creation statements defining fields referenced by the expression + * @return a {@link io.substrait.proto.ExtendedExpression } + * @throws SqlParseException + */ + public ExtendedExpression convert( + String sqlExpression, String separator, List createStatements) + throws SqlParseException { + return convert(Arrays.asList(sqlExpression.split(separator)), createStatements); } /** @@ -90,21 +99,6 @@ public ExtendedExpression convert(List sqlExpressions, List crea result.nameToNodeMap()); } - private ExtendedExpression executeInnerSQLExpression( - String sqlExpression, - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) - throws SqlParseException { - return executeInnerSQLExpressions( - Collections.singletonList(sqlExpression), - validator, - catalogReader, - nameToTypeMap, - nameToNodeMap); - } - private ExtendedExpression executeInnerSQLExpressions( List sqlExpressions, SqlValidator validator, @@ -117,7 +111,9 @@ private ExtendedExpression executeInnerSQLExpressions( expressionReferences = new ArrayList<>(); RexNode rexNode; for (String sqlExpression : sqlExpressions) { - rexNode = sqlToRexNode(sqlExpression, validator, catalogReader, nameToTypeMap, nameToNodeMap); + rexNode = + sqlToRexNode( + sqlExpression.trim(), validator, catalogReader, nameToTypeMap, nameToNodeMap); ImmutableExpressionReference expressionReference = ImmutableExpressionReference.builder() .expression(rexNode.accept(this.rexConverter)) diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index d47abcc77..fe29d6215 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -27,35 +27,42 @@ public static List tpchSchemaCreateStatements() throws IOException { return tpchSchemaCreateStatements("tpch/schema.sql"); } - protected ExtendedExpression assertProtoExtendedExpressionRoundtrip(String query) - throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundtrip(query, new SqlExpressionToSubstrait()); - } - - protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( - String query, String schemaToLoad) throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundtrip( - query, new SqlExpressionToSubstrait(), schemaToLoad); + protected void assertProtoEEForExpressionsDefaultCommaSeparatorRoundtrip(String expressions) + throws SqlParseException, IOException { + // proto initial extended expression + ExtendedExpression extendedExpressionProtoInitial = + new SqlExpressionToSubstrait().convert(expressions, tpchSchemaCreateStatements()); + asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( - String query, SqlExpressionToSubstrait s) throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundtrip(query, s, tpchSchemaCreateStatements()); + protected void assertProtoEEForExpressionsDefaultCommaSeparatorErrorRoundtrip( + String expressions, String schemaToLoad) throws SqlParseException, IOException { + // proto initial extended expression + ExtendedExpression extendedExpressionProtoInitial = + new SqlExpressionToSubstrait() + .convert(expressions, tpchSchemaCreateStatements(schemaToLoad)); + asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( - String query, SqlExpressionToSubstrait s, String schemaToLoad) - throws IOException, SqlParseException { - return assertProtoExtendedExpressionRoundtrip( - query, s, tpchSchemaCreateStatements(schemaToLoad)); + protected void assertProtoEEForExpressionsCustomSeparatorRoundtrip( + String expressions, String separator) throws SqlParseException, IOException { + // proto initial extended expression + ExtendedExpression extendedExpressionProtoInitial = + new SqlExpressionToSubstrait() + .convert(expressions, separator, tpchSchemaCreateStatements()); + asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( - String query, SqlExpressionToSubstrait s, List creates) + protected void assertProtoEEForListExpressionRoundtrip(List expression) throws SqlParseException, IOException { // proto initial extended expression - ExtendedExpression extendedExpressionProtoInitial = s.convert(query, creates); + ExtendedExpression extendedExpressionProtoInitial = + new SqlExpressionToSubstrait().convert(expression, tpchSchemaCreateStatements()); + asserProtoExtendedExpression(extendedExpressionProtoInitial); + } + private static void asserProtoExtendedExpression( + ExtendedExpression extendedExpressionProtoInitial) throws IOException { // pojo final extended expression io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = new ProtoExtendedExpressionConverter().from(extendedExpressionProtoInitial); @@ -66,7 +73,5 @@ protected ExtendedExpression assertProtoExtendedExpressionRoundtrip( // round-trip to validate extended expression proto initial equals to final Assertions.assertEquals(extendedExpressionProtoFinal, extendedExpressionProtoInitial); - - return extendedExpressionProtoInitial; } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 5a622f769..4c38b267b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -4,8 +4,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import java.util.stream.Stream; import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -18,30 +21,55 @@ private static Stream expressionTypeProvider() { Arguments.of("L_ORDERKEY"), // FieldReferenceExpression Arguments.of("L_ORDERKEY > 10"), // ScalarFunctionExpressionFilter Arguments.of("L_ORDERKEY + 10"), // ScalarFunctionExpressionProjection - Arguments.of("L_ORDERKEY IN (10, 20)"), // ScalarFunctionExpressionIn + Arguments.of("L_ORDERKEY IN (10)"), // ScalarFunctionExpressionIn Arguments.of("L_ORDERKEY is not null"), // ScalarFunctionExpressionIsNotNull - Arguments.of("L_ORDERKEY is null"), // ScalarFunctionExpressionIsNull - Arguments.of("L_ORDERKEY + 10", "L_ORDERKEY * 2"), - Arguments.of("L_ORDERKEY + 10", "L_ORDERKEY * 2", "L_ORDERKEY > 10")); + Arguments.of("L_ORDERKEY is null")); // ScalarFunctionExpressionIsNull } @ParameterizedTest @MethodSource("expressionTypeProvider") - public void testExtendedExpressionsRoundTrip(String sqlExpression) + public void testExtendedExpressionsCommaSeparatorRoundTrip(String sqlExpression) throws SqlParseException, IOException { - assertProtoExtendedExpressionRoundtrip(sqlExpression); + assertProtoEEForExpressionsDefaultCommaSeparatorRoundtrip( + sqlExpression); // comma-separator by default } @ParameterizedTest @MethodSource("expressionTypeProvider") - public void testExtendedExpressionsRoundTripDuplicateColumnIdentifier(String sqlExpression) { + public void testExtendedExpressionsDuplicateColumnIdentifierRoundTrip(String sqlExpression) { IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, - () -> assertProtoExtendedExpressionRoundtrip(sqlExpression, "tpch/schema_error.sql")); + () -> + assertProtoEEForExpressionsDefaultCommaSeparatorErrorRoundtrip( + sqlExpression, "tpch/schema_error.sql")); assertTrue( illegalArgumentException .getMessage() .startsWith("There is no support for duplicate column names")); } + + @Test + public void testExtendedExpressionsCustomSeparatorRoundTrip() + throws SqlParseException, IOException { + String expressions = + "2#L_ORDERKEY#L_ORDERKEY > 10#L_ORDERKEY + 10#L_ORDERKEY IN (10, 20)#L_ORDERKEY is not null#L_ORDERKEY is null"; + String separator = "#"; + assertProtoEEForExpressionsCustomSeparatorRoundtrip(expressions, separator); + } + + @Test + public void testExtendedExpressionsListExpressionRoundTrip() + throws SqlParseException, IOException { + List expressions = + Arrays.asList( + "2", + "L_ORDERKEY", + "L_ORDERKEY > 10", + "L_ORDERKEY + 10", + "L_ORDERKEY IN (10, 20)", // the comma won't cause any problems + "L_ORDERKEY is not null", + "L_ORDERKEY is null"); + assertProtoEEForListExpressionRoundtrip(expressions); + } } diff --git a/isthmus/src/test/script/smoke.sh b/isthmus/src/test/script/smoke.sh index 76859204e..f64f7417d 100755 --- a/isthmus/src/test/script/smoke.sh +++ b/isthmus/src/test/script/smoke.sh @@ -29,3 +29,6 @@ $CMD --expression 'l_orderkey + 9888486986' --create "${LINEITEM}" # SQL Expression - 03 Projection expression (column-1, column-2, column-3) $CMD --expression 'l_orderkey + 9888486986, l_orderkey * 2, l_orderkey > 10' --create "${LINEITEM}" + +# SQL Expression - 03 Projection expression (column-1, column-2, column-3) with custom seprator +$CMD --expression 'l_orderkey + 9888486986 # l_orderkey * 2 # l_orderkey > 10' --create "${LINEITEM}" --separator "#"