Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate transport #256

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions rsocket-internal-io/api/rsocket-internal-io.api
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ public final class io/rsocket/kotlin/internal/io/ChannelsKt {
public final class io/rsocket/kotlin/internal/io/ContextKt {
public static final fun childContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
public static final fun ensureActive (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;)V
public static final fun launchCoroutine (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun launchCoroutine$default (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public static final fun onCompletion (Lkotlinx/coroutines/Job;Lkotlin/jvm/functions/Function1;)Lkotlinx/coroutines/Job;
public static final fun supervisorContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,13 @@ public inline fun CoroutineContext.ensureActive(onInactive: () -> Unit) {
onInactive() // should not throw
ensureActive() // will throw
}

@Suppress("SuspendFunctionOnCoroutineScope")
public suspend inline fun <T> CoroutineScope.launchCoroutine(
context: CoroutineContext = EmptyCoroutineContext,
crossinline block: suspend (CancellableContinuation<T>) -> Unit,
): T = suspendCancellableCoroutine { cont ->
val job = launch(context) { block(cont) }
job.invokeOnCompletion { if (it != null && cont.isActive) cont.resumeWithException(it) }
cont.invokeOnCancellation { job.cancel("launchCoroutine was cancelled", it) }
}
40 changes: 40 additions & 0 deletions rsocket-transports/ktor-tcp/api/rsocket-transport-ktor-tcp.api
Original file line number Diff line number Diff line change
@@ -1,3 +1,43 @@
public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport : io/rsocket/kotlin/transport/RSocketTransport {
public static final field Factory Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$Factory;
public abstract fun target (Lio/ktor/network/sockets/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketClientTarget;
public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketClientTarget;
}

public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory {
}

public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder {
public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V
public fun inheritDispatcher ()V
public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V
public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V
public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V
}

public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance {
public abstract fun getLocalAddress ()Lio/ktor/network/sockets/SocketAddress;
}

public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport : io/rsocket/kotlin/transport/RSocketTransport {
public static final field Factory Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$Factory;
public abstract fun target (Lio/ktor/network/sockets/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketServerTarget;
public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketServerTarget;
public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport;Lio/ktor/network/sockets/SocketAddress;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget;
public static synthetic fun target$default (Lio/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport;Ljava/lang/String;IILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget;
}

public final class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory {
}

public abstract interface class io/rsocket/kotlin/transport/ktor/tcp/KtorTcpServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder {
public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V
public fun inheritDispatcher ()V
public abstract fun selectorManager (Lio/ktor/network/selector/SelectorManager;Z)V
public abstract fun selectorManagerDispatcher (Lkotlin/coroutines/CoroutineContext;)V
public abstract fun socketOptions (Lkotlin/jvm/functions/Function1;)V
}

public final class io/rsocket/kotlin/transport/ktor/tcp/TcpClientTransportKt {
public static final fun TcpClientTransport (Lio/ktor/network/sockets/InetSocketAddress;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport;
public static final fun TcpClientTransport (Ljava/lang/String;ILkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/ClientTransport;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2015-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.ktor.tcp

import io.ktor.network.selector.*
import io.ktor.network.sockets.*
import io.rsocket.kotlin.internal.io.*
import io.rsocket.kotlin.transport.*
import kotlinx.coroutines.*
import kotlin.coroutines.*

@OptIn(RSocketTransportApi::class)
public sealed interface KtorTcpClientTransport : RSocketTransport {
public fun target(remoteAddress: SocketAddress): RSocketClientTarget
public fun target(host: String, port: Int): RSocketClientTarget

public companion object Factory :
RSocketTransportFactory<KtorTcpClientTransport, KtorTcpClientTransportBuilder>(::KtorTcpClientTransportBuilderImpl)
}

@OptIn(RSocketTransportApi::class)
public sealed interface KtorTcpClientTransportBuilder : RSocketTransportBuilder<KtorTcpClientTransport> {
public fun dispatcher(context: CoroutineContext)
public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext)

public fun selectorManagerDispatcher(context: CoroutineContext)
public fun selectorManager(manager: SelectorManager, manage: Boolean)

public fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit)

//TODO: TLS support
}

private class KtorTcpClientTransportBuilderImpl : KtorTcpClientTransportBuilder {
private var dispatcher: CoroutineContext = Dispatchers.Default
private var selector: KtorTcpSelector = KtorTcpSelector.FromContext(Dispatchers.IO)
private var socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit = {}

override fun dispatcher(context: CoroutineContext) {
check(context[Job] == null) { "Dispatcher shouldn't contain job" }
this.dispatcher = context
}

override fun socketOptions(block: SocketOptions.TCPClientSocketOptions.() -> Unit) {
this.socketOptions = block
}

override fun selectorManagerDispatcher(context: CoroutineContext) {
check(context[Job] == null) { "Dispatcher shouldn't contain job" }
this.selector = KtorTcpSelector.FromContext(context)
}

override fun selectorManager(manager: SelectorManager, manage: Boolean) {
this.selector = KtorTcpSelector.FromInstance(manager, manage)
}

@RSocketTransportApi
override fun buildTransport(context: CoroutineContext): KtorTcpClientTransport {
val transportContext = context.supervisorContext() + dispatcher
return KtorTcpClientTransportImpl(
coroutineContext = transportContext,
socketOptions = socketOptions,
selectorManager = selector.createFor(transportContext)
)
}
}

private class KtorTcpClientTransportImpl(
override val coroutineContext: CoroutineContext,
private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit,
private val selectorManager: SelectorManager,
) : KtorTcpClientTransport {
override fun target(remoteAddress: SocketAddress): RSocketClientTarget = KtorTcpClientTargetImpl(
coroutineContext = coroutineContext.supervisorContext(),
socketOptions = socketOptions,
selectorManager = selectorManager,
remoteAddress = remoteAddress
)

override fun target(host: String, port: Int): RSocketClientTarget = target(InetSocketAddress(host, port))
}

@OptIn(RSocketTransportApi::class)
private class KtorTcpClientTargetImpl(
override val coroutineContext: CoroutineContext,
private val socketOptions: SocketOptions.TCPClientSocketOptions.() -> Unit,
private val selectorManager: SelectorManager,
private val remoteAddress: SocketAddress,
) : RSocketClientTarget {

@RSocketTransportApi
override fun connectClient(handler: RSocketConnectionHandler): Job = launch {
val socket = aSocket(selectorManager).tcp().connect(remoteAddress, socketOptions)
handler.handleKtorTcpConnection(socket)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright 2015-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.ktor.tcp

import io.ktor.network.sockets.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import io.rsocket.kotlin.internal.io.*
import io.rsocket.kotlin.transport.*
import io.rsocket.kotlin.transport.internal.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*

@RSocketTransportApi
internal suspend fun RSocketConnectionHandler.handleKtorTcpConnection(socket: Socket): Unit = coroutineScope {
val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED)
val inbound = channelForCloseable<ByteReadPacket>(Channel.BUFFERED)

val readerJob = launch {
val input = socket.openReadChannel()
try {
while (true) inbound.send(input.readFrame() ?: break)
input.cancel(null)
} catch (cause: Throwable) {
input.cancel(cause)
throw cause
}
}.onCompletion { inbound.cancel() }

val writerJob = launch {
val output = socket.openWriteChannel()
try {
while (true) {
// we write all available frames here, and only after it flush
// in this case, if there are several buffered frames we can send them in one go
// avoiding unnecessary flushes
output.writeFrame(outboundQueue.dequeueFrame() ?: break)
while (true) output.writeFrame(outboundQueue.tryDequeueFrame() ?: break)
output.flush()
}
output.close(null)
} catch (cause: Throwable) {
output.close(cause)
throw cause
}
}.onCompletion { outboundQueue.cancel() }

try {
handleConnection(KtorTcpConnection(outboundQueue, inbound))
} finally {
readerJob.cancel()
outboundQueue.close() // will cause `writerJob` completion
// even if it was cancelled, we still need to close socket and await it closure
withContext(NonCancellable) {
// await completion of read/write and then close socket
readerJob.join()
writerJob.join()
// close socket
socket.close()
socket.socketContext.join()
}
}
}

@RSocketTransportApi
private class KtorTcpConnection(
private val outboundQueue: PrioritizationFrameQueue,
private val inbound: ReceiveChannel<ByteReadPacket>,
) : RSocketSequentialConnection {
override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend
override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) {
return outboundQueue.enqueueFrame(streamId, frame)
}

override suspend fun receiveFrame(): ByteReadPacket? {
return inbound.receiveCatching().getOrNull()
}
}

private suspend fun ByteWriteChannel.writeFrame(frame: ByteReadPacket) {
val packet = buildPacket {
writeInt24(frame.remaining.toInt())
writePacket(frame)
}
try {
writePacket(packet)
} catch (cause: Throwable) {
packet.close()
throw cause
}
}

private suspend fun ByteReadChannel.readFrame(): ByteReadPacket? {
val lengthPacket = readRemaining(3)
if (lengthPacket.remaining == 0L) return null
return readPacket(lengthPacket.readInt24())
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2015-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.rsocket.kotlin.transport.ktor.tcp

import io.ktor.network.selector.*
import kotlinx.coroutines.*
import kotlin.coroutines.*

internal sealed class KtorTcpSelector {
class FromContext(val context: CoroutineContext) : KtorTcpSelector()
class FromInstance(val selectorManager: SelectorManager, val manage: Boolean) : KtorTcpSelector()
}

internal fun KtorTcpSelector.createFor(parentContext: CoroutineContext): SelectorManager {
val selectorManager: SelectorManager
val manage: Boolean
when (this) {
is KtorTcpSelector.FromContext -> {
selectorManager = SelectorManager(parentContext + context)
manage = true
}

is KtorTcpSelector.FromInstance -> {
selectorManager = this.selectorManager
manage = this.manage
}
}
if (manage) Job(parentContext.job).invokeOnCompletion { selectorManager.close() }
return selectorManager
}
Loading
Loading