Skip to content

Commit

Permalink
Adds interruption support for v0.2.8
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn committed Jan 16, 2024
1 parent b286675 commit b2e6b87
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 2 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ allprojects {

subprojects {
group = 'org.partiql'
version = '0.2.7'
version = '0.2.8-SNAPSHOT'
}

buildDir = new File(rootProject.projectDir, "gradle-build/" + project.name)
Expand Down
27 changes: 26 additions & 1 deletion lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ internal class EvaluatingCompiler(
get() = compilationContextStack.peek() ?: throw EvaluationException(
"compilationContextStack was empty.", internal = true)


/**
* This checks whether the thread has been interrupted. Specifically, it currently checks during the compilation
* of aggregations and joins, the "evaluation" of aggregations and joins, and the materialization of joins
* and from source scans.
*
* Note: This is essentially a way to avoid constantly checking [CompileOptions.interruptible]. By writing it this
* way, we statically determine whether to introduce checks. If the compiler has specified
* [CompileOptions.interruptible], the invocation of this function will insert a Thread interruption check. If not
* specified, it will not perform the check during compilation/evaluation/materialization.
*/
private val interruptionCheck: () -> Unit = {
if (Thread.interrupted()) {
throw InterruptedException()
}
}

//Note: please don't make this inline -- it messes up [EvaluationException] stack traces and
//isn't a huge benefit because this is only used at SQL-compile time anyway.
private fun <R> nestCompilationContext(expressionContext: ExpressionContext,
Expand Down Expand Up @@ -981,7 +998,12 @@ internal class EvaluatingCompiler(
// Grouping is not needed -- simply project the results from the FROM clause directly.
thunkFactory.thunkEnv(metas) { env ->

val projectedRows = sourceThunks(env).map { (joinedValues, projectEnv) ->
val sourcedRows = sourceThunks(env).map {
interruptionCheck()
it
}

val projectedRows = sourcedRows.map { (joinedValues, projectEnv) ->
selectProjectionThunk(projectEnv, joinedValues)
}

Expand Down Expand Up @@ -1051,6 +1073,7 @@ internal class EvaluatingCompiler(
// iterate over the values from the FROM clause and populate our
// aggregate register values.
fromProductions.forEach { fromProduction ->
interruptionCheck()
compiledAggregates?.forEachIndexed { index, ca ->
registers[index].aggregator.next(ca.argThunk(fromProduction.env))
}
Expand Down Expand Up @@ -1459,6 +1482,7 @@ internal class EvaluatingCompiler(
// compute the join over the data sources
var seq = compiledSources
.foldLeftProduct({ env: Environment -> env }) { bindEnv: (Environment) -> Environment, source: CompiledFromSource ->
interruptionCheck()
fun correlatedBind(value: ExprValue): Pair<(Environment) -> Environment, ExprValue> {
// add the correlated binding environment thunk
val alias = source.alias
Expand Down Expand Up @@ -1517,6 +1541,7 @@ internal class EvaluatingCompiler(
}
.asSequence()
.map { joinedValues ->
interruptionCheck()
// bind the joined value to the bindings for the filter/project
FromProduction(joinedValues, fromEnv.nest(localsBinder.bindLocals(joinedValues)))
}
Expand Down
190 changes: 190 additions & 0 deletions lang/test/org/partiql/lang/eval/EvaluatingCompilerInterruptTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at:
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/

package org.partiql.lang.eval

import com.amazon.ion.system.IonSystemBuilder
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import org.partiql.lang.CompilerPipeline
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.concurrent.thread

/**
* Making sure we can interrupt the [EvaluatingCompiler].
*/
class EvaluatingCompilerInterruptTests {

private val ion = IonSystemBuilder.standard().build()
private val factory = ExprValueFactory.standard(ion)
private val session = EvaluationSession.standard()
private val pipeline = CompilerPipeline.standard(ion)

companion object {
/** How long (in millis) to wait after starting a thread to set the interrupted flag. */
const val INTERRUPT_AFTER_MS: Long = 100

/** How long (in millis) to wait for a thread to terminate after setting the interrupted flag. */
const val WAIT_FOR_THREAD_TERMINATION_MS: Long = 1000
}

/**
* Joins are only evaluated during the materialization of the ExprValue's elements. Cross Joins.
*/
@Test
fun evalCrossJoins() {
val query = """
SELECT
*
FROM
([1, 2, 3, 4]) as x1,
([1, 2, 3, 4]) as x2,
([1, 2, 3, 4]) as x3,
([1, 2, 3, 4]) as x4,
([1, 2, 3, 4]) as x5,
([1, 2, 3, 4]) as x6,
([1, 2, 3, 4]) as x7,
([1, 2, 3, 4]) as x8,
([1, 2, 3, 4]) as x9,
([1, 2, 3, 4]) as x10,
([1, 2, 3, 4]) as x11,
([1, 2, 3, 4]) as x12,
([1, 2, 3, 4]) as x13,
([1, 2, 3, 4]) as x14,
([1, 2, 3, 4]) as x15
""".trimIndent()
val expression = pipeline.compile(query)
val result = expression.eval(session)
testThreadInterrupt {
result.forEach { it }
}
}

/**
* Joins are only evaluated during the materialization of the ExprValue's elements. Making sure left
* joins can be interrupted.
*/
@Test
fun evalLeftJoins() {
val query = """
SELECT
*
FROM
[1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN
([1, 2, 3, 4] LEFT JOIN ([1, 2, 3, 4]) ON TRUE) ON TRUE) ON TRUE) ON TRUE) ON TRUE) ON TRUE) ON TRUE) ON TRUE) ON TRUE) ON TRUE) ON TRUE) ON TRUE) ON TRUE) ON TRUE
""".trimIndent()
val expression = pipeline.compile(query)
val result = expression.eval(session)
testThreadInterrupt {
result.forEach { it }
}
}

/**
* Aggregations currently get materialized during [Expression.evaluate], so we need to check that we can
* interrupt there.
*/
@Test
fun compileLargeAggregation() {
val query = """
SELECT
COUNT(*)
FROM
([1, 2, 3, 4]) as x1,
([1, 2, 3, 4]) as x2,
([1, 2, 3, 4]) as x3,
([1, 2, 3, 4]) as x4,
([1, 2, 3, 4]) as x5,
([1, 2, 3, 4]) as x6,
([1, 2, 3, 4]) as x7,
([1, 2, 3, 4]) as x8,
([1, 2, 3, 4]) as x9,
([1, 2, 3, 4]) as x10,
([1, 2, 3, 4]) as x11,
([1, 2, 3, 4]) as x12,
([1, 2, 3, 4]) as x13,
([1, 2, 3, 4]) as x14,
([1, 2, 3, 4]) as x15
""".trimIndent()
val expression = pipeline.compile(query)
testThreadInterrupt {
expression.eval(session)
}
}

/**
* We need to make sure that we can end a never-ending query. These sorts of queries get materialized during the
* iteration of [ExprValue].
*/
@Test
fun neverEndingScan() {
val indefiniteCollection = factory.newBag(
sequence {
while (true) {
yield(factory.nullValue)
}
}
)
val query = """
SELECT *
FROM ?
""".trimIndent()
val session = EvaluationSession.build {
parameters(listOf(indefiniteCollection))
}

val expression = pipeline.compile(query)
val result = expression.eval(session)
testThreadInterrupt {
result.forEach { it }
}
}

private fun testThreadInterrupt(
interruptAfter: Long = INTERRUPT_AFTER_MS,
interruptWait: Long = WAIT_FOR_THREAD_TERMINATION_MS,
block: () -> Unit
) {
val wasInterrupted = AtomicBoolean(false)
val t = thread(start = false) {
try {
block()
} catch (_: InterruptedException) {
wasInterrupted.set(true)
} catch (e: EvaluationException) {
if (e.cause is InterruptedException) {
wasInterrupted.set(true)
}
}
}
t.setUncaughtExceptionHandler { _, ex -> throw ex }
t.start()
Thread.sleep(interruptAfter)
t.interrupt()
t.join(interruptWait)
Assertions.assertTrue(wasInterrupted.get(), "Thread should have been interrupted.")
}
}

0 comments on commit b2e6b87

Please sign in to comment.