Skip to content

Commit

Permalink
Fixes underflow/overflow for negation
Browse files Browse the repository at this point in the history
  • Loading branch information
rchowell committed Jul 15, 2024
1 parent 4027b73 commit 94d88b2
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,7 @@ import org.partiql.parser.SourceLocation
import org.partiql.parser.SourceLocations
import org.partiql.parser.antlr.PartiQLParserBaseVisitor
import org.partiql.parser.internal.util.DateTimeUtils
import org.partiql.value.DecimalValue
import org.partiql.value.Float32Value
import org.partiql.value.Float64Value
import org.partiql.value.Int16Value
import org.partiql.value.Int32Value
import org.partiql.value.Int64Value
import org.partiql.value.Int8Value
import org.partiql.value.IntValue
import org.partiql.parser.internal.util.NumberUtils.negate
import org.partiql.value.NumericValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.StringValue
Expand All @@ -225,12 +218,8 @@ import org.partiql.value.dateValue
import org.partiql.value.datetime.DateTimeException
import org.partiql.value.datetime.DateTimeValue
import org.partiql.value.decimalValue
import org.partiql.value.float32Value
import org.partiql.value.float64Value
import org.partiql.value.int16Value
import org.partiql.value.int32Value
import org.partiql.value.int64Value
import org.partiql.value.int8Value
import org.partiql.value.intValue
import org.partiql.value.missingValue
import org.partiql.value.nullValue
Expand Down Expand Up @@ -1531,20 +1520,6 @@ internal class PartiQLParserDefault : PartiQLParser {
}
}

/**
* We might consider a `negate` method on the NumericValue but this is fine for now and is private.
*/
private fun NumericValue<*>.negate(): NumericValue<*> = when (this) {
is DecimalValue -> decimalValue(value?.negate())
is Float32Value -> float32Value(value?.let { it * -1 })
is Float64Value -> float64Value(value?.let { it * -1 })
is Int8Value -> int8Value(value?.let { (it.toInt() * -1).toByte() })
is Int16Value -> int16Value(value?.let { (it.toInt() * -1).toShort() })
is Int32Value -> int32Value(value?.let { it * -1 })
is Int64Value -> int64Value(value?.let { it * -1 })
is IntValue -> intValue(value?.negate())
}

private fun convertBinaryExpr(lhs: ParserRuleContext, rhs: ParserRuleContext, op: Expr.Binary.Op): Expr {
val l = visit(lhs) as Expr
val r = visit(rhs) as Expr
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package org.partiql.parser.internal.util

import org.partiql.value.DecimalValue
import org.partiql.value.Float32Value
import org.partiql.value.Float64Value
import org.partiql.value.Int16Value
import org.partiql.value.Int32Value
import org.partiql.value.Int64Value
import org.partiql.value.Int8Value
import org.partiql.value.IntValue
import org.partiql.value.NumericValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.decimalValue
import org.partiql.value.float32Value
import org.partiql.value.float64Value
import org.partiql.value.int16Value
import org.partiql.value.int32Value
import org.partiql.value.int64Value
import org.partiql.value.int8Value
import org.partiql.value.intValue
import java.math.BigInteger

internal object NumberUtils {

/**
* We might consider a `negate` method on the NumericValue but this is fine for now and is internal.
*/
@OptIn(PartiQLValueExperimental::class)
internal fun NumericValue<*>.negate(): NumericValue<*> = when (this) {
is DecimalValue -> decimalValue(value?.negate())
is Float32Value -> float32Value(value?.let { it * -1 })
is Float64Value -> float64Value(value?.let { it * -1 })
is Int8Value -> when (value) {
null -> this
Byte.MIN_VALUE -> int16Value(value?.let { (it.toInt() * -1).toShort() })
else -> int8Value(value?.let { (it.toInt() * -1).toByte() })
}
is Int16Value -> when (value) {
null -> this
Short.MIN_VALUE -> int32Value(value?.let { it.toInt() * -1 })
else -> int16Value(value?.let { (it.toInt() * -1).toShort() })
}
is Int32Value -> when (value) {
null -> this
Int.MIN_VALUE -> int64Value(value?.let { it.toLong() * -1 })
else -> int32Value(value?.let { it * -1 })
}
is Int64Value -> when (value) {
null -> this
Long.MIN_VALUE -> intValue(BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE))
else -> int64Value(value?.let { it * -1 })
}
is IntValue -> intValue(value?.negate())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package org.partiql.parser.internal.util

import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.partiql.parser.internal.util.NumberUtils.negate
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.decimalValue
import org.partiql.value.int16Value
import org.partiql.value.int32Value
import org.partiql.value.int64Value
import org.partiql.value.int8Value
import org.partiql.value.intValue
import java.math.BigDecimal
import java.math.BigInteger

@OptIn(PartiQLValueExperimental::class)
class NumberUtilsTest {

@Test
fun negate_normal() {
assertEquals(int8Value(-1), int8Value(1).negate())
assertEquals(int16Value(-1), int16Value(1).negate())
assertEquals(int32Value(-1), int32Value(1).negate())
assertEquals(int64Value(-1), int64Value(1).negate())
assertEquals(intValue(BigInteger.valueOf(-1L)), intValue(BigInteger.valueOf(1L)).negate())
assertEquals(decimalValue(BigDecimal.valueOf(-1L)), decimalValue(BigDecimal.valueOf(1L)).negate())
}

@Test
fun negate_overflow() {
assertEquals(int16Value((Byte.MAX_VALUE.toShort() + 1).toShort()), int8Value(Byte.MIN_VALUE).negate())
assertEquals(int32Value((Short.MAX_VALUE.toInt() + 1)), int16Value(Short.MIN_VALUE).negate())
assertEquals(int64Value((Int.MAX_VALUE.toLong() + 1)), int32Value(Int.MIN_VALUE).negate())
assertEquals(intValue(BigInteger.valueOf(Long.MAX_VALUE) + BigInteger.ONE), int64Value(Long.MIN_VALUE).negate())
}
}

0 comments on commit 94d88b2

Please sign in to comment.