Skip to content

Commit

Permalink
Adds CURRENT_USER and CURRENT_DATE to transpiler (#1228)
Browse files Browse the repository at this point in the history
  • Loading branch information
RCHowell authored Sep 28, 2023
1 parent ef73d18 commit c136f02
Show file tree
Hide file tree
Showing 18 changed files with 181 additions and 18 deletions.
2 changes: 1 addition & 1 deletion buildSrc/src/main/kotlin/partiql.versions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object Versions {
const val kotlinxCollections = "0.3.5"
const val picoCli = "4.7.0"
const val kasechange = "1.3.0"
const val ktlint = "11.5.0"
const val ktlint = "11.6.0"
const val pig = "0.6.2"

// Testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@

package org.partiql.transpiler.cli

import com.amazon.ion.IonWriter
import com.amazon.ion.system.IonTextWriterBuilder
import com.amazon.ion.system.IonWriterBuilder
import com.amazon.ionelement.api.field
import com.amazon.ionelement.api.ionString
import com.amazon.ionelement.api.ionStructOf
Expand All @@ -34,8 +32,6 @@ import org.jline.utils.AttributedStyle
import org.jline.utils.AttributedStyle.BOLD
import org.jline.utils.InfoCmp
import org.joda.time.Duration
import org.partiql.ast.sql.SqlBlock
import org.partiql.ast.sql.SqlLayout
import org.partiql.planner.PartiQLPlanner
import org.partiql.planner.test.plugin.FsConnector
import org.partiql.planner.test.plugin.FsPlugin
Expand All @@ -47,7 +43,6 @@ import org.partiql.spi.connector.ConnectorSession
import org.partiql.transpiler.PartiQLTranspiler
import org.partiql.transpiler.TpTarget
import org.partiql.transpiler.TranspilerProblem
import org.partiql.transpiler.sql.SqlTransform
import org.partiql.transpiler.targets.partiql.PartiQLTarget
import org.partiql.transpiler.targets.redshift.RedshiftTarget
import org.partiql.transpiler.targets.trino.TrinoTarget
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ public abstract class SqlCalls {
"is_struct" to { args -> isType(PartiQLValueType.STRUCT, args) },
"is_null" to { args -> isType(PartiQLValueType.NULL, args) },
"is_missing" to { args -> isType(PartiQLValueType.MISSING, args) },
// Session Attributes
"\$__current_user" to sessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER),
"\$__current_date" to sessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE),
)

public fun retarget(name: String, args: SqlArgs): Expr {
Expand All @@ -150,6 +153,10 @@ public abstract class SqlCalls {
)
}

private fun sessionAttribute(attribute: Expr.SessionAttribute.Attribute): SqlCallFn {
return { _ -> Ast.exprSessionAttribute(attribute) }
}

public open fun unary(op: Expr.Unary.Op, args: SqlArgs): Expr {
assert(args.size == 1) { "Unary operator $op requires exactly 1 argument" }
return Ast.exprUnary(op, args[0].expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public class RedshiftCalls(private val onProblem: ProblemCallback) : SqlCalls()
PartiQLValueType.INT -> {
onProblem.error("PartiQL `INT` type (arbitrary precision integer) not supported in Redshift")
// this needs a extra safety renaming because int refers to int4 in redshift.
exprCast(args[0].expr, typeCustom("Arbitrary Precision Integer"))
exprCast(args[0].expr, typeCustom("UNKNOWN"))
}
PartiQLValueType.MISSING -> {
onProblem.error("PartiQL `MISSING` type not supported in Redshift")
Expand Down Expand Up @@ -114,6 +114,8 @@ public class RedshiftCalls(private val onProblem: ProblemCallback) : SqlCalls()
PartiQLValueType.BYTE -> TODO("Mapping to VARBYTE(1), do this after supporting parameterized type")
else -> super.rewriteCast(type, args)
}
}

/**
* Push the negation down if possible.
* For example : NOT 1 is NULL -> 1 is NOT NULL.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package org.partiql.transpiler.targets.redshift

import org.partiql.ast.sql.SqlDialect
import org.partiql.plan.PartiQLPlan
import org.partiql.plan.PlanNode
import org.partiql.plan.Rel
Expand Down
42 changes: 42 additions & 0 deletions lib/partiql-transpiler/src/test/resources/cases/tests_00.ion
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,47 @@ suite::{
}
}
},
'0004': {
statement: '''
SELECT s_store_sk FROM
tpc_ds.store AS store
LEFT JOIN
tpc_ds.store_returns AS returns
ON s_store_sk = sr_store_sk
''',
schema: {
type: "bag",
items: {
type: "struct",
fields: [
{
name: "s_store_sk",
type: "string"
}
]
}
}
},
'0005': {
statement: '''
SELECT CURRENT_USER, CURRENT_DATE FROM store_sales
''',
schema: {
type: "bag",
items: {
type: "struct",
fields: [
{
name: "CURRENT_USER",
type: "string",
},
{
name: "CURRENT_DATE",
type: "date",
},
],
},
},
},
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,10 @@ target::{
ON store.s_store_sk = returns.sr_store_sk
'''
},
'0005': {
statement: '''
SELECT CURRENT_USER AS CURRENT_USER, CURRENT_DATE AS CURRENT_DATE FROM store_sales AS store_sales
''',
},
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ target::{
},
'0001': {
statement: '''
SELECT store_sales.ss_wholesale_cost + CAST(10 AS DOUBLE PRECISION) < store_sales.ss_list_price AS x
SELECT store_sales.ss_wholesale_cost + CAST(10 AS FLOAT8) < store_sales.ss_list_price AS x
FROM store_sales AS store_sales
''',
},
Expand All @@ -31,6 +31,7 @@ target::{
// TODO: For the last projection item (CASE-WHEN), the second WHEN should not be able to compare INT32? and STR.
// TODO: However, this will be addressed when functions are resolved. When this happens, we should assert on the
// TODO: error message here.
// TODO: PartiQL integer literals get type INT (which does not exist in Redshift) causing transpilation to fail
statement: '''
SELECT
store_sales.ss_quantity AS ss_quantity,
Expand Down Expand Up @@ -64,5 +65,10 @@ target::{
ON store.s_store_sk = returns.sr_store_sk
'''
},
'0005': {
statement: '''
SELECT CURRENT_USER AS CURRENT_USER, CURRENT_DATE AS CURRENT_DATE FROM store_sales AS store_sales
''',
},
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,10 @@ target::{
ON store.s_store_sk = returns.sr_store_sk
'''
},
'0005': {
statement: '''
SELECT CURRENT_USER AS CURRENT_USER, CURRENT_DATE AS CURRENT_DATE FROM store_sales AS store_sales
''',
},
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ private val col = { index: Int -> "_${index + 1}" }
* 1. If item is an id, use the last symbol
* 2. If item is a path with a final symbol step, use the symbol — else 4
* 3. If item is a cast, use the value name
* 4. Else, use item index with prefix _
* 4. If item is a Session Attribute, use the variable name (similar to Expr.Var)
* 5. Else, use item index with prefix _
*
* See https://github.com/partiql/partiql-lang-kotlin/issues/1122
*/
public fun Expr.toBinder(index: Int): Identifier.Symbol = when (this) {
is Expr.Var -> this.identifier.toBinder()
is Expr.Path -> this.toBinder(index)
is Expr.Cast -> this.value.toBinder(index)
is Expr.SessionAttribute -> this.attribute.name.toBinder()
else -> col(index).toBinder()
}

Expand Down
1 change: 1 addition & 0 deletions partiql-ast/src/main/resources/partiql_ast.ion
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ expr::[
session_attribute::{
attribute: [
CURRENT_USER,
CURRENT_DATE,
],
},

Expand Down
1 change: 1 addition & 0 deletions partiql-parser/src/main/antlr/PartiQL.g4
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ exprPrimary
exprTerm
: PAREN_LEFT expr PAREN_RIGHT # ExprTermWrappedQuery
| CURRENT_USER # ExprTermCurrentUser
| CURRENT_DATE # ExprTermCurrentDate
| parameter # ExprTermBase
| varRefExpr # ExprTermBase
| literal # ExprTermBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,10 @@ internal class PartiQLParserDefault : PartiQLParser {
exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER)
}

override fun visitExprTermCurrentDate(ctx: org.partiql.parser.antlr.PartiQLParser.ExprTermCurrentDateContext) = translate(ctx) {
exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE)
}

/**
*
* FUNCTIONS
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package org.partiql.parser.impl

import org.junit.jupiter.api.Test
import org.partiql.ast.AstNode
import org.partiql.ast.Expr
import org.partiql.ast.builder.AstFactory
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.int64Value
import kotlin.test.assertEquals

@OptIn(PartiQLValueExperimental::class)
class PartiQLParserSessionAttributeTests {

private val parser = PartiQLParserDefault()

private fun query(body: AstFactory.() -> Expr) = AstFactory.create {
statementQuery(this.body())
}

@Test
fun currentUserUpperCase() = assertExpression(
"CURRENT_USER",
query {
exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER)
}
)

@Test
fun currentUserMixedCase() = assertExpression(
"CURRENT_user",
query {
exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER)
}
)

@Test
fun currentUserLowerCase() = assertExpression(
"current_user",
query {
exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER)
}
)

@Test
fun currentUserEquals() = assertExpression(
"1 = current_user",
query {
exprBinary(
op = Expr.Binary.Op.EQ,
lhs = exprLit(int64Value(1)),
rhs = exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER)
)
}
)

@Test
fun currentDateUpperCase() = assertExpression(
"CURRENT_DATE",
query {
exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE)
}
)

@Test
fun currentDateMixedCase() = assertExpression(
"CURRENT_date",
query {
exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE)
}
)

@Test
fun currentDateLowerCase() = assertExpression(
"current_date",
query {
exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE)
}
)

private fun assertExpression(input: String, expected: AstNode) {
val result = parser.parse(input)
val actual = result.root
assertEquals(expected, actual)
}
}
4 changes: 3 additions & 1 deletion partiql-plan/src/main/resources/partiql_plan_0_1.ion
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ rex::{
],
},

err::{},
err::{
message: string,
},
],
}

Expand Down
11 changes: 10 additions & 1 deletion partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ private typealias FunctionMap = Map<String, List<FunctionSignature>>
* Map session attributes to underlying function name.
*/
internal val ATTRIBUTES: Map<String, String> = mapOf(
"CURRENT_USER" to "\$__current_user"
"CURRENT_USER" to "\$__current_user",
"CURRENT_DATE" to "\$__current_date"
)

/**
Expand Down Expand Up @@ -223,6 +224,7 @@ internal class Header(

public fun system(): List<FunctionSignature> = listOf(
currentUser(),
currentDate(),
)

private val allTypes = PartiQLValueType.values()
Expand Down Expand Up @@ -635,6 +637,13 @@ internal class Header(
isNullable = true,
)

private fun currentDate() = FunctionSignature(
name = "\$__current_date",
returns = DATE,
parameters = emptyList(),
isNullable = false,
)

// Function precedence comparator
// 1. Fewest args first
// 2. Parameters are compared left-to-right
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import org.partiql.planner.typer.toNonNullStaticType
import org.partiql.planner.typer.toStaticType
import org.partiql.types.StaticType
import org.partiql.types.TimeType
import org.partiql.types.TimestampType
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.boolValue
import org.partiql.value.int32Value
Expand Down Expand Up @@ -390,7 +388,7 @@ internal object RexConverter {
override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex = transform {
val type = node.asType
val arg0 = visitExpr(node.value, ctx)
when(type) {
when (type) {
is Type.NullType -> rex(StaticType.NULL, call("cast_null", arg0))
is Type.Missing -> rex(StaticType.MISSING, call("cast_missing", arg0))
is Type.Bool -> rex(StaticType.BOOL, call("cast_bool", arg0))
Expand All @@ -416,7 +414,7 @@ internal object RexConverter {
is Type.Date -> rex(StaticType.DATE, call("cast_date", arg0))
is Type.Time -> rex(StaticType.TIME, call("cast_time", arg0))
is Type.TimeWithTz -> rex(TimeType(null, true), call("cast_timeWithTz", arg0))
is Type.Timestamp -> TODO("Need to rebase main")
is Type.Timestamp -> TODO("Need to rebase main")
is Type.TimestampWithTz -> rex(StaticType.TIMESTAMP, call("cast_timeWithTz", arg0))
is Type.Interval -> TODO("Static Type does not have Interval type")
is Type.Bag -> rex(StaticType.BAG, call("cast_bag", arg0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ internal class PlanTyper(
}
is FnMatch.Error -> {
handleUnknownFunction(match)
rex(StaticType.NULL_OR_MISSING, rexOpErr())
rex(StaticType.NULL_OR_MISSING, rexOpErr("Unknown function $fn"))
}
}
}
Expand All @@ -357,7 +357,7 @@ internal class PlanTyper(
override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Rex = rewrite {
if (ctx!! !is CollectionType) {
handleUnexpectedType(ctx, setOf(StaticType.LIST, StaticType.BAG, StaticType.SEXP))
return rex(StaticType.NULL_OR_MISSING, rexOpErr())
return rex(StaticType.NULL_OR_MISSING, rexOpErr("Expected collection type"))
}
val values = node.values.map { visitRex(it, null) }
val t = values.toUnionType()
Expand Down

0 comments on commit c136f02

Please sign in to comment.