Skip to content

Commit

Permalink
Add operator node to AST and parser (#1499)
Browse files Browse the repository at this point in the history
  • Loading branch information
alancai98 authored Jul 15, 2024
1 parent 35271b1 commit ff03b0a
Show file tree
Hide file tree
Showing 16 changed files with 736 additions and 408 deletions.
291 changes: 172 additions & 119 deletions partiql-ast/api/partiql-ast.api

Large diffs are not rendered by default.

118 changes: 67 additions & 51 deletions partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt
Original file line number Diff line number Diff line change
Expand Up @@ -336,70 +336,86 @@ private class AstTranslator(val metas: Map<String, MetaContainer>) : AstBaseVisi
return aggregates.contains(this)
}

override fun visitExprUnary(node: Expr.Unary, ctx: Ctx) = translate(node) { metas ->
val arg = visitExpr(node.expr, ctx)
when (node.op) {
Expr.Unary.Op.NOT -> not(arg, metas)
Expr.Unary.Op.POS -> {
when {
arg !is PartiqlAst.Expr.Lit -> pos(arg)
arg.value is IntElement -> arg
arg.value is FloatElement -> arg
arg.value is DecimalElement -> arg
else -> pos(arg)
override fun visitExprOperator(node: Expr.Operator, ctx: Ctx) = translate(node) { metas ->
val lhs = node.lhs?.let { visitExpr(it, ctx) }
val rhs = visitExpr(node.rhs, ctx)
if (lhs == null) {
when (node.symbol) {
"+" -> {
when {
rhs !is PartiqlAst.Expr.Lit -> pos(rhs)
rhs.value is IntElement -> rhs
rhs.value is FloatElement -> rhs
rhs.value is DecimalElement -> rhs
else -> pos(rhs)
}
}
}
Expr.Unary.Op.NEG -> {
when {
arg !is PartiqlAst.Expr.Lit -> neg(arg, metas)
arg.value is IntElement -> {
val intValue = when (arg.value.integerSize) {
IntElementSize.LONG -> ionInt(-arg.value.longValue)
IntElementSize.BIG_INTEGER -> when (arg.value.bigIntegerValue) {
Long.MAX_VALUE.toBigInteger() + (1L).toBigInteger() -> ionInt(Long.MIN_VALUE)
else -> ionInt(arg.value.bigIntegerValue * BigInteger.valueOf(-1L))
"-" -> {
when {
rhs !is PartiqlAst.Expr.Lit -> neg(rhs, metas)
rhs.value is IntElement -> {
val intValue = when (rhs.value.integerSize) {
IntElementSize.LONG -> ionInt(-rhs.value.longValue)
IntElementSize.BIG_INTEGER -> when (rhs.value.bigIntegerValue) {
Long.MAX_VALUE.toBigInteger() + (1L).toBigInteger() -> ionInt(Long.MIN_VALUE)
else -> ionInt(rhs.value.bigIntegerValue * BigInteger.valueOf(-1L))
}
}
rhs.copy(
value = intValue.asAnyElement(),
metas = metas,
)
}
arg.copy(
value = intValue.asAnyElement(),
rhs.value is FloatElement -> rhs.copy(
value = ionFloat(-(rhs.value.doubleValue)).asAnyElement(),
metas = metas,
)
rhs.value is DecimalElement -> rhs.copy(
value = ionDecimal(Decimal.valueOf(-(rhs.value.decimalValue))).asAnyElement(),
metas = metas,
)
else -> neg(rhs, metas)
}
arg.value is FloatElement -> arg.copy(
value = ionFloat(-(arg.value.doubleValue)).asAnyElement(),
metas = metas,
)
arg.value is DecimalElement -> arg.copy(
value = ionDecimal(Decimal.valueOf(-(arg.value.decimalValue))).asAnyElement(),
metas = metas,
)
else -> neg(arg, metas)
}
else -> error("unsupported unary expr operator $node")
}
} else {
val operands = listOf(lhs, rhs)
when (node.symbol) {
"+" -> plus(operands, metas)
"-" -> minus(operands, metas)
"*" -> times(operands, metas)
"/" -> divide(operands, metas)
"%" -> modulo(operands, metas)
"||" -> concat(operands, metas)
"=" -> eq(operands, metas)
"<>" -> ne(operands, metas)
"!=" -> ne(operands, metas)
">" -> gt(operands, metas)
">=" -> gte(operands, metas)
"<" -> lt(operands, metas)
"<=" -> lte(operands, metas)
"&" -> bitwiseAnd(operands, metas)
else -> error("unsupported binary expr operator $node")
}
}
}

override fun visitExprBinary(node: Expr.Binary, ctx: Ctx) = translate(node) { metas ->
override fun visitExprAnd(node: Expr.And, ctx: Ctx) = translate(node) { metas ->
val lhs = visitExpr(node.lhs, ctx)
val rhs = visitExpr(node.rhs, ctx)
val operands = listOf(lhs, rhs)
when (node.op) {
Expr.Binary.Op.PLUS -> plus(operands, metas)
Expr.Binary.Op.MINUS -> minus(operands, metas)
Expr.Binary.Op.TIMES -> times(operands, metas)
Expr.Binary.Op.DIVIDE -> divide(operands, metas)
Expr.Binary.Op.MODULO -> modulo(operands, metas)
Expr.Binary.Op.CONCAT -> concat(operands, metas)
Expr.Binary.Op.AND -> and(operands, metas)
Expr.Binary.Op.OR -> or(operands, metas)
Expr.Binary.Op.EQ -> eq(operands, metas)
Expr.Binary.Op.NE -> ne(operands, metas)
Expr.Binary.Op.GT -> gt(operands, metas)
Expr.Binary.Op.GTE -> gte(operands, metas)
Expr.Binary.Op.LT -> lt(operands, metas)
Expr.Binary.Op.LTE -> lte(operands, metas)
Expr.Binary.Op.BITWISE_AND -> bitwiseAnd(operands, metas)
}
and(lhs, rhs)
}

override fun visitExprOr(node: Expr.Or, ctx: Ctx) = translate(node) { metas ->
val lhs = visitExpr(node.lhs, ctx)
val rhs = visitExpr(node.rhs, ctx)
or(lhs, rhs)
}

override fun visitExprNot(node: Expr.Not, ctx: Ctx) = translate(node) { metas ->
val rhs = visitExpr(node.value, ctx)
not(rhs)
}

override fun visitExprPath(node: Expr.Path, ctx: Ctx) = translate(node) { metas ->
Expand Down
57 changes: 30 additions & 27 deletions partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt
Original file line number Diff line number Diff line change
Expand Up @@ -229,44 +229,47 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
return head concat r("`$value`")
}

override fun visitExprUnary(node: Expr.Unary, head: SqlBlock): SqlBlock {
val op = when (node.op) {
Expr.Unary.Op.NOT -> "NOT ("
Expr.Unary.Op.POS -> "+("
Expr.Unary.Op.NEG -> "-("
override fun visitExprOperator(node: Expr.Operator, head: SqlBlock): SqlBlock {
val lhs = node.lhs
return if (lhs != null) {
var h = head
h = visitExprWrapped(node.lhs, h)
h = h concat r(" ${node.symbol} ")
h = visitExprWrapped(node.rhs, h)
h
} else {
var h = head
h = h concat r(node.symbol + "(")
h = visitExprWrapped(node.rhs, h)
h = h concat r(")")
return h
}
}

override fun visitExprAnd(node: Expr.And, head: SqlBlock): SqlBlock {
var h = head
h = h concat r(op)
h = visitExprWrapped(node.expr, h)
h = h concat r(")")
h = visitExprWrapped(node.lhs, h)
h = h concat r(" AND ")
h = visitExprWrapped(node.rhs, h)
return h
}

override fun visitExprBinary(node: Expr.Binary, head: SqlBlock): SqlBlock {
val op = when (node.op) {
Expr.Binary.Op.PLUS -> "+"
Expr.Binary.Op.MINUS -> "-"
Expr.Binary.Op.TIMES -> "*"
Expr.Binary.Op.DIVIDE -> "/"
Expr.Binary.Op.MODULO -> "%"
Expr.Binary.Op.CONCAT -> "||"
Expr.Binary.Op.AND -> "AND"
Expr.Binary.Op.OR -> "OR"
Expr.Binary.Op.EQ -> "="
Expr.Binary.Op.NE -> "<>"
Expr.Binary.Op.GT -> ">"
Expr.Binary.Op.GTE -> ">="
Expr.Binary.Op.LT -> "<"
Expr.Binary.Op.LTE -> "<="
Expr.Binary.Op.BITWISE_AND -> "&"
}
override fun visitExprOr(node: Expr.Or, head: SqlBlock): SqlBlock {
var h = head
h = visitExprWrapped(node.lhs, h)
h = h concat r(" $op ")
h = h concat r(" OR ")
h = visitExprWrapped(node.rhs, h)
return h
}

override fun visitExprNot(node: Expr.Not, head: SqlBlock): SqlBlock {
var h = head
h = h concat r("NOT (")
h = visitExprWrapped(node.value, h)
h = h concat r(")")
return h
}

override fun visitExprVar(node: Expr.Var, head: SqlBlock): SqlBlock {
var h = head
// Prepend @
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,44 +255,47 @@ internal abstract class InternalSqlDialect : AstBaseVisitor<InternalSqlBlock, In
return tail concat "`$value`"
}

override fun visitExprUnary(node: Expr.Unary, tail: InternalSqlBlock): InternalSqlBlock {
val op = when (node.op) {
Expr.Unary.Op.NOT -> "NOT ("
Expr.Unary.Op.POS -> "+("
Expr.Unary.Op.NEG -> "-("
override fun visitExprOperator(node: Expr.Operator, tail: InternalSqlBlock): InternalSqlBlock {
val lhs = node.lhs
return if (lhs != null) {
var t = tail
t = visitExprWrapped(node.lhs, t)
t = t concat " ${node.symbol} "
t = visitExprWrapped(node.rhs, t)
t
} else {
var t = tail
t = t concat node.symbol + "("
t = visitExprWrapped(node.rhs, t)
t = t concat ")"
return t
}
}

override fun visitExprAnd(node: Expr.And, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
t = t concat op
t = visitExprWrapped(node.expr, t)
t = t concat ")"
t = visitExprWrapped(node.lhs, t)
t = t concat " AND "
t = visitExprWrapped(node.rhs, t)
return t
}

override fun visitExprBinary(node: Expr.Binary, tail: InternalSqlBlock): InternalSqlBlock {
val op = when (node.op) {
Expr.Binary.Op.PLUS -> "+"
Expr.Binary.Op.MINUS -> "-"
Expr.Binary.Op.TIMES -> "*"
Expr.Binary.Op.DIVIDE -> "/"
Expr.Binary.Op.MODULO -> "%"
Expr.Binary.Op.CONCAT -> "||"
Expr.Binary.Op.AND -> "AND"
Expr.Binary.Op.OR -> "OR"
Expr.Binary.Op.EQ -> "="
Expr.Binary.Op.NE -> "<>"
Expr.Binary.Op.GT -> ">"
Expr.Binary.Op.GTE -> ">="
Expr.Binary.Op.LT -> "<"
Expr.Binary.Op.LTE -> "<="
Expr.Binary.Op.BITWISE_AND -> "&"
}
override fun visitExprOr(node: Expr.Or, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
t = visitExprWrapped(node.lhs, t)
t = t concat " $op "
t = t concat " OR "
t = visitExprWrapped(node.rhs, t)
return t
}

override fun visitExprNot(node: Expr.Not, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
t = t concat "NOT ("
t = visitExprWrapped(node.value, t)
t = t concat ")"
return t
}

override fun visitExprVar(node: Expr.Var, tail: InternalSqlBlock): InternalSqlBlock {
var t = tail
// Prepend @
Expand Down
29 changes: 18 additions & 11 deletions partiql-ast/src/main/resources/partiql_ast.ion
Original file line number Diff line number Diff line change
Expand Up @@ -354,19 +354,26 @@ expr::[
index: int,
},

// Unary Operators
unary::{
op: [ NOT, POS, NEG ],
expr: expr,
// Operator expr node
operator::{
symbol: string,
lhs: optional::expr,
rhs: expr
},

// Binary Operators
binary::{
op: [
PLUS, MINUS, TIMES, DIVIDE, MODULO, CONCAT, BITWISE_AND,
AND, OR,
EQ, NE, GT, GTE, LT, LTE,
],
// SQL special form `NOT`
not::{
value: expr,
},

// SQL special form `AND`
and::{
lhs: expr,
rhs: expr,
},

// SQL special form `OR`
or::{
lhs: expr,
rhs: expr,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,27 +286,26 @@ class ToLegacyAstTest {
@JvmStatic
fun operators() = listOf(
expect("(not (lit null))") {
exprUnary {
op = Expr.Unary.Op.NOT
expr = NULL
exprNot {
value = NULL
}
},
expect("(pos (lit null))") {
exprUnary {
op = Expr.Unary.Op.POS
expr = NULL
exprOperator {
symbol = "+"
rhs = NULL
}
},
expect("(neg (lit null))") {
exprUnary {
op = Expr.Unary.Op.NEG
expr = NULL
exprOperator {
symbol = "-"
rhs = NULL
}
},
// we don't really need to test _all_ binary operators
expect("(plus (lit null) (lit null))") {
exprBinary {
op = Expr.Binary.Op.PLUS
exprOperator {
symbol = "+"
lhs = NULL
rhs = NULL
}
Expand Down
Loading

0 comments on commit ff03b0a

Please sign in to comment.