From 4027b73f14e288c11a2cd752033462bdb75174ac Mon Sep 17 00:00:00 2001 From: "R. C. Howell" Date: Wed, 5 Jun 2024 20:33:59 -0700 Subject: [PATCH] Fixes parsing of signed numeric literals --- .../org/partiql/ast/helpers/ToLegacyAst.kt | 44 +---------- .../lang/syntax/PartiQLParserLiteralTests.kt | 74 ++++++++++++++++++ .../partiql/lang/syntax/PartiQLParserTest.kt | 1 + .../parser/internal/PartiQLParserDefault.kt | 78 +++++++++++++++---- 4 files changed, 142 insertions(+), 55 deletions(-) create mode 100644 partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserLiteralTests.kt diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt index 0dd1a879ff..91506e2fc7 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt @@ -3,14 +3,9 @@ package org.partiql.ast.helpers import com.amazon.ion.Decimal -import com.amazon.ionelement.api.DecimalElement -import com.amazon.ionelement.api.FloatElement -import com.amazon.ionelement.api.IntElement -import com.amazon.ionelement.api.IntElementSize import com.amazon.ionelement.api.MetaContainer import com.amazon.ionelement.api.emptyMetaContainer import com.amazon.ionelement.api.ionDecimal -import com.amazon.ionelement.api.ionFloat import com.amazon.ionelement.api.ionInt import com.amazon.ionelement.api.ionString import com.amazon.ionelement.api.ionSymbol @@ -48,7 +43,6 @@ import org.partiql.value.TimestampValue import org.partiql.value.datetime.TimeZone import org.partiql.value.toIon import java.math.BigDecimal -import java.math.BigInteger /** * Translates an [AstNode] tree to the legacy PIG AST. @@ -322,42 +316,8 @@ private class AstTranslator(val metas: Map) : AstBaseVisi val arg = visitExpr(node.expr, ctx) when (node.op) { Expr.Unary.Op.NOT -> not(arg, metas) - Expr.Unary.Op.POS -> { - when { - arg !is PartiqlAst.Expr.Lit -> pos(arg) - arg.value is IntElement -> arg - arg.value is FloatElement -> arg - arg.value is DecimalElement -> arg - else -> pos(arg) - } - } - Expr.Unary.Op.NEG -> { - when { - arg !is PartiqlAst.Expr.Lit -> neg(arg, metas) - arg.value is IntElement -> { - val intValue = when (arg.value.integerSize) { - IntElementSize.LONG -> ionInt(-arg.value.longValue) - IntElementSize.BIG_INTEGER -> when (arg.value.bigIntegerValue) { - Long.MAX_VALUE.toBigInteger() + (1L).toBigInteger() -> ionInt(Long.MIN_VALUE) - else -> ionInt(arg.value.bigIntegerValue * BigInteger.valueOf(-1L)) - } - } - arg.copy( - value = intValue.asAnyElement(), - metas = metas, - ) - } - arg.value is FloatElement -> arg.copy( - value = ionFloat(-(arg.value.doubleValue)).asAnyElement(), - metas = metas, - ) - arg.value is DecimalElement -> arg.copy( - value = ionDecimal(Decimal.valueOf(-(arg.value.decimalValue))).asAnyElement(), - metas = metas, - ) - else -> neg(arg, metas) - } - } + Expr.Unary.Op.POS -> pos(arg, metas) + Expr.Unary.Op.NEG -> neg(arg, metas) } } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserLiteralTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserLiteralTests.kt new file mode 100644 index 0000000000..22ca8668b6 --- /dev/null +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserLiteralTests.kt @@ -0,0 +1,74 @@ +package org.partiql.lang.syntax + +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource + +class PartiQLParserLiteralTests : PartiQLParserTestBase() { + + override val targets: Array = arrayOf(ParserTarget.EXPERIMENTAL) + + @ParameterizedTest + @MethodSource("cases") + @Execution(ExecutionMode.CONCURRENT) + fun testAll(case: Case) { + assertExpression( + source = case.input, + expectedPigAst = case.expect, + ) + Long.MAX_VALUE + } + + companion object { + + @JvmStatic + fun cases() = listOf( + Case( + input = "1", + expect = "(lit 1)" + ), + Case( + input = "+-1", + expect = "(lit -1)" + ), + Case( + input = "-+1", + expect = "(lit -1)" + ), + Case( + input = "-+-1", + expect = "(lit 1)" + ), + Case( + input = "+++1", + expect = "(lit 1)" + ), + Case( + input = "-1", + expect = "(lit -1)" + ), + Case( + input = "+1", + expect = "(lit 1)" + ), + Case( + input = "9223372036854775808", // Long.MAX_VALUE + 1 + expect = "(lit 9223372036854775808)" + ), + Case( + input = "-9223372036854775809", // Long.MIN_VALUE - 1 + expect = "(lit -9223372036854775809)" + ), + Case( + input = "+9223372036854775808", + expect = "(lit 9223372036854775808)" + ), + ) + } + + class Case( + @JvmField val input: String, + @JvmField val expect: String, + ) +} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt index a72fefa81f..c781ba9be7 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/syntax/PartiQLParserTest.kt @@ -295,6 +295,7 @@ class PartiQLParserTest : PartiQLParserTestBase() { } @Test + @Ignore("Disabled because it's not clear that the parser should be pushing down negations on boxed Ion values") fun unaryIonFloatLiteral() { assertExpression( "+-+-+-`-5e0`", diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt index ac30689f17..354bbfa962 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt @@ -209,6 +209,14 @@ import org.partiql.parser.SourceLocation import org.partiql.parser.SourceLocations import org.partiql.parser.antlr.PartiQLParserBaseVisitor import org.partiql.parser.internal.util.DateTimeUtils +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue import org.partiql.value.NumericValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.StringValue @@ -217,8 +225,12 @@ import org.partiql.value.dateValue import org.partiql.value.datetime.DateTimeException import org.partiql.value.datetime.DateTimeValue import org.partiql.value.decimalValue +import org.partiql.value.float32Value +import org.partiql.value.float64Value +import org.partiql.value.int16Value import org.partiql.value.int32Value import org.partiql.value.int64Value +import org.partiql.value.int8Value import org.partiql.value.intValue import org.partiql.value.missingValue import org.partiql.value.nullValue @@ -569,17 +581,18 @@ internal class PartiQLParserDefault : PartiQLParser { } } - override fun visitQualifiedName(ctx: org.partiql.parser.antlr.PartiQLParser.QualifiedNameContext) = translate(ctx) { - val qualifier = ctx.qualifier.map { visitSymbolPrimitive(it) } - val name = visitSymbolPrimitive(ctx.name) - if (qualifier.isEmpty()) { - name - } else { - val root = qualifier.first() - val steps = qualifier.drop(1) + listOf(name) - identifierQualified(root, steps) + override fun visitQualifiedName(ctx: org.partiql.parser.antlr.PartiQLParser.QualifiedNameContext) = + translate(ctx) { + val qualifier = ctx.qualifier.map { visitSymbolPrimitive(it) } + val name = visitSymbolPrimitive(ctx.name) + if (qualifier.isEmpty()) { + name + } else { + val root = qualifier.first() + val steps = qualifier.drop(1) + listOf(name) + identifierQualified(root, steps) + } } - } /** * @@ -1488,9 +1501,48 @@ internal class PartiQLParserDefault : PartiQLParser { } override fun visitValueExpr(ctx: GeneratedParser.ValueExprContext) = translate(ctx) { - if (ctx.parent != null) return@translate visit(ctx.parent) - val expr = visit(ctx.rhs) as Expr - exprUnary(convertUnaryOp(ctx.sign), expr) + // expression + if (ctx.parent != null) { + return@translate visit(ctx.parent) + } + // unary expression + val op = when (ctx.sign.type) { + GeneratedParser.NOT -> Expr.Unary.Op.NOT + GeneratedParser.PLUS -> Expr.Unary.Op.POS + GeneratedParser.MINUS -> Expr.Unary.Op.NEG + else -> throw error(ctx.sign, "Invalid unary operator") + } + // If argument is not a literal, then return the op. + val arg = visit(ctx.rhs) as Expr + return when (arg) { + is Expr.Lit -> arg.negate(op) + // TODO should we unwrap and negate Ion values for -`-1`? I don't think so.. + is Expr.Ion -> exprUnary(op, arg) + else -> exprUnary(op, arg) + } + } + + private fun Expr.Lit.negate(op: Expr.Unary.Op): Expr { + val v = this.value + return when { + op == Expr.Unary.Op.POS && v is NumericValue<*> -> exprLit(v) + op == Expr.Unary.Op.NEG && v is NumericValue<*> -> exprLit(v.negate()) + else -> exprUnary(op, exprLit(v)) + } + } + + /** + * We might consider a `negate` method on the NumericValue but this is fine for now and is private. + */ + private fun NumericValue<*>.negate(): NumericValue<*> = when (this) { + is DecimalValue -> decimalValue(value?.negate()) + is Float32Value -> float32Value(value?.let { it * -1 }) + is Float64Value -> float64Value(value?.let { it * -1 }) + is Int8Value -> int8Value(value?.let { (it.toInt() * -1).toByte() }) + is Int16Value -> int16Value(value?.let { (it.toInt() * -1).toShort() }) + is Int32Value -> int32Value(value?.let { it * -1 }) + is Int64Value -> int64Value(value?.let { it * -1 }) + is IntValue -> intValue(value?.negate()) } private fun convertBinaryExpr(lhs: ParserRuleContext, rhs: ParserRuleContext, op: Expr.Binary.Op): Expr {