Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes parsing of signed numeric literals #1484

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 2 additions & 42 deletions partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,9 @@
package org.partiql.ast.helpers

import com.amazon.ion.Decimal
import com.amazon.ionelement.api.DecimalElement
import com.amazon.ionelement.api.FloatElement
import com.amazon.ionelement.api.IntElement
import com.amazon.ionelement.api.IntElementSize
import com.amazon.ionelement.api.MetaContainer
import com.amazon.ionelement.api.emptyMetaContainer
import com.amazon.ionelement.api.ionDecimal
import com.amazon.ionelement.api.ionFloat
import com.amazon.ionelement.api.ionInt
import com.amazon.ionelement.api.ionString
import com.amazon.ionelement.api.ionSymbol
Expand Down Expand Up @@ -48,7 +43,6 @@ import org.partiql.value.TimestampValue
import org.partiql.value.datetime.TimeZone
import org.partiql.value.toIon
import java.math.BigDecimal
import java.math.BigInteger

/**
* Translates an [AstNode] tree to the legacy PIG AST.
Expand Down Expand Up @@ -322,42 +316,8 @@ private class AstTranslator(val metas: Map<String, MetaContainer>) : AstBaseVisi
val arg = visitExpr(node.expr, ctx)
when (node.op) {
Expr.Unary.Op.NOT -> not(arg, metas)
Expr.Unary.Op.POS -> {
when {
arg !is PartiqlAst.Expr.Lit -> pos(arg)
arg.value is IntElement -> arg
arg.value is FloatElement -> arg
arg.value is DecimalElement -> arg
else -> pos(arg)
}
}
Expr.Unary.Op.NEG -> {
when {
arg !is PartiqlAst.Expr.Lit -> neg(arg, metas)
arg.value is IntElement -> {
val intValue = when (arg.value.integerSize) {
IntElementSize.LONG -> ionInt(-arg.value.longValue)
IntElementSize.BIG_INTEGER -> when (arg.value.bigIntegerValue) {
Long.MAX_VALUE.toBigInteger() + (1L).toBigInteger() -> ionInt(Long.MIN_VALUE)
else -> ionInt(arg.value.bigIntegerValue * BigInteger.valueOf(-1L))
}
}
arg.copy(
value = intValue.asAnyElement(),
metas = metas,
)
}
arg.value is FloatElement -> arg.copy(
value = ionFloat(-(arg.value.doubleValue)).asAnyElement(),
metas = metas,
)
arg.value is DecimalElement -> arg.copy(
value = ionDecimal(Decimal.valueOf(-(arg.value.decimalValue))).asAnyElement(),
metas = metas,
)
else -> neg(arg, metas)
}
}
Expr.Unary.Op.POS -> pos(arg, metas)
Expr.Unary.Op.NEG -> neg(arg, metas)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.partiql.lang.syntax

import org.junit.jupiter.api.parallel.Execution
import org.junit.jupiter.api.parallel.ExecutionMode
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource

class PartiQLParserLiteralTests : PartiQLParserTestBase() {

override val targets: Array<ParserTarget> = arrayOf(ParserTarget.EXPERIMENTAL)

@ParameterizedTest
@MethodSource("cases")
@Execution(ExecutionMode.CONCURRENT)
fun testAll(case: Case) {
assertExpression(
source = case.input,
expectedPigAst = case.expect,
)
Long.MAX_VALUE
}

companion object {

@JvmStatic
fun cases() = listOf(
Case(
input = "1",
expect = "(lit 1)"
),
Case(
input = "+-1",
expect = "(lit -1)"
),
Case(
input = "-+1",
expect = "(lit -1)"
),
Case(
input = "-+-1",
expect = "(lit 1)"
),
Case(
input = "+++1",
expect = "(lit 1)"
),
Case(
input = "-1",
expect = "(lit -1)"
),
Case(
input = "+1",
expect = "(lit 1)"
),
Case(
input = "9223372036854775808", // Long.MAX_VALUE + 1
expect = "(lit 9223372036854775808)"
),
Case(
input = "-9223372036854775809", // Long.MIN_VALUE - 1
expect = "(lit -9223372036854775809)"
),
Case(
input = "+9223372036854775808",
expect = "(lit 9223372036854775808)"
),
)
}

class Case(
@JvmField val input: String,
@JvmField val expect: String,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ class PartiQLParserTest : PartiQLParserTestBase() {
}

@Test
@Ignore("Disabled because it's not clear that the parser should be pushing down negations on boxed Ion values")
fun unaryIonFloatLiteral() {
assertExpression(
"+-+-+-`-5e0`",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ import org.partiql.parser.SourceLocation
import org.partiql.parser.SourceLocations
import org.partiql.parser.antlr.PartiQLParserBaseVisitor
import org.partiql.parser.internal.util.DateTimeUtils
import org.partiql.parser.internal.util.NumberUtils.negate
import org.partiql.value.NumericValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.StringValue
Expand Down Expand Up @@ -569,17 +570,18 @@ internal class PartiQLParserDefault : PartiQLParser {
}
}

override fun visitQualifiedName(ctx: org.partiql.parser.antlr.PartiQLParser.QualifiedNameContext) = translate(ctx) {
val qualifier = ctx.qualifier.map { visitSymbolPrimitive(it) }
val name = visitSymbolPrimitive(ctx.name)
if (qualifier.isEmpty()) {
name
} else {
val root = qualifier.first()
val steps = qualifier.drop(1) + listOf(name)
identifierQualified(root, steps)
override fun visitQualifiedName(ctx: org.partiql.parser.antlr.PartiQLParser.QualifiedNameContext) =
translate(ctx) {
val qualifier = ctx.qualifier.map { visitSymbolPrimitive(it) }
val name = visitSymbolPrimitive(ctx.name)
if (qualifier.isEmpty()) {
name
} else {
val root = qualifier.first()
val steps = qualifier.drop(1) + listOf(name)
identifierQualified(root, steps)
}
}
}

/**
*
Expand Down Expand Up @@ -1488,9 +1490,34 @@ internal class PartiQLParserDefault : PartiQLParser {
}

override fun visitValueExpr(ctx: GeneratedParser.ValueExprContext) = translate(ctx) {
if (ctx.parent != null) return@translate visit(ctx.parent)
val expr = visit(ctx.rhs) as Expr
exprUnary(convertUnaryOp(ctx.sign), expr)
// expression
if (ctx.parent != null) {
return@translate visit(ctx.parent)
}
// unary expression
val op = when (ctx.sign.type) {
GeneratedParser.NOT -> Expr.Unary.Op.NOT
GeneratedParser.PLUS -> Expr.Unary.Op.POS
GeneratedParser.MINUS -> Expr.Unary.Op.NEG
else -> throw error(ctx.sign, "Invalid unary operator")
}
// If argument is not a literal, then return the op.
val arg = visit(ctx.rhs) as Expr
return when (arg) {
is Expr.Lit -> arg.negate(op)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this potentially can cause overflow:

-(-2147483648) => 2147483648 which can not be represented by 32 bit integer. 

// TODO should we unwrap and negate Ion values for -`-1`? I don't think so..
is Expr.Ion -> exprUnary(op, arg)
else -> exprUnary(op, arg)
}
}

private fun Expr.Lit.negate(op: Expr.Unary.Op): Expr {
val v = this.value
return when {
op == Expr.Unary.Op.POS && v is NumericValue<*> -> exprLit(v)
op == Expr.Unary.Op.NEG && v is NumericValue<*> -> exprLit(v.negate())
else -> exprUnary(op, exprLit(v))
}
}

private fun convertBinaryExpr(lhs: ParserRuleContext, rhs: ParserRuleContext, op: Expr.Binary.Op): Expr {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package org.partiql.parser.internal.util

import org.partiql.value.DecimalValue
import org.partiql.value.Float32Value
import org.partiql.value.Float64Value
import org.partiql.value.Int16Value
import org.partiql.value.Int32Value
import org.partiql.value.Int64Value
import org.partiql.value.Int8Value
import org.partiql.value.IntValue
import org.partiql.value.NumericValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.decimalValue
import org.partiql.value.float32Value
import org.partiql.value.float64Value
import org.partiql.value.int16Value
import org.partiql.value.int32Value
import org.partiql.value.int64Value
import org.partiql.value.int8Value
import org.partiql.value.intValue
import java.math.BigInteger

internal object NumberUtils {

/**
* We might consider a `negate` method on the NumericValue but this is fine for now and is internal.
*/
@OptIn(PartiQLValueExperimental::class)
internal fun NumericValue<*>.negate(): NumericValue<*> = when (this) {
is DecimalValue -> decimalValue(value?.negate())
is Float32Value -> float32Value(value?.let { it * -1 })
is Float64Value -> float64Value(value?.let { it * -1 })
is Int8Value -> when (value) {
null -> this
Byte.MIN_VALUE -> int16Value(value?.let { (it.toInt() * -1).toShort() })
else -> int8Value(value?.let { (it.toInt() * -1).toByte() })
}
is Int16Value -> when (value) {
null -> this
Short.MIN_VALUE -> int32Value(value?.let { it.toInt() * -1 })
else -> int16Value(value?.let { (it.toInt() * -1).toShort() })
}
is Int32Value -> when (value) {
null -> this
Int.MIN_VALUE -> int64Value(value?.let { it.toLong() * -1 })
else -> int32Value(value?.let { it * -1 })
}
is Int64Value -> when (value) {
null -> this
Long.MIN_VALUE -> intValue(BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE))
else -> int64Value(value?.let { it * -1 })
}
is IntValue -> intValue(value?.negate())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package org.partiql.parser.internal.util

import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.partiql.parser.internal.util.NumberUtils.negate
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.decimalValue
import org.partiql.value.int16Value
import org.partiql.value.int32Value
import org.partiql.value.int64Value
import org.partiql.value.int8Value
import org.partiql.value.intValue
import java.math.BigDecimal
import java.math.BigInteger

@OptIn(PartiQLValueExperimental::class)
class NumberUtilsTest {

@Test
fun negate_normal() {
assertEquals(int8Value(-1), int8Value(1).negate())
assertEquals(int16Value(-1), int16Value(1).negate())
assertEquals(int32Value(-1), int32Value(1).negate())
assertEquals(int64Value(-1), int64Value(1).negate())
assertEquals(intValue(BigInteger.valueOf(-1L)), intValue(BigInteger.valueOf(1L)).negate())
assertEquals(decimalValue(BigDecimal.valueOf(-1L)), decimalValue(BigDecimal.valueOf(1L)).negate())
}

@Test
fun negate_overflow() {
assertEquals(int16Value((Byte.MAX_VALUE.toShort() + 1).toShort()), int8Value(Byte.MIN_VALUE).negate())
assertEquals(int32Value((Short.MAX_VALUE.toInt() + 1)), int16Value(Short.MIN_VALUE).negate())
assertEquals(int64Value((Int.MAX_VALUE.toLong() + 1)), int32Value(Int.MIN_VALUE).negate())
assertEquals(intValue(BigInteger.valueOf(Long.MAX_VALUE) + BigInteger.ONE), int64Value(Long.MIN_VALUE).negate())
}
}
Loading