diff --git a/src/main/kotlin/com/malinskiy/adam/AndroidDebugBridgeServer.kt b/src/main/kotlin/com/malinskiy/adam/AndroidDebugBridgeServer.kt index 76d525aa2..607aeb262 100644 --- a/src/main/kotlin/com/malinskiy/adam/AndroidDebugBridgeServer.kt +++ b/src/main/kotlin/com/malinskiy/adam/AndroidDebugBridgeServer.kt @@ -33,51 +33,63 @@ import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.produce +import kotlinx.coroutines.io.close import java.net.InetAddress import java.net.InetSocketAddress +import kotlin.coroutines.CoroutineContext class AndroidDebugBridgeServer( val port: Int, - val host: InetAddress + val host: InetAddress, + val coroutineContext: CoroutineContext ) { private val socketAddress: InetSocketAddress = InetSocketAddress(host, port) + private val selectorManager = ActorSelectorManager(coroutineContext) suspend fun execute(request: ComplexRequest, serial: String? = null): T { - aSocket(ActorSelectorManager(Dispatchers.IO)) + aSocket(selectorManager) .tcp() .connect(socketAddress).use { socket -> val readChannel = socket.openReadChannel().toAndroidChannel() - val writeChannel = socket.openWriteChannel(autoFlush = true).toAndroidChannel() - - serial?.let { - processRequest(writeChannel, SetDeviceRequest(it).serialize(), readChannel) + var writeChannel: AndroidWriteChannel? = null + try { + writeChannel = socket.openWriteChannel(autoFlush = true).toAndroidChannel() + serial?.let { + processRequest(writeChannel, SetDeviceRequest(it).serialize(), readChannel) + } + return request.process(readChannel, writeChannel) + } finally { + writeChannel?.close() } - - return request.process(readChannel, writeChannel) } } fun execute(request: AsyncChannelRequest, scope: CoroutineScope, serial: String? = null): ReceiveChannel { return scope.produce { - aSocket(ActorSelectorManager(Dispatchers.IO)) + aSocket(selectorManager) .tcp() .connect(socketAddress).use { socket -> val readChannel = socket.openReadChannel().toAndroidChannel() - val writeChannel = socket.openWriteChannel(autoFlush = true).toAndroidChannel() + var writeChannel: AndroidWriteChannel? = null - serial?.let { - processRequest(writeChannel, SetDeviceRequest(it).serialize(), readChannel) - } + try { + writeChannel = socket.openWriteChannel(autoFlush = true).toAndroidChannel() + serial?.let { + processRequest(writeChannel, SetDeviceRequest(it).serialize(), readChannel) + } - request.handshake(readChannel, writeChannel) + request.handshake(readChannel, writeChannel) - while (true) { - if (isClosedForSend || readChannel.isClosedForRead || writeChannel.isClosedForWrite) return@produce - val element = request.readElement(readChannel, writeChannel) - send(element) - } + while (true) { + if (isClosedForSend || readChannel.isClosedForRead || writeChannel.isClosedForWrite) return@produce + val element = request.readElement(readChannel, writeChannel) + send(element) + } - request.close(channel) + request.close(channel) + } finally { + writeChannel?.close() + } } } } @@ -103,11 +115,13 @@ class AndroidDebugBridgeServer( class AndroidDebugBridgeServerFactory { var port: Int? = null var host: InetAddress? = null + var coroutineContext: CoroutineContext? = null fun build(): AndroidDebugBridgeServer { return AndroidDebugBridgeServer( port = port ?: DiscoverAdbSocketInteractor().execute(), - host = host ?: InetAddress.getByName(Const.DEFAULT_ADB_HOST) + host = host ?: InetAddress.getByName(Const.DEFAULT_ADB_HOST), + coroutineContext = coroutineContext ?: Dispatchers.IO ) } } \ No newline at end of file