Skip to content

Commit

Permalink
Added the ability to disable strict streamScoped requirement (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr3zee authored May 17, 2024
1 parent a5c09aa commit 3839aac
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 29 deletions.
6 changes: 6 additions & 0 deletions runtime/api/kotlinx-rpc-runtime.api
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ public final class kotlinx/rpc/RPCTransportMessage$StringMessage : kotlinx/rpc/R
public final fun getValue ()Ljava/lang/String;
}

public final class kotlinx/rpc/internal/DevStreamScopeKt {
}

public final class kotlinx/rpc/internal/ExceptionUtilsKt {
}

Expand All @@ -159,6 +162,9 @@ public final class kotlinx/rpc/internal/ScopedClientCallKt {
public final class kotlinx/rpc/internal/SerializationUtilsKt {
}

public final class kotlinx/rpc/internal/ServiceScopeKt {
}

public final class kotlinx/rpc/internal/StreamScopeKt {
public static final fun invokeOnStreamScopeCompletion (ZLkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun invokeOnStreamScopeCompletion$default (ZLkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,33 +201,21 @@ public abstract class KRPCClient(
handleOutgoingStreams(it, rpcCall.serialFormat, call.serviceTypeString)
}

var requestCancelled = false

val streamScope = streamScopeOrNull()

streamScope?.onScopeCompletion(rpcCall.callId) {
connector.unsubscribeFromMessages(call.serviceTypeString, rpcCall.callId)

if (!requestCancelled) {
requestCancelled = true

sendCancellation(CancellationType.REQUEST, call.serviceId.toString(), rpcCall.callId)
}
}

callResult.invokeOnCompletion { cause ->
// no streams available
if (rpcCall.streamContext.valueOrNull == null && streamScope != null) {
streamScope.cancelRequestScopeById(rpcCall.callId, "No streams provided", null)
}

if (cause != null) {
connector.unsubscribeFromMessages(call.serviceTypeString, rpcCall.callId)

rpcCall.streamContext.valueOrNull?.cancel("Request failed", cause)

if (!wrappedCallResult.callExceptionOccurred && !requestCancelled) {
requestCancelled = true
if (!wrappedCallResult.callExceptionOccurred) {
sendCancellation(CancellationType.REQUEST, call.serviceId.toString(), rpcCall.callId)
}
} else {
val streamScope = rpcCall.streamContext.valueOrNull?.streamScope

streamScope?.onScopeCompletion(rpcCall.callId) {
connector.unsubscribeFromMessages(call.serviceTypeString, rpcCall.callId)

sendCancellation(CancellationType.REQUEST, call.serviceId.toString(), rpcCall.callId)
}
}
Expand All @@ -248,7 +236,11 @@ public abstract class KRPCClient(

logger.trace { "start a call[$callId] ${callInfo.callableName}" }

val streamContext = LazyRPCStreamContext(streamScopeOrNull()) {
val fallbackScope = serviceScopeOrNull()
?.serviceCoroutineScope
?.let { streamScopeOrNull(it) }

val streamContext = LazyRPCStreamContext(streamScopeOrNull(), fallbackScope) {
RPCStreamContext(callId, config, connectionId, callInfo.serviceId, it)
}
val serialFormat = prepareSerialFormat(streamContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@ import kotlin.coroutines.CoroutineContext
@InternalRPCApi
public class LazyRPCStreamContext(
public val streamScopeOrNull: StreamScope?,
private val fallbackScope: StreamScope? = null,
private val initializer: (StreamScope) -> RPCStreamContext,
) {
private val deferred = CompletableDeferred<RPCStreamContext>()
private val lazyValue by lazy(LazyThreadSafetyMode.SYNCHRONIZED) {
if (streamScopeOrNull == null) {
if (streamScopeOrNull == null && (STREAM_SCOPES_ENABLED || fallbackScope == null)) {
noStreamScopeError()
}

initializer(streamScopeOrNull).also { deferred.complete(it) }
// null pointer is impossible
val streamScope = streamScopeOrNull ?: fallbackScope!!
initializer(streamScope).also { deferred.complete(it) }
}

public suspend fun awaitInitialized(): RPCStreamContext = deferred.await()
Expand All @@ -45,7 +48,7 @@ public class RPCStreamContext(
private val config: RPCConfig,
private val connectionId: Long?,
private val serviceId: Long?,
private val streamScope: StreamScope,
public val streamScope: StreamScope,
) {
private companion object {
private const val STREAM_ID_PREFIX = "stream:"
Expand Down
49 changes: 49 additions & 0 deletions runtime/src/commonMain/kotlin/kotlinx/rpc/internal/ServiceScope.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.rpc.internal

import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.withContext
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.CoroutineContext

@InternalRPCApi
public class ServiceScope(public val serviceCoroutineScope: CoroutineScope) : CoroutineContext.Element {
internal companion object Key : CoroutineContext.Key<ServiceScope>

override val key: CoroutineContext.Key<*> = Key
}

@InternalRPCApi
public suspend fun createServiceScope(serviceCoroutineScope: CoroutineScope): ServiceScope {
val context = currentCoroutineContext()

if (context[ServiceScope.Key] != null) {
error("serviceScoped nesting is not allowed")
}

return ServiceScope(serviceCoroutineScope)
}

@InternalRPCApi
public suspend fun serviceScopeOrNull(): ServiceScope? {
return currentCoroutineContext()[ServiceScope.Key]
}

@InternalRPCApi
@OptIn(ExperimentalContracts::class)
public suspend inline fun <T> serviceScoped(
serviceCoroutineScope: CoroutineScope,
noinline block: suspend CoroutineScope.() -> T,
): T {
contract {
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
}

return withContext(createServiceScope(serviceCoroutineScope), block)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.rpc.internal

/**
* For legacy internal users ONLY.
* Special dev builds may set this value to `false`.
*
* If the value is `false`, absence of [streamScoped] for a call is replaced with service's [StreamScope]
* obtained via [withClientStreamScope].
*/
@InternalRPCApi
public const val STREAM_SCOPES_ENABLED: Boolean = true
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ import kotlinx.coroutines.*
@InternalRPCApi
@OptIn(InternalCoroutinesApi::class)
@Suppress("unused")
public suspend inline fun <T> scopedClientCall(serviceScope: CoroutineScope, body: () -> T): T {
public suspend inline fun <T> scopedClientCall(serviceScope: CoroutineScope, crossinline body: suspend () -> T): T {
val requestJob = currentCoroutineContext().job
val handle = serviceScope.coroutineContext.job.invokeOnCompletion(onCancelling = true) {
requestJob.cancel(it as CancellationException)
}

try {
return body()
return serviceScoped(serviceScope) {
body()
}
} finally {
handle.dispose()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.toList
import kotlinx.rpc.client.withService
import kotlinx.rpc.internal.STREAM_SCOPES_ENABLED
import kotlinx.rpc.internal.invokeOnStreamScopeCompletion
import kotlinx.rpc.internal.streamScoped
import kotlin.test.*
Expand Down Expand Up @@ -203,7 +204,11 @@ class CancellationTest {
fun testStreamScopeAbsentForOutgoingStream() = runCancellationTest {
val fence = CompletableDeferred<Unit>()

assertFailsWith<IllegalStateException> {
if (STREAM_SCOPES_ENABLED) {
assertFailsWith<IllegalStateException> {
service.outgoingStream(resumableFlow(fence))
}
} else {
service.outgoingStream(resumableFlow(fence))
}

Expand All @@ -212,7 +217,11 @@ class CancellationTest {

@Test
fun testStreamScopeAbsentForIncomingStream() = runCancellationTest {
assertFailsWith<IllegalStateException> {
if (STREAM_SCOPES_ENABLED) {
assertFailsWith<IllegalStateException> {
service.incomingStream()
}
} else {
service.incomingStream()
}

Expand Down

0 comments on commit 3839aac

Please sign in to comment.