diff --git a/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/extensions/CompletableFutureExtensions.kt b/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/extensions/CompletableFutureExtensions.kt index 98df382ed5..52fb37f17f 100644 --- a/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/extensions/CompletableFutureExtensions.kt +++ b/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/extensions/CompletableFutureExtensions.kt @@ -37,23 +37,25 @@ fun CompletableFuture.dispatchIfNeeded( val dataLoaderRegistry = environment.dataLoaderRegistry as? KotlinDataLoaderRegistry ?: throw MissingKotlinDataLoaderRegistryException() if (dataLoaderRegistry.dataLoadersInvokedOnDispatch()) { - val cantContinueExecution = when { + when { environment.graphQlContext.hasKey(ExecutionLevelDispatchedState::class) -> { - environment - .graphQlContext.get(ExecutionLevelDispatchedState::class) - .allExecutionsDispatched(Level(environment.executionStepInfo.path.level)) + val cantContinueExecution = + environment + .graphQlContext.get(ExecutionLevelDispatchedState::class) + .allExecutionsDispatched(Level(environment.executionStepInfo.path.level)) + if (cantContinueExecution) { + dataLoaderRegistry.dispatchAll() + } } environment.graphQlContext.hasKey(SyncExecutionExhaustedState::class) -> { environment .graphQlContext.get(SyncExecutionExhaustedState::class) - .allSyncExecutionsExhausted() + .ifAllSyncExecutionsExhausted { + dataLoaderRegistry.dispatchAll() + } } else -> throw MissingInstrumentationStateException() } - - if (cantContinueExecution) { - dataLoaderRegistry.dispatchAll() - } } return this } diff --git a/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/state/SyncExecutionExhaustedState.kt b/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/state/SyncExecutionExhaustedState.kt index 95501fa33b..82bb4e8768 100644 --- a/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/state/SyncExecutionExhaustedState.kt +++ b/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/state/SyncExecutionExhaustedState.kt @@ -63,11 +63,12 @@ class SyncExecutionExhaustedState( override fun onCompleted(result: ExecutionResult?, t: Throwable?) { if ((result != null && result.errors.size > 0) || t != null) { if (executions.containsKey(parameters.executionInput.executionId)) { - executions.remove(parameters.executionInput.executionId) - totalExecutions.set(totalExecutions.get() - 1) - val allSyncExecutionsExhausted = allSyncExecutionsExhausted() - if (allSyncExecutionsExhausted) { - onSyncExecutionExhausted(executions.keys().toList()) + synchronized(executions) { + executions.remove(parameters.executionInput.executionId) + totalExecutions.set(totalExecutions.get() - 1) + } + ifAllSyncExecutionsExhausted { executionIds -> + onSyncExecutionExhausted(executionIds) } } } @@ -126,9 +127,8 @@ class SyncExecutionExhaustedState( executionState } - val allSyncExecutionsExhausted = allSyncExecutionsExhausted() - if (allSyncExecutionsExhausted) { - onSyncExecutionExhausted(executions.keys().toList()) + ifAllSyncExecutionsExhausted { executionIds -> + onSyncExecutionExhausted(executionIds) } } override fun onCompleted(result: Any?, t: Throwable?) { @@ -137,26 +137,26 @@ class SyncExecutionExhaustedState( executionState } - val allSyncExecutionsExhausted = allSyncExecutionsExhausted() - if (allSyncExecutionsExhausted) { - onSyncExecutionExhausted(executions.keys().toList()) + ifAllSyncExecutionsExhausted { executionIds -> + onSyncExecutionExhausted(executionIds) } } } } /** - * Provide the information about when all [ExecutionInput] sharing a [GraphQLContext] exhausted their execution + * execute a given [predicate] when all [ExecutionInput] sharing a [GraphQLContext] exhausted their execution. * A Synchronous Execution is considered Exhausted when all [DataFetcher]s of all paths were executed up until * a scalar leaf or a [DataFetcher] that returns a [CompletableFuture] */ - fun allSyncExecutionsExhausted(): Boolean = synchronized(executions) { - val operationsToExecute = totalExecutions.get() - when { - executions.size < operationsToExecute || !dataLoaderRegistry.onDispatchFuturesHandled() -> false - else -> { - executions.values.all(ExecutionBatchState::isSyncExecutionExhausted) + fun ifAllSyncExecutionsExhausted(predicate: (List) -> Unit) = + synchronized(executions) { + val operationsToExecute = totalExecutions.get() + if (executions.size < operationsToExecute || !dataLoaderRegistry.onDispatchFuturesHandled()) + return@synchronized + + if (executions.values.all(ExecutionBatchState::isSyncExecutionExhausted)) { + predicate(executions.keys().toList()) } } - } } diff --git a/executions/graphql-kotlin-dataloader-instrumentation/src/test/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/DataLoaderSyncExecutionExhaustedInstrumentationTest.kt b/executions/graphql-kotlin-dataloader-instrumentation/src/test/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/DataLoaderSyncExecutionExhaustedInstrumentationTest.kt index b2f972f6f3..7355c6e7e9 100644 --- a/executions/graphql-kotlin-dataloader-instrumentation/src/test/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/DataLoaderSyncExecutionExhaustedInstrumentationTest.kt +++ b/executions/graphql-kotlin-dataloader-instrumentation/src/test/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/DataLoaderSyncExecutionExhaustedInstrumentationTest.kt @@ -614,9 +614,9 @@ class DataLoaderSyncExecutionExhaustedInstrumentationTest { fun `Instrumentation should not consider executions that thrown exceptions`() { val executions = listOf( ExecutionInput.newExecutionInput("query test1 { astronaut(id: 1) { id name } }").operationName("test1").build(), - ExecutionInput.newExecutionInput("query test2 { astronaut(id: 2) { id name } }").operationName("test2").build(), - ExecutionInput.newExecutionInput("query test3 { mission(id: 3) { id designation } }").operationName("test3").build(), - ExecutionInput.newExecutionInput("query test4 { mission(id: 4) { designation } }").operationName("OPERATION_NOT_IN_DOCUMENT").build() + ExecutionInput.newExecutionInput("query test2 { astronaut(id: 2) { id name } }").operationName("OPERATION_NOT_IN_DOCUMENT").build(), + ExecutionInput.newExecutionInput("query test3 { mission(id: 3) { id designation } }").operationName("OPERATION_NOT_IN_DOCUMENT").build(), + ExecutionInput.newExecutionInput("query test4 { mission(id: 4) { designation } }").operationName("test4").build() ) val (results, kotlinDataLoaderRegistry) = AstronautGraphQL.execute( @@ -631,7 +631,7 @@ class DataLoaderSyncExecutionExhaustedInstrumentationTest { val missionStatistics = kotlinDataLoaderRegistry.dataLoadersMap["MissionDataLoader"]?.statistics assertEquals(1, astronautStatistics?.batchInvokeCount) - assertEquals(2, astronautStatistics?.batchLoadCount) + assertEquals(1, astronautStatistics?.batchLoadCount) assertEquals(1, missionStatistics?.batchInvokeCount) assertEquals(1, missionStatistics?.batchLoadCount)