diff --git a/build.gradle.kts b/build.gradle.kts index 2e27c55ed..ee56df352 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -105,7 +105,7 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -tasks.compileKotlin { +tasks.withType(KotlinCompile::class) { kotlinOptions.jvmTarget = "1.8" kotlinOptions.apiVersion = "1.4" } @@ -122,9 +122,11 @@ dependencies { testImplementation(TestLibraries.assertk) testImplementation(TestLibraries.junit) testImplementation(TestLibraries.imageComparison) + testImplementation(TestLibraries.coroutinesDebug) testImplementation(kotlin("reflect", version = Versions.kotlin)) integrationTestImplementation(TestLibraries.assertk) integrationTestImplementation(TestLibraries.junit) + integrationTestImplementation(TestLibraries.coroutinesDebug) integrationTestImplementation(kotlin("reflect", version = Versions.kotlin)) } diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index 50e213a93..44993821f 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -13,6 +13,7 @@ object Versions { val imageComparison = "4.3.0" val dokka = kotlin val pdbank = "0.9.1" + val coroutinesDebug = "1.4.0" } object BuildPlugins { @@ -32,4 +33,5 @@ object TestLibraries { val assertk = "com.willowtreeapps.assertk:assertk:${Versions.assertk}" val junit = "junit:junit:${Versions.junit}" val imageComparison = "com.github.romankh3:image-comparison:${Versions.imageComparison}" + val coroutinesDebug = "org.jetbrains.kotlinx:kotlinx-coroutines-debug:${Versions.coroutinesDebug}" } diff --git a/src/integrationTest/kotlin/com/malinskiy/adam/rule/AdbDeviceRule.kt b/src/integrationTest/kotlin/com/malinskiy/adam/rule/AdbDeviceRule.kt index 543a31dc7..202288dcb 100644 --- a/src/integrationTest/kotlin/com/malinskiy/adam/rule/AdbDeviceRule.kt +++ b/src/integrationTest/kotlin/com/malinskiy/adam/rule/AdbDeviceRule.kt @@ -25,6 +25,7 @@ import com.malinskiy.adam.request.device.ListDevicesRequest import com.malinskiy.adam.request.misc.GetAdbServerVersionRequest import com.malinskiy.adam.request.prop.GetSinglePropRequest import com.malinskiy.adam.request.shell.v1.ShellCommandRequest +import com.malinskiy.adam.transport.NioSocketFactory import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.isActive import kotlinx.coroutines.runBlocking @@ -46,8 +47,10 @@ class AdbDeviceRule(val deviceType: DeviceType = DeviceType.ANY, vararg val requ lateinit var deviceSerial: String lateinit var supportedFeatures: List lateinit var lineSeparator: String - - val adb = AndroidDebugBridgeClientFactory().build() + + val adb = AndroidDebugBridgeClientFactory().apply { + socketFactory = NioSocketFactory() + }.build() val initTimeout = Duration.ofSeconds(10) override fun apply(base: Statement, description: Description): Statement { diff --git a/src/main/kotlin/com/malinskiy/adam/AndroidDebugBridgeClient.kt b/src/main/kotlin/com/malinskiy/adam/AndroidDebugBridgeClient.kt index e4ffd5550..81d93080e 100644 --- a/src/main/kotlin/com/malinskiy/adam/AndroidDebugBridgeClient.kt +++ b/src/main/kotlin/com/malinskiy/adam/AndroidDebugBridgeClient.kt @@ -16,9 +16,7 @@ package com.malinskiy.adam -import com.malinskiy.adam.exception.RequestRejectedException import com.malinskiy.adam.exception.RequestValidationException -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.interactor.DiscoverAdbSocketInteractor import com.malinskiy.adam.log.AdamLogging import com.malinskiy.adam.request.AsyncChannelRequest @@ -26,12 +24,9 @@ import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.MultiRequest import com.malinskiy.adam.request.emu.EmulatorCommandRequest import com.malinskiy.adam.request.misc.SetDeviceRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel import com.malinskiy.adam.transport.KtorSocketFactory import com.malinskiy.adam.transport.SocketFactory -import io.ktor.network.sockets.* -import io.ktor.utils.io.* +import com.malinskiy.adam.transport.use import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.NonCancellable @@ -57,22 +52,10 @@ class AndroidDebugBridgeClient( throw RequestValidationException("Request $requestSimpleClassName did not pass validation: ${validationResponse.message}") } socketFactory.tcp(socketAddress).use { socket -> - val readChannel = socket.openReadChannel().toAndroidChannel() - var writeChannel: AndroidWriteChannel? = null - try { - writeChannel = socket.openWriteChannel(autoFlush = true).toAndroidChannel() - serial?.let { - processRequest(writeChannel, SetDeviceRequest(it).serialize(), readChannel) - } - return request.process(readChannel, writeChannel) - } finally { - try { - writeChannel?.close() - readChannel.cancel() - } catch (e: Exception) { - log.debug(e) { "Exception during cleanup. Ignoring" } - } + serial?.let { + SetDeviceRequest(it).handshake(socket) } + return request.process(socket) } } @@ -84,32 +67,28 @@ class AndroidDebugBridgeClient( } return scope.produce { socketFactory.tcp(socketAddress).use { socket -> - val readChannel = socket.openReadChannel().toAndroidChannel() - var writeChannel: AndroidWriteChannel? = null var backChannel = request.channel try { - writeChannel = socket.openWriteChannel(autoFlush = true).toAndroidChannel() serial?.let { - processRequest(writeChannel, SetDeviceRequest(it).serialize(), readChannel) + SetDeviceRequest(it).handshake(socket) } - request.handshake(readChannel, writeChannel) + request.handshake(socket) while (true) { if (isClosedForSend || - readChannel.isClosedForRead || - writeChannel.isClosedForWrite || + socket.isClosedForRead || + socket.isClosedForWrite || request.channel?.isClosedForReceive == true ) { break } - request.readElement(readChannel, writeChannel)?.let { - send(it) - } + val finished = request.readElement(socket, this) + if (finished) break backChannel?.poll()?.let { - request.writeElement(it, readChannel, writeChannel) + request.writeElement(it, socket) } } } finally { @@ -117,8 +96,6 @@ class AndroidDebugBridgeClient( withContext(NonCancellable) { request.close(channel) } - writeChannel?.close() - readChannel.cancel() } catch (e: Exception) { log.debug(e) { "Exception during cleanup. Ignoring" } } @@ -129,18 +106,7 @@ class AndroidDebugBridgeClient( suspend fun execute(request: EmulatorCommandRequest): String { socketFactory.tcp(request.address).use { socket -> - var readChannel: ByteReadChannel? = null - var writeChannel: ByteWriteChannel? = null - - try { - readChannel = socket.openReadChannel() - writeChannel = socket.openWriteChannel(true) - - return request.process(readChannel, writeChannel) - } finally { - readChannel?.cancel() - writeChannel?.close() - } + return request.process(socket) } } @@ -154,19 +120,6 @@ class AndroidDebugBridgeClient( return request.execute(this, serial) } - private suspend fun processRequest( - writeChannel: AndroidWriteChannel, - request: ByteArray, - readChannel: AndroidReadChannel - ) { - writeChannel.write(request) - val response = readChannel.read() - if (!response.okay) { - log.warn { "adb server rejected command ${String(request, Const.DEFAULT_TRANSPORT_ENCODING)}" } - throw RequestRejectedException(response.message ?: "no message received") - } - } - companion object { private val log = AdamLogging.logger {} } diff --git a/src/main/kotlin/com/malinskiy/adam/Const.kt b/src/main/kotlin/com/malinskiy/adam/Const.kt index 3cd9d7ff8..ec53f2cfe 100644 --- a/src/main/kotlin/com/malinskiy/adam/Const.kt +++ b/src/main/kotlin/com/malinskiy/adam/Const.kt @@ -19,7 +19,6 @@ package com.malinskiy.adam object Const { const val MAX_REMOTE_PATH_LENGTH = 1024 const val DEFAULT_BUFFER_SIZE = 1024 - const val READ_DELAY = 100L val DEFAULT_TRANSPORT_ENCODING = Charsets.UTF_8 const val DEFAULT_ADB_HOST = "127.0.0.1" const val DEFAULT_ADB_PORT = 5037 @@ -27,6 +26,7 @@ object Const { const val SERVER_PORT_ENV_VAR = "ANDROID_ADB_SERVER_PORT" const val MAX_PACKET_LENGTH = 16384 const val MAX_FILE_PACKET_LENGTH = 64 * 1024 + const val KTOR_INTERNAL_BUFFER_LENGTH = 4088 const val MAX_PROTOBUF_LOGCAT_LENGTH = 10_000 const val MAX_PROTOBUF_PACKET_LENGTH = 10 * 1024 * 1024L //10Mb diff --git a/src/main/kotlin/com/malinskiy/adam/extension/ByteReadChannel.kt b/src/main/kotlin/com/malinskiy/adam/extension/ByteReadChannel.kt index f5fd2e97a..f82c021c3 100644 --- a/src/main/kotlin/com/malinskiy/adam/extension/ByteReadChannel.kt +++ b/src/main/kotlin/com/malinskiy/adam/extension/ByteReadChannel.kt @@ -16,11 +16,12 @@ package com.malinskiy.adam.extension -import com.malinskiy.adam.request.transform.ResponseTransformer +import com.malinskiy.adam.transport.Socket import io.ktor.utils.io.* import java.nio.ByteBuffer -suspend fun ByteReadChannel.copyTo(channel: ByteWriteChannel, buffer: ByteArray): Long { +suspend fun ByteReadChannel.copyTo(socket: Socket, buffer: ByteBuffer) = copyTo(socket, buffer.array()) +suspend fun ByteReadChannel.copyTo(socket: Socket, buffer: ByteArray): Long { var processed = 0L loop@ while (true) { val available = readAvailable(buffer, 0, buffer.size) @@ -29,50 +30,11 @@ suspend fun ByteReadChannel.copyTo(channel: ByteWriteChannel, buffer: ByteArray) break@loop } available > 0 -> { - channel.writeFully(buffer, 0, available) + socket.writeFully(buffer, 0, available) processed += available } else -> continue@loop } } return processed -} - -/** - * Copies up to limit bytes into transformer using buffer. If limit is null - copy until EOF - */ -suspend fun ByteReadChannel.copyTo(transformer: ResponseTransformer, buffer: ByteArray, limit: Long? = null): Long { - var processed = 0L - loop@ while (true) { - val toRead = when { - limit == null || (limit - processed) > buffer.size -> { - buffer.size - } - else -> { - (limit - processed).toInt() - } - } - val available = readAvailable(buffer, 0, toRead) - when { - processed == limit -> break@loop - available < 0 -> { - break@loop - } - available > 0 -> { - transformer.process(buffer, 0, available) - processed += available - } - else -> continue@loop - } - } - return processed -} - -/** - * TODO: rewrite - * Assumes buffer hasArray == true - */ -suspend fun ByteReadChannel.copyTo(channel: ByteWriteChannel, buffer: ByteBuffer) = copyTo(channel, buffer.array()) -suspend fun ByteReadChannel.copyTo(transformer: ResponseTransformer, buffer: ByteBuffer) = copyTo(transformer, buffer.array()) -suspend fun ByteReadChannel.copyTo(transformer: ResponseTransformer, buffer: ByteBuffer, limit: Long? = null) = - copyTo(transformer, buffer.array(), limit) +} \ No newline at end of file diff --git a/src/main/kotlin/com/malinskiy/adam/extension/Socket.kt b/src/main/kotlin/com/malinskiy/adam/extension/Socket.kt new file mode 100644 index 000000000..6c0b88e9c --- /dev/null +++ b/src/main/kotlin/com/malinskiy/adam/extension/Socket.kt @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2021 Anton Malinskiy + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.malinskiy.adam.extension + +import com.malinskiy.adam.Const +import com.malinskiy.adam.request.transform.ResponseTransformer +import com.malinskiy.adam.request.transform.StringResponseTransformer +import com.malinskiy.adam.transport.Socket +import com.malinskiy.adam.transport.TransportResponse +import com.malinskiy.adam.transport.withDefaultBuffer +import io.ktor.util.cio.* +import io.ktor.utils.io.* +import io.ktor.utils.io.bits.* +import io.ktor.utils.io.core.* +import java.io.File +import java.nio.ByteBuffer +import kotlin.coroutines.CoroutineContext +import kotlin.text.toByteArray + +suspend fun Socket.copyTo(channel: ByteWriteChannel, buffer: ByteArray): Long { + var processed = 0L + loop@ while (true) { + val available = readAvailable(buffer, 0, buffer.size) + when { + available < 0 -> { + break@loop + } + available > 0 -> { + channel.writeFully(buffer, 0, available) + processed += available + } + else -> continue@loop + } + } + return processed +} + +/** + * Copies up to limit bytes into transformer using buffer. If limit is null - copy until EOF + */ +suspend fun Socket.copyTo(transformer: ResponseTransformer, buffer: ByteArray, limit: Long? = null): Long { + var processed = 0L + loop@ while (true) { + val toRead = when { + limit == null || (limit - processed) > buffer.size -> { + buffer.size + } + else -> { + (limit - processed).toInt() + } + } + val available = readAvailable(buffer, 0, toRead) + when { + processed == limit -> break@loop + available < 0 -> { + break@loop + } + available > 0 -> { + transformer.process(buffer, 0, available) + processed += available + } + else -> continue@loop + } + } + return processed +} + +/** + * TODO: rewrite + * Assumes buffer hasArray == true + */ +suspend fun Socket.copyTo(channel: ByteWriteChannel, buffer: ByteBuffer) = copyTo(channel, buffer.array()) +suspend fun Socket.copyTo(transformer: ResponseTransformer, buffer: ByteBuffer) = copyTo(transformer, buffer.array()) +suspend fun Socket.copyTo(transformer: ResponseTransformer, buffer: ByteBuffer, limit: Long? = null) = + copyTo(transformer, buffer.array(), limit) + +suspend fun Socket.readOptionalProtocolString(): String? { + val responseLength = withDefaultBuffer { + val transformer = StringResponseTransformer() + copyTo(transformer, this, limit = 4L) + transformer.transform() + } + val errorMessageLength = responseLength.toIntOrNull(16) + return if (errorMessageLength == null) { + readStatus() + } else { + val errorBytes = ByteArray(errorMessageLength) + readFully(errorBytes, 0, errorMessageLength) + String(errorBytes, Const.DEFAULT_TRANSPORT_ENCODING) + } +} + +suspend fun Socket.read(): TransportResponse { + val bytes = ByteArray(4) + readFully(bytes, 0, 4) + + val ok = bytes.isOkay() + val message = if (!ok) { + readOptionalProtocolString() + } else { + null + } + + return TransportResponse(ok, message) +} + +private fun ByteArray.isOkay() = contentEquals(Const.Message.OKAY) + +suspend fun Socket.readStatus(): String { + withDefaultBuffer { + val transformer = StringResponseTransformer() + copyTo(transformer, this) + return transformer.transform() + } +} + +suspend fun Socket.write(request: ByteArray, length: Int? = null) { + writeFully(request, 0, length ?: request.size) +} + +suspend fun Socket.writeFile(file: File, coroutineContext: CoroutineContext) = withDefaultBuffer { + var fileChannel: ByteReadChannel? = null + try { + val fileChannel = file.readChannel(coroutineContext = coroutineContext) + fileChannel.copyTo(this@writeFile, this) + } finally { + fileChannel?.cancel() + } +} + +suspend fun Socket.writeSyncRequest(type: ByteArray, remotePath: String) { + val path = remotePath.toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) + val size = path.size.toByteArray().reversedArray() + + val cmd = ByteArray(8 + path.size) + + type.copyInto(cmd) + size.copyInto(cmd, 4) + path.copyInto(cmd, 8) + write(cmd) +} + +suspend fun Socket.writeSyncV2Request(type: ByteArray, remotePath: String, flags: Int, mode: Int? = null) { + val path = remotePath.toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) + + withDefaultBuffer { + compatLimit(4 + 4) + put(type) + putInt(path.size.reverseByteOrder()) + compatRewind() + writeFully(this) + + writeFully(path) + + compatRewind() + mode?.let { compatLimit(4 + 4 + 4) } ?: compatLimit(4 + 4) + put(type) + mode?.let { putInt(it.reverseByteOrder()) } + putInt(flags.reverseByteOrder()) + compatRewind() + writeFully(this) + } +} + +suspend fun Socket.readTransportResponse(): TransportResponse { + val bytes = ByteArray(4) + readFully(bytes, 0, 4) + + val ok = bytes.isOkay() + val message = if (!ok) { + readOptionalProtocolString() + } else { + null + } + + return TransportResponse(ok, message) +} diff --git a/src/main/kotlin/com/malinskiy/adam/request/AsyncChannelRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/AsyncChannelRequest.kt index e1426cc4d..e62dac6b5 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/AsyncChannelRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/AsyncChannelRequest.kt @@ -16,8 +16,7 @@ package com.malinskiy.adam.request -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel @@ -33,12 +32,12 @@ abstract class AsyncChannelRequest( /** * Called after the initial OKAY confirmation */ - abstract suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): T? + abstract suspend fun readElement(socket: Socket, sendChannel: SendChannel): Boolean /** * Called after each readElement */ - abstract suspend fun writeElement(element: I, readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) + abstract suspend fun writeElement(element: I, socket: Socket) /** * Optionally send a message diff --git a/src/main/kotlin/com/malinskiy/adam/request/ComplexRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/ComplexRequest.kt index 91a677be8..3400fb775 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/ComplexRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/ComplexRequest.kt @@ -16,8 +16,7 @@ package com.malinskiy.adam.request -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket /** * This type of request starts with single serialized request @@ -28,10 +27,10 @@ abstract class ComplexRequest(target: Target = NonSpecifiedTarget) : R * Some requests ignore the initial OKAY/FAIL response and instead stream the actual response * To implement these we allow overriding this method */ - open suspend fun process(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): T { - handshake(readChannel, writeChannel) - return readElement(readChannel, writeChannel) + open suspend fun process(socket: Socket): T { + handshake(socket) + return readElement(socket) } - abstract suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): T + abstract suspend fun readElement(socket: Socket): T } \ No newline at end of file diff --git a/src/main/kotlin/com/malinskiy/adam/request/Request.kt b/src/main/kotlin/com/malinskiy/adam/request/Request.kt index 3786120d9..21a7a77b9 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/Request.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/Request.kt @@ -18,9 +18,10 @@ package com.malinskiy.adam.request import com.malinskiy.adam.Const import com.malinskiy.adam.exception.RequestRejectedException +import com.malinskiy.adam.extension.readTransportResponse +import com.malinskiy.adam.extension.write import com.malinskiy.adam.log.AdamLogging -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.io.UnsupportedEncodingException /** @@ -35,10 +36,10 @@ open abstract class Request(val target: Target = HostTarget) { */ abstract fun serialize(): ByteArray - open suspend fun handshake(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { + open suspend fun handshake(socket: Socket) { val request = serialize() - writeChannel.write(request) - val response = readChannel.read() + socket.write(request) + val response = socket.readTransportResponse() if (!response.okay) { log.warn { "adb server rejected command ${String(request, Const.DEFAULT_TRANSPORT_ENCODING)}" } throw RequestRejectedException(response.message ?: "no message received") diff --git a/src/main/kotlin/com/malinskiy/adam/request/SynchronousRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/SynchronousRequest.kt index c27238748..76775e3a4 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/SynchronousRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/SynchronousRequest.kt @@ -18,16 +18,15 @@ package com.malinskiy.adam.request import com.malinskiy.adam.Const import com.malinskiy.adam.request.transform.ResponseTransformer -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket abstract class SynchronousRequest(target: Target = NonSpecifiedTarget) : ComplexRequest(target), ResponseTransformer { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): T { + override suspend fun readElement(socket: Socket): T { val data = ByteArray(Const.MAX_PACKET_LENGTH) loop@ do { - if (writeChannel.isClosedForWrite || readChannel.isClosedForRead) break@loop + if (socket.isClosedForWrite || socket.isClosedForRead) break@loop - val count = readChannel.readAvailable(data, 0, Const.MAX_PACKET_LENGTH) + val count = socket.readAvailable(data, 0, Const.MAX_PACKET_LENGTH) when { count == 0 -> { continue@loop diff --git a/src/main/kotlin/com/malinskiy/adam/request/abb/AbbExecRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/abb/AbbExecRequest.kt index fb0d460d2..8e4841446 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/abb/AbbExecRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/abb/AbbExecRequest.kt @@ -22,17 +22,16 @@ import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.transform.StringResponseTransformer -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import com.malinskiy.adam.transport.withDefaultBuffer @Features(Feature.ABB_EXEC) open class AbbExecRequest(private val args: List, private val supportedFeatures: List) : ComplexRequest() { private val transformer = StringResponseTransformer() - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): String { + override suspend fun readElement(socket: Socket): String { withDefaultBuffer { - readChannel.copyTo(transformer, this) + socket.copyTo(transformer, this) } return transformer.transform() } diff --git a/src/main/kotlin/com/malinskiy/adam/request/device/AsyncDeviceMonitorRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/device/AsyncDeviceMonitorRequest.kt index 050b31ffb..21660d5e4 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/device/AsyncDeviceMonitorRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/device/AsyncDeviceMonitorRequest.kt @@ -19,31 +19,33 @@ package com.malinskiy.adam.request.device import com.malinskiy.adam.Const import com.malinskiy.adam.request.AsyncChannelRequest import com.malinskiy.adam.request.HostTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket +import kotlinx.coroutines.channels.SendChannel import java.nio.ByteBuffer class AsyncDeviceMonitorRequest : AsyncChannelRequest, Unit>(target = HostTarget) { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): List? { + override suspend fun readElement(socket: Socket, sendChannel: SendChannel>): Boolean { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) val payload = String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING) - return payload.lines() - .filter { it.isNotEmpty() } - .map { - val line = it.trim() - val split = line.split("\t") - Device( - serial = split[0], - state = DeviceState.from(split[1]) - ) - } + sendChannel.send(payload.lines() + .filter { it.isNotEmpty() } + .map { + val line = it.trim() + val split = line.split("\t") + Device( + serial = split[0], + state = DeviceState.from(split[1]) + ) + } + ) + return false } override fun serialize() = createBaseRequest("track-devices") - override suspend fun writeElement(element: Unit, readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) = Unit + override suspend fun writeElement(element: Unit, socket: Socket) = Unit } diff --git a/src/main/kotlin/com/malinskiy/adam/request/device/FetchDeviceFeaturesRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/device/FetchDeviceFeaturesRequest.kt index 16af534fb..13a0d12d0 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/device/FetchDeviceFeaturesRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/device/FetchDeviceFeaturesRequest.kt @@ -20,8 +20,7 @@ import com.malinskiy.adam.Const import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.SerialTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer @@ -29,13 +28,13 @@ class FetchDeviceFeaturesRequest(serial: String) : ComplexRequest> override fun serialize() = createBaseRequest("features") - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): List { + override suspend fun readElement(socket: Socket): List { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) return String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).split(',').mapNotNull { Feature.of(it) } } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/device/ListDevicesRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/device/ListDevicesRequest.kt index dfa455c65..3a33f2929 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/device/ListDevicesRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/device/ListDevicesRequest.kt @@ -19,20 +19,19 @@ package com.malinskiy.adam.request.device import com.malinskiy.adam.Const import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.HostTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer class ListDevicesRequest : ComplexRequest>(target = HostTarget) { override fun serialize() = createBaseRequest("devices") - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): List { + override suspend fun readElement(socket: Socket): List { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) val payload = String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING) return payload.lines() .filter { it.isNotEmpty() } diff --git a/src/main/kotlin/com/malinskiy/adam/request/emu/EmulatorCommandRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/emu/EmulatorCommandRequest.kt index 433690717..d42094faf 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/emu/EmulatorCommandRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/emu/EmulatorCommandRequest.kt @@ -16,7 +16,8 @@ package com.malinskiy.adam.request.emu -import com.malinskiy.adam.Const +import com.malinskiy.adam.transport.Socket +import com.malinskiy.adam.transport.withDefaultBuffer import io.ktor.util.cio.* import io.ktor.utils.io.* import java.io.File @@ -50,7 +51,7 @@ class EmulatorCommandRequest( } } - suspend fun process(readChannel: ByteReadChannel, writeChannel: ByteWriteChannel): String { + suspend fun process(socket: Socket): String { val sessionBuilder = StringBuilder() val token = authToken ?: readAuthToken() ?: "" if (token.isNotEmpty()) { @@ -59,27 +60,29 @@ class EmulatorCommandRequest( sessionBuilder.append("$cmd\n") sessionBuilder.append("quit\n") - writeChannel.writeFully(sessionBuilder.toString().toByteArray()) + socket.writeFully(sessionBuilder.toString().toByteArray()) - val buffer = ByteArray(1024) - val output = StringBuilder() - loop@ do { - if (writeChannel.isClosedForWrite || readChannel.isClosedForRead) break@loop + withDefaultBuffer { + val buffer = array() + val output = StringBuilder() + loop@ do { + if (socket.isClosedForWrite || socket.isClosedForRead) break@loop - val count = readChannel.readAvailable(buffer, 0, Const.MAX_PACKET_LENGTH) - when { - count == 0 -> { - continue@loop + val count = socket.readAvailable(buffer, 0, buffer.size) + when { + count == 0 -> { + continue@loop + } + count > 0 -> { + output.append(String(buffer, 0, count, Charsets.UTF_8)) + } } - count > 0 -> { - output.append(String(buffer, 0, count, Charsets.UTF_8)) - } - } - } while (count >= 0) + } while (count >= 0) - val firstOkPosition = output.indexOf(OUTPUT_DELIMITER) - val secondOkPosition = output.indexOf(OUTPUT_DELIMITER, firstOkPosition + 1) - return output.substring(secondOkPosition + OUTPUT_DELIMITER.length) + val firstOkPosition = output.indexOf(OUTPUT_DELIMITER) + val secondOkPosition = output.indexOf(OUTPUT_DELIMITER, firstOkPosition + 1) + return output.substring(secondOkPosition + OUTPUT_DELIMITER.length) + } } companion object { diff --git a/src/main/kotlin/com/malinskiy/adam/request/forwarding/ListPortForwardsRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/forwarding/ListPortForwardsRequest.kt index 9723e44a2..e6bb64a1d 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/forwarding/ListPortForwardsRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/forwarding/ListPortForwardsRequest.kt @@ -19,18 +19,17 @@ package com.malinskiy.adam.request.forwarding import com.malinskiy.adam.Const import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.SerialTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer class ListPortForwardsRequest(serial: String) : ComplexRequest>(target = SerialTarget(serial)) { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): List { + override suspend fun readElement(socket: Socket): List { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) val payload = String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING) return payload.lines().mapNotNull { line -> if (line.isNotEmpty()) { diff --git a/src/main/kotlin/com/malinskiy/adam/request/forwarding/PortForwardRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/forwarding/PortForwardRequest.kt index fb019d503..67fb465f2 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/forwarding/PortForwardRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/forwarding/PortForwardRequest.kt @@ -17,10 +17,11 @@ package com.malinskiy.adam.request.forwarding import com.malinskiy.adam.exception.RequestRejectedException +import com.malinskiy.adam.extension.readOptionalProtocolString +import com.malinskiy.adam.extension.readTransportResponse import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.SerialTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket /** * Optionally returns a local TCP port that is occupied now if using LocalTcpPortSpec without any parameters @@ -36,13 +37,13 @@ class PortForwardRequest( override fun serialize() = createBaseRequest("forward${mode.value}:${local.toSpec()};${remote.toSpec()}") - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): Int? { - val transportResponse = readChannel.read() + override suspend fun readElement(socket: Socket): Int? { + val transportResponse = socket.readTransportResponse() if (!transportResponse.okay) { throw RequestRejectedException("Can't establish port forwarding: ${transportResponse.message ?: ""}") } - return readChannel.readOptionalProtocolString()?.toIntOrNull() + return socket.readOptionalProtocolString()?.toIntOrNull() } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/framebuffer/BufferedImageScreenCaptureAdapter.kt b/src/main/kotlin/com/malinskiy/adam/request/framebuffer/BufferedImageScreenCaptureAdapter.kt index 9aeae2cde..d3f523032 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/framebuffer/BufferedImageScreenCaptureAdapter.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/framebuffer/BufferedImageScreenCaptureAdapter.kt @@ -17,7 +17,7 @@ package com.malinskiy.adam.request.framebuffer import com.malinskiy.adam.extension.compatRewind -import com.malinskiy.adam.transport.AndroidReadChannel +import com.malinskiy.adam.transport.Socket import io.ktor.utils.io.bits.* import java.awt.image.BufferedImage import java.nio.ByteBuffer @@ -43,9 +43,9 @@ class BufferedImageScreenCaptureAdapter( alphaOffset: Int, alphaLength: Int, colorSpace: ColorSpace?, - channel: AndroidReadChannel + socket: Socket ): BufferedImage { - val imageBuffer: ByteBuffer = read(channel, size) + val imageBuffer: ByteBuffer = read(socket, size) imageBuffer.compatRewind() return when (bitsPerPixel) { 16 -> { diff --git a/src/main/kotlin/com/malinskiy/adam/request/framebuffer/RawImageScreenCaptureAdapter.kt b/src/main/kotlin/com/malinskiy/adam/request/framebuffer/RawImageScreenCaptureAdapter.kt index 1cadf1f6f..291a31ffc 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/framebuffer/RawImageScreenCaptureAdapter.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/framebuffer/RawImageScreenCaptureAdapter.kt @@ -16,7 +16,7 @@ package com.malinskiy.adam.request.framebuffer -import com.malinskiy.adam.transport.AndroidReadChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer class RawImageScreenCaptureAdapter(buffer: ByteBuffer? = null) : ScreenCaptureAdapter(buffer = buffer) { @@ -36,9 +36,9 @@ class RawImageScreenCaptureAdapter(buffer: ByteBuffer? = null) : ScreenCaptureAd alphaOffset: Int, alphaLength: Int, colorSpace: ColorSpace?, - channel: AndroidReadChannel + socket: Socket ): RawImage { - val imageBuffer = read(channel, size) + val imageBuffer = read(socket, size) return RawImage( version = version, diff --git a/src/main/kotlin/com/malinskiy/adam/request/framebuffer/ScreenCaptureAdapter.kt b/src/main/kotlin/com/malinskiy/adam/request/framebuffer/ScreenCaptureAdapter.kt index f6e9fb5ad..07e000aa1 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/framebuffer/ScreenCaptureAdapter.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/framebuffer/ScreenCaptureAdapter.kt @@ -17,8 +17,7 @@ package com.malinskiy.adam.request.framebuffer import com.malinskiy.adam.extension.compatRewind -import com.malinskiy.adam.transport.AndroidReadChannel -import io.ktor.utils.io.* +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer /** @@ -32,7 +31,7 @@ abstract class ScreenCaptureAdapter( private var buffer: ByteBuffer? = null, protected val colorModelFactory: ColorModelFactory = ColorModelFactory() ) { - suspend fun read(channel: ByteReadChannel, size: Int): ByteBuffer { + suspend fun read(socket: Socket, size: Int): ByteBuffer { val localBuffer = buffer val imageBuffer = if (localBuffer != null && localBuffer.capacity() == size) { @@ -43,7 +42,7 @@ abstract class ScreenCaptureAdapter( ByteBuffer.allocate(size).also { buffer = it } } - channel.readFully(imageBuffer) + socket.readFully(imageBuffer) return imageBuffer } @@ -62,6 +61,6 @@ abstract class ScreenCaptureAdapter( alphaOffset: Int, alphaLength: Int, colorSpace: ColorSpace? = null, - channel: AndroidReadChannel + socket: Socket ): T } diff --git a/src/main/kotlin/com/malinskiy/adam/request/framebuffer/ScreenCaptureRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/framebuffer/ScreenCaptureRequest.kt index 21dfebb4e..fb90cf66a 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/framebuffer/ScreenCaptureRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/framebuffer/ScreenCaptureRequest.kt @@ -19,15 +19,14 @@ package com.malinskiy.adam.request.framebuffer import com.malinskiy.adam.exception.UnsupportedImageProtocolException import com.malinskiy.adam.extension.compatRewind import com.malinskiy.adam.request.ComplexRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer import java.nio.ByteOrder class ScreenCaptureRequest(private val adapter: ScreenCaptureAdapter) : ComplexRequest() { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): T { + override suspend fun readElement(socket: Socket): T { val protocolBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(protocolBuffer) + socket.readFully(protocolBuffer) protocolBuffer.compatRewind() val protocolVersion = protocolBuffer.order(ByteOrder.LITTLE_ENDIAN).int @@ -42,9 +41,9 @@ class ScreenCaptureRequest(private val adapter: ScreenCaptureAdapter) : Co else -> throw UnsupportedImageProtocolException(protocolVersion) } val headerBuffer = ByteBuffer.allocate(headerSize * 4) - readChannel.readFully(headerBuffer) + socket.readFully(headerBuffer) headerBuffer.compatRewind() - writeChannel.writeFully(ByteArray(1) { 0.toByte() }, 0, 1) + socket.writeFully(ByteArray(1) { 0.toByte() }, 0, 1) headerBuffer.order(ByteOrder.LITTLE_ENDIAN) headerBuffer.compatRewind() @@ -63,7 +62,7 @@ class ScreenCaptureRequest(private val adapter: ScreenCaptureAdapter) : Co blueLength = 5, alphaOffset = 0, alphaLength = 0, - channel = readChannel + socket = socket ) 1 -> adapter.process( version = protocolVersion, @@ -79,7 +78,7 @@ class ScreenCaptureRequest(private val adapter: ScreenCaptureAdapter) : Co greenLength = headerBuffer.int, alphaOffset = headerBuffer.int, alphaLength = headerBuffer.int, - channel = readChannel + socket = socket ) 2 -> adapter.process( version = protocolVersion, @@ -96,7 +95,7 @@ class ScreenCaptureRequest(private val adapter: ScreenCaptureAdapter) : Co greenLength = headerBuffer.int, alphaOffset = headerBuffer.int, alphaLength = headerBuffer.int, - channel = readChannel + socket = socket ) else -> throw UnsupportedImageProtocolException(protocolVersion) } diff --git a/src/main/kotlin/com/malinskiy/adam/request/mdns/ListMdnsServicesRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/mdns/ListMdnsServicesRequest.kt index 6af5d6602..c53eff455 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/mdns/ListMdnsServicesRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/mdns/ListMdnsServicesRequest.kt @@ -19,18 +19,17 @@ package com.malinskiy.adam.request.mdns import com.malinskiy.adam.Const import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.HostTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer class ListMdnsServicesRequest : ComplexRequest>(target = HostTarget) { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): List { + override suspend fun readElement(socket: Socket): List { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) return String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING) .lines() .filterNot { it.isEmpty() } diff --git a/src/main/kotlin/com/malinskiy/adam/request/mdns/MdnsCheckRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/mdns/MdnsCheckRequest.kt index b26b09b06..31a93b231 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/mdns/MdnsCheckRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/mdns/MdnsCheckRequest.kt @@ -19,21 +19,20 @@ package com.malinskiy.adam.request.mdns import com.malinskiy.adam.Const import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.HostTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer /** * check if mdns discovery is available */ class MdnsCheckRequest : ComplexRequest(target = HostTarget) { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): MdnsStatus { + override suspend fun readElement(socket: Socket): MdnsStatus { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) val string = String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING) return if (string.contains("mdns daemon unavailable")) { MdnsStatus(false) diff --git a/src/main/kotlin/com/malinskiy/adam/request/misc/ConnectDeviceRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/misc/ConnectDeviceRequest.kt index cc8425199..5a6050ab7 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/misc/ConnectDeviceRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/misc/ConnectDeviceRequest.kt @@ -19,8 +19,7 @@ package com.malinskiy.adam.request.misc import com.malinskiy.adam.Const import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.HostTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer /** @@ -33,13 +32,13 @@ class ConnectDeviceRequest( override fun serialize() = createBaseRequest("connect:$host:$port") - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): String { + override suspend fun readElement(socket: Socket): String { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) return String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING) } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/misc/DisconnectDeviceRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/misc/DisconnectDeviceRequest.kt index 33574fc9f..c3348101f 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/misc/DisconnectDeviceRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/misc/DisconnectDeviceRequest.kt @@ -19,8 +19,7 @@ package com.malinskiy.adam.request.misc import com.malinskiy.adam.Const import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.HostTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer /** @@ -43,13 +42,13 @@ class DisconnectDeviceRequest( }" ) - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): String { + override suspend fun readElement(socket: Socket): String { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) return String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING) } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/misc/ExecInRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/misc/ExecInRequest.kt index 12c3312d0..fad9c0802 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/misc/ExecInRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/misc/ExecInRequest.kt @@ -18,20 +18,20 @@ package com.malinskiy.adam.request.misc import com.malinskiy.adam.Const import com.malinskiy.adam.extension.copyTo +import com.malinskiy.adam.extension.readStatus import com.malinskiy.adam.request.ComplexRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel -import io.ktor.utils.io.ByteReadChannel +import com.malinskiy.adam.transport.Socket +import io.ktor.utils.io.* /** * Executes the command and provides the channel as the input to the command. Does not return anything */ class ExecInRequest(private val cmd: String, private val channel: ByteReadChannel) : ComplexRequest() { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { + override suspend fun readElement(socket: Socket) { val buffer = ByteArray(Const.MAX_FILE_PACKET_LENGTH) - channel.copyTo(writeChannel, buffer) + channel.copyTo(socket, buffer) //Have to poll - readChannel.readStatus() + socket.readStatus() } override fun serialize() = createBaseRequest("exec:$cmd") diff --git a/src/main/kotlin/com/malinskiy/adam/request/misc/FetchHostFeaturesRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/misc/FetchHostFeaturesRequest.kt index f88b1b98a..bebb2f9be 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/misc/FetchHostFeaturesRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/misc/FetchHostFeaturesRequest.kt @@ -20,21 +20,20 @@ import com.malinskiy.adam.Const import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.HostTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer class FetchHostFeaturesRequest : ComplexRequest>(target = HostTarget) { override fun serialize() = createBaseRequest("host-features") - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): List { + override suspend fun readElement(socket: Socket): List { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) return String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).split(',').mapNotNull { Feature.of(it) } } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/misc/GetAdbServerVersionRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/misc/GetAdbServerVersionRequest.kt index 18712c124..e27ffb756 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/misc/GetAdbServerVersionRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/misc/GetAdbServerVersionRequest.kt @@ -17,17 +17,17 @@ package com.malinskiy.adam.request.misc import com.malinskiy.adam.exception.RequestRejectedException +import com.malinskiy.adam.extension.readOptionalProtocolString import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.HostTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket /** * @see https://android.googlesource.com/platform/system/core/+/refs/heads/master/adb/adb.h#62 */ class GetAdbServerVersionRequest : ComplexRequest(target = HostTarget) { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): Int { - val version = readChannel.readOptionalProtocolString() + override suspend fun readElement(socket: Socket): Int { + val version = socket.readOptionalProtocolString() return version?.toIntOrNull(radix = 16) ?: throw RequestRejectedException("Empty/corrupt response") } diff --git a/src/main/kotlin/com/malinskiy/adam/request/misc/PairDeviceRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/misc/PairDeviceRequest.kt index 90117114b..51d4c7187 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/misc/PairDeviceRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/misc/PairDeviceRequest.kt @@ -19,8 +19,7 @@ package com.malinskiy.adam.request.misc import com.malinskiy.adam.Const import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.HostTarget -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer /** @@ -34,13 +33,13 @@ class PairDeviceRequest( private val pairingCode: String ) : ComplexRequest(target = HostTarget) { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): String { + override suspend fun readElement(socket: Socket): String { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) return String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING) } diff --git a/src/main/kotlin/com/malinskiy/adam/request/misc/ReconnectRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/misc/ReconnectRequest.kt index 00652d731..20208aae9 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/misc/ReconnectRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/misc/ReconnectRequest.kt @@ -21,8 +21,7 @@ import com.malinskiy.adam.extension.compatRewind import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.NonSpecifiedTarget import com.malinskiy.adam.request.Target -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer /** @@ -40,9 +39,9 @@ class ReconnectRequest( ) : ComplexRequest(target = target) { private val buffer = ByteBuffer.allocate(4) - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): String { + override suspend fun readElement(socket: Socket): String { - readChannel.readFully(buffer) + socket.readFully(buffer) val array = buffer.array() return if (array.contentEquals(done)) { "done" @@ -50,7 +49,7 @@ class ReconnectRequest( //This is length of a response string val size = String(array, Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) payloadBuffer.compatRewind() String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING) } diff --git a/src/main/kotlin/com/malinskiy/adam/request/pkg/LegacySideloadRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/pkg/LegacySideloadRequest.kt index 34a40ab79..304ac0a6e 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/pkg/LegacySideloadRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/pkg/LegacySideloadRequest.kt @@ -18,13 +18,12 @@ package com.malinskiy.adam.request.pkg import com.malinskiy.adam.Const import com.malinskiy.adam.extension.copyTo +import com.malinskiy.adam.extension.readTransportResponse import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.ValidationResponse -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel -import io.ktor.util.cio.readChannel -import io.ktor.utils.io.ByteReadChannel -import io.ktor.utils.io.cancel +import com.malinskiy.adam.transport.Socket +import io.ktor.util.cio.* +import io.ktor.utils.io.* import kotlinx.coroutines.Dispatchers import java.io.File import kotlin.coroutines.CoroutineContext @@ -51,16 +50,16 @@ class LegacySideloadRequest( override fun serialize() = createBaseRequest("sideload:${pkg.length()}") - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): Boolean { + override suspend fun readElement(socket: Socket): Boolean { val buffer = ByteArray(Const.MAX_FILE_PACKET_LENGTH) var fileChannel: ByteReadChannel? = null try { val fileChannel = pkg.readChannel(coroutineContext = coroutineContext) - fileChannel.copyTo(writeChannel, buffer) + fileChannel.copyTo(socket, buffer) } finally { fileChannel?.cancel() } - return readChannel.read().okay + return socket.readTransportResponse().okay } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/pkg/SideloadRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/pkg/SideloadRequest.kt index 55c32a0fa..6138a1647 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/pkg/SideloadRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/pkg/SideloadRequest.kt @@ -19,18 +19,16 @@ package com.malinskiy.adam.request.pkg import com.malinskiy.adam.Const import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.ValidationResponse -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel -import io.ktor.util.cio.readChannel -import io.ktor.utils.io.ByteReadChannel -import io.ktor.utils.io.cancel +import com.malinskiy.adam.transport.Socket +import io.ktor.util.cio.* +import io.ktor.utils.io.* import java.io.File class SideloadRequest( private val pkg: File, private var blockSize: Int = Const.MAX_FILE_PACKET_LENGTH ) : ComplexRequest() { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): Boolean { + override suspend fun readElement(socket: Socket): Boolean { val buffer = ByteArray(blockSize) var pkgChannel: ByteReadChannel? = null try { @@ -38,7 +36,7 @@ class SideloadRequest( var currentOffset = 0L while (true) { - readChannel.readFully(buffer, 0, 8) + socket.readFully(buffer, 0, 8) val bytes = buffer.copyOfRange(0, 8) when { bytes.contentEquals(Const.Message.DONEDONE) -> return true @@ -59,7 +57,7 @@ class SideloadRequest( blockSize.toLong() } pkgChannel?.readFully(buffer, 0, expectedLength.toInt()) - writeChannel.writeFully(buffer, 0, expectedLength.toInt()) + socket.writeFully(buffer, 0, expectedLength.toInt()) currentOffset += expectedLength } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/pkg/StreamingPackageInstallRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/pkg/StreamingPackageInstallRequest.kt index e03746f7d..5903e2b01 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/pkg/StreamingPackageInstallRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/pkg/StreamingPackageInstallRequest.kt @@ -25,8 +25,7 @@ import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.abb.AbbExecRequest import com.malinskiy.adam.request.transform.StringResponseTransformer -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import io.ktor.util.cio.* import io.ktor.utils.io.* import kotlinx.coroutines.Dispatchers @@ -110,17 +109,17 @@ class StreamingPackageInstallRequest( } } - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): Boolean { + override suspend fun readElement(socket: Socket): Boolean { val buffer = ByteArray(Const.MAX_FILE_PACKET_LENGTH) var fileChannel: ByteReadChannel? = null try { val fileChannel = pkg.readChannel(coroutineContext = coroutineContext) - fileChannel.copyTo(writeChannel, buffer) + fileChannel.copyTo(socket, buffer) } finally { fileChannel?.cancel() } - readChannel.copyTo(transformer, buffer) + socket.copyTo(transformer, buffer) return transformer.transform().startsWith("Success") } diff --git a/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/AddSessionRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/AddSessionRequest.kt index 45c756898..ef36cdbf1 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/AddSessionRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/AddSessionRequest.kt @@ -17,11 +17,11 @@ package com.malinskiy.adam.request.pkg.multi import com.malinskiy.adam.exception.RequestRejectedException +import com.malinskiy.adam.extension.readStatus import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.abb.AbbExecRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket class AddSessionRequest( private val childSessions: List, @@ -50,8 +50,8 @@ class AddSessionRequest( } } - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { - val response = readChannel.readStatus() + override suspend fun readElement(socket: Socket) { + val response = socket.readStatus() if (!response.contains("Success")) { throw RequestRejectedException( "Failed to add child sessions ${childSessions.joinToString()} to a parent session " + diff --git a/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/CreateIndividualPackageSessionRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/CreateIndividualPackageSessionRequest.kt index fca4cb424..e109b2e5f 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/CreateIndividualPackageSessionRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/CreateIndividualPackageSessionRequest.kt @@ -18,11 +18,11 @@ package com.malinskiy.adam.request.pkg.multi import com.malinskiy.adam.exception.RequestRejectedException import com.malinskiy.adam.extension.bashEscape +import com.malinskiy.adam.extension.readStatus import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.abb.AbbExecRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket class CreateIndividualPackageSessionRequest( private val pkg: InstallationPackage, @@ -84,8 +84,8 @@ class CreateIndividualPackageSessionRequest( } } - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): String { - val response = readChannel.readStatus() + override suspend fun readElement(socket: Socket): String { + val response = socket.readStatus() if (!response.contains("Success")) { throw RequestRejectedException("Failed to create multi-package session: $response") } diff --git a/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/CreateMultiPackageSessionRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/CreateMultiPackageSessionRequest.kt index e4379190b..4ee74be77 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/CreateMultiPackageSessionRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/CreateMultiPackageSessionRequest.kt @@ -19,12 +19,12 @@ package com.malinskiy.adam.request.pkg.multi import com.malinskiy.adam.annotation.Features import com.malinskiy.adam.exception.RequestRejectedException import com.malinskiy.adam.extension.bashEscape +import com.malinskiy.adam.extension.readStatus import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.abb.AbbExecRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.io.File @Features(Feature.CMD, Feature.ABB_EXEC) @@ -119,8 +119,8 @@ class CreateMultiPackageSessionRequest( } } - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): String { - val createSessionResponse = readChannel.readStatus() + override suspend fun readElement(socket: Socket): String { + val createSessionResponse = socket.readStatus() if (!createSessionResponse.contains("Success")) { throw RequestRejectedException("Failed to create multi-package session") } diff --git a/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/InstallCommitRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/InstallCommitRequest.kt index 1b3a697e4..de1a16d27 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/InstallCommitRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/InstallCommitRequest.kt @@ -17,11 +17,11 @@ package com.malinskiy.adam.request.pkg.multi import com.malinskiy.adam.exception.RequestRejectedException +import com.malinskiy.adam.extension.readStatus import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.abb.AbbExecRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket class InstallCommitRequest( private val parentSession: String, @@ -56,8 +56,8 @@ class InstallCommitRequest( } } - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { - val result = readChannel.readStatus() + override suspend fun readElement(socket: Socket) { + val result = socket.readStatus() //Rather than checking for success, we check for Failure since some implementations of PackageManagerShellCommand ignore the //logSuccess=true in doCommitSession if (result.contains("Failure")) { diff --git a/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/WriteIndividualPackageRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/WriteIndividualPackageRequest.kt index bd8c955da..8f4860b25 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/WriteIndividualPackageRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/pkg/multi/WriteIndividualPackageRequest.kt @@ -17,11 +17,12 @@ package com.malinskiy.adam.request.pkg.multi import com.malinskiy.adam.exception.RequestRejectedException +import com.malinskiy.adam.extension.readStatus +import com.malinskiy.adam.extension.writeFile import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.abb.AbbExecRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import kotlinx.coroutines.Dispatchers import java.io.File import kotlin.coroutines.CoroutineContext @@ -59,9 +60,9 @@ class WriteIndividualPackageRequest( } } - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { - writeChannel.writeFile(file, coroutineContext) - val response = readChannel.readStatus() + override suspend fun readElement(socket: Socket) { + socket.writeFile(file, coroutineContext) + val response = socket.readStatus() if (!response.contains("Success")) { throw RequestRejectedException("Failed to write package $file: $response") } diff --git a/src/main/kotlin/com/malinskiy/adam/request/reverse/ListReversePortForwardsRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/reverse/ListReversePortForwardsRequest.kt index bbcbd40a5..92c388810 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/reverse/ListReversePortForwardsRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/reverse/ListReversePortForwardsRequest.kt @@ -21,21 +21,20 @@ import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.NonSpecifiedTarget import com.malinskiy.adam.request.forwarding.LocalPortSpec import com.malinskiy.adam.request.forwarding.RemotePortSpec -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.nio.ByteBuffer /** * Doesn't work with SerialTarget, have to use the serial as a parameter for the execute method */ class ListReversePortForwardsRequest : ComplexRequest>(target = NonSpecifiedTarget) { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): List { + override suspend fun readElement(socket: Socket): List { val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4) - readChannel.readFully(sizeBuffer) + socket.readFully(sizeBuffer) val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16) val payloadBuffer = ByteBuffer.allocate(size) - readChannel.readFully(payloadBuffer) + socket.readFully(payloadBuffer) val payload = String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING) return payload.lines().mapNotNull { line -> if (line.isNotEmpty()) { diff --git a/src/main/kotlin/com/malinskiy/adam/request/reverse/ReversePortForwardRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/reverse/ReversePortForwardRequest.kt index 5d470ebc0..f42b31dbf 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/reverse/ReversePortForwardRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/reverse/ReversePortForwardRequest.kt @@ -17,13 +17,14 @@ package com.malinskiy.adam.request.reverse import com.malinskiy.adam.exception.RequestRejectedException +import com.malinskiy.adam.extension.readOptionalProtocolString +import com.malinskiy.adam.extension.readTransportResponse import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.NonSpecifiedTarget import com.malinskiy.adam.request.forwarding.LocalPortSpec import com.malinskiy.adam.request.forwarding.PortForwardingMode import com.malinskiy.adam.request.forwarding.RemotePortSpec -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket /** * On some devices, this might not return the actual port if you're passing tcp:0 @@ -40,12 +41,12 @@ class ReversePortForwardRequest( override fun serialize() = createBaseRequest("reverse:forward${mode.value}:${local.toSpec()};${remote.toSpec()}") - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): Int? { - val transportResponse = readChannel.read() + override suspend fun readElement(socket: Socket): Int? { + val transportResponse = socket.readTransportResponse() if (!transportResponse.okay) { throw RequestRejectedException("Can't establish port forwarding: ${transportResponse.message ?: ""}") } - return readChannel.readOptionalProtocolString()?.toIntOrNull() + return socket.readOptionalProtocolString()?.toIntOrNull() } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/shell/v1/ChanneledShellCommandRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/shell/v1/ChanneledShellCommandRequest.kt index bfd8a2324..ec21102db 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/shell/v1/ChanneledShellCommandRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/shell/v1/ChanneledShellCommandRequest.kt @@ -20,9 +20,8 @@ import com.malinskiy.adam.Const import com.malinskiy.adam.request.AsyncChannelRequest import com.malinskiy.adam.request.NonSpecifiedTarget import com.malinskiy.adam.request.Target -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel -import kotlinx.coroutines.delay +import com.malinskiy.adam.transport.Socket +import kotlinx.coroutines.channels.SendChannel open class ChanneledShellCommandRequest( val cmd: String, @@ -31,19 +30,15 @@ open class ChanneledShellCommandRequest( val data = ByteArray(Const.MAX_PACKET_LENGTH) - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): String? { - while (readChannel.availableForRead == 0) { - if (readChannel.isClosedForRead || writeChannel.isClosedForWrite) return null - delay(Const.READ_DELAY) - } - - val count = readChannel.readAvailable(data, 0, Const.MAX_PACKET_LENGTH) - return when { - count > 0 -> String(data, 0, count, Const.DEFAULT_TRANSPORT_ENCODING) - else -> return null + override suspend fun readElement(socket: Socket, sendChannel: SendChannel): Boolean { + val count = socket.readAvailable(data, 0, Const.MAX_PACKET_LENGTH) + when { + count > 0 -> sendChannel.send(String(data, 0, count, Const.DEFAULT_TRANSPORT_ENCODING)) + else -> Unit } + return false } override fun serialize() = createBaseRequest("shell:$cmd") - override suspend fun writeElement(element: Unit, readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) = Unit + override suspend fun writeElement(element: Unit, socket: Socket) = Unit } diff --git a/src/main/kotlin/com/malinskiy/adam/request/shell/v2/ChanneledShellCommandRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/shell/v2/ChanneledShellCommandRequest.kt index 1b763f572..60d7f653d 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/shell/v2/ChanneledShellCommandRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/shell/v2/ChanneledShellCommandRequest.kt @@ -20,13 +20,10 @@ import com.malinskiy.adam.Const import com.malinskiy.adam.request.AsyncChannelRequest import com.malinskiy.adam.request.NonSpecifiedTarget import com.malinskiy.adam.request.Target -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel -import io.ktor.utils.io.readIntLittleEndian -import io.ktor.utils.io.writeByte -import io.ktor.utils.io.writeFully -import io.ktor.utils.io.writeIntLittleEndian +import com.malinskiy.adam.transport.Socket +import com.malinskiy.adam.transport.withDefaultBuffer import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.channels.SendChannel open class ChanneledShellCommandRequest( private val cmd: String, @@ -36,47 +33,52 @@ open class ChanneledShellCommandRequest( val data = ByteArray(Const.MAX_PACKET_LENGTH) - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): ShellCommandResultChunk? { - //Skip if nothing is happening - if (readChannel.availableForRead == 0) { - return null - } - return when (MessageType.of(readChannel.readByte().toInt())) { - MessageType.STDOUT -> { - val length = readChannel.readIntLittleEndian() - readChannel.readFully(data, 0, length) - ShellCommandResultChunk(stdout = String(data, 0, length, Const.DEFAULT_TRANSPORT_ENCODING)) - } - MessageType.STDERR -> { - val length = readChannel.readIntLittleEndian() - readChannel.readFully(data, 0, length) - ShellCommandResultChunk(stderr = String(data, 0, length, Const.DEFAULT_TRANSPORT_ENCODING)) - } - MessageType.EXIT -> { - val ignoredLength = readChannel.readIntLittleEndian() - val exitCode = readChannel.readByte().toInt() - ShellCommandResultChunk(exitCode = exitCode) + override suspend fun readElement(socket: Socket, sendChannel: SendChannel): Boolean { + withDefaultBuffer { + val readAvailable = socket.readAvailable(this.array(), 0, 1) + when (readAvailable) { + //Skip as if nothing is happening + 0, -1 -> return false } - else -> { - null + + val readByte = this.get(0) + when (MessageType.of(readByte.toInt())) { + MessageType.STDOUT -> { + val length = socket.readIntLittleEndian() + socket.readFully(data, 0, length) + sendChannel.send(ShellCommandResultChunk(stdout = String(data, 0, length, Const.DEFAULT_TRANSPORT_ENCODING))) + } + MessageType.STDERR -> { + val length = socket.readIntLittleEndian() + socket.readFully(data, 0, length) + sendChannel.send(ShellCommandResultChunk(stderr = String(data, 0, length, Const.DEFAULT_TRANSPORT_ENCODING))) + } + MessageType.EXIT -> { + val ignoredLength = socket.readIntLittleEndian() + val exitCode = socket.readByte().toInt() + sendChannel.send(ShellCommandResultChunk(exitCode = exitCode)) + } + else -> Unit } + + return false } } /** * Handles stdin */ - override suspend fun writeElement(element: ShellCommandInputChunk, readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { + override suspend fun writeElement(element: ShellCommandInputChunk, socket: Socket) { element.stdin?.let { - writeChannel.writeByte(MessageType.STDIN.toValue()) + socket.writeByte(MessageType.STDIN.toValue()) val bytes = it.toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - writeChannel.writeIntLittleEndian(bytes.size) - writeChannel.writeFully(bytes) + socket.writeIntLittleEndian(bytes.size) + socket.writeFully(bytes) } if (element.close) { - writeChannel.writeByte(MessageType.CLOSE_STDIN.toValue()) - writeChannel.writeIntLittleEndian(0) + socket.writeByte(MessageType.CLOSE_STDIN.toValue()) + socket.writeIntLittleEndian(0) } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/shell/v2/SyncShellCommandRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/shell/v2/SyncShellCommandRequest.kt index a97dfc166..7897a178a 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/shell/v2/SyncShellCommandRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/shell/v2/SyncShellCommandRequest.kt @@ -21,8 +21,7 @@ import com.malinskiy.adam.exception.RequestValidationException import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.NonSpecifiedTarget import com.malinskiy.adam.request.Target -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import io.ktor.utils.io.* /** @@ -42,22 +41,22 @@ abstract class SyncShellCommandRequest(val cmd: String, target: Target abstract fun convertResult(response: ShellCommandResult): T override fun serialize() = createBaseRequest("shell,v2,raw:$cmd") - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): T { + override suspend fun readElement(socket: Socket): T { loop@ while (true) { - when (val messageType = MessageType.of(readChannel.readByte().toInt())) { + when (val messageType = MessageType.of(socket.readByte().toInt())) { MessageType.STDOUT -> { - val length = readChannel.readIntLittleEndian() - readChannel.readFully(data, 0, length) + val length = socket.readIntLittleEndian() + socket.readFully(data, 0, length) stdoutBuilder.append(String(data, 0, length, Const.DEFAULT_TRANSPORT_ENCODING)) } MessageType.STDERR -> { - val length = readChannel.readIntLittleEndian() - readChannel.readFully(data, 0, length) + val length = socket.readIntLittleEndian() + socket.readFully(data, 0, length) stderrBuilder.append(String(data, 0, length, Const.DEFAULT_TRANSPORT_ENCODING)) } MessageType.EXIT -> { - val length = readChannel.readIntLittleEndian() - exitCode = readChannel.readByte().toInt() + val length = socket.readIntLittleEndian() + exitCode = socket.readByte().toInt() break@loop } MessageType.STDIN, MessageType.CLOSE_STDIN, MessageType.WINDOW_SIZE_CHANGE, MessageType.INVALID -> { diff --git a/src/main/kotlin/com/malinskiy/adam/request/sync/base/BasePullFileRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/sync/base/BasePullFileRequest.kt index 32bb1628d..f6c9fcac3 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/sync/base/BasePullFileRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/sync/base/BasePullFileRequest.kt @@ -23,8 +23,7 @@ import com.malinskiy.adam.extension.toInt import com.malinskiy.adam.request.AsyncChannelRequest import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.sync.v1.StatFileRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import io.ktor.util.cio.* import io.ktor.utils.io.* import kotlinx.coroutines.Dispatchers @@ -53,41 +52,39 @@ abstract class BasePullFileRequest( var totalBytes = -1L var currentPosition = 0L - override suspend fun handshake(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { - super.handshake(readChannel, writeChannel) + override suspend fun handshake(socket: Socket) { + super.handshake(socket) //If we don't have expected size, fetch it - totalBytes = size ?: StatFileRequest(remotePath).readElement(readChannel, writeChannel).size.toLong() + totalBytes = size ?: StatFileRequest(remotePath).readElement(socket).size.toLong() } private val headerBuffer = ByteArray(8) private val dataBuffer = ByteArray(Const.MAX_FILE_PACKET_LENGTH) - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): Double? { - readChannel.readFully(headerBuffer, 0, 8) + override suspend fun readElement(socket: Socket, sendChannel: SendChannel): Boolean { + socket.readFully(headerBuffer, 0, 8) val header = headerBuffer.copyOfRange(0, 4) when { header.contentEquals(Const.Message.DONE) -> { fileWriteChannel.close() - readChannel.cancel(null) - writeChannel.close(null) - return 1.0 + return true } header.contentEquals(Const.Message.DATA) -> { val available = headerBuffer.copyOfRange(4, 8).toInt() if (available > Const.MAX_FILE_PACKET_LENGTH) { throw UnsupportedSyncProtocolException() } - readChannel.readFully(dataBuffer, 0, available) + socket.readFully(dataBuffer, 0, available) fileWriteChannel.writeFully(dataBuffer, 0, available) currentPosition += available - return currentPosition.toDouble() / totalBytes + sendChannel.send(currentPosition.toDouble() / totalBytes) } header.contentEquals(Const.Message.FAIL) -> { val size = headerBuffer.copyOfRange(4, 8).toInt() - readChannel.readFully(dataBuffer, 0, size) + socket.readFully(dataBuffer, 0, size) val errorMessage = String(dataBuffer, 0, size) throw PullFailedException("Failed to pull file $remotePath: $errorMessage") } @@ -95,6 +92,7 @@ abstract class BasePullFileRequest( throw UnsupportedSyncProtocolException("Unexpected header message ${String(header, Const.DEFAULT_TRANSPORT_ENCODING)}") } } + return false } override suspend fun close(channel: SendChannel) { @@ -111,6 +109,6 @@ abstract class BasePullFileRequest( } } - override suspend fun writeElement(element: Unit, readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) = Unit + override suspend fun writeElement(element: Unit, socket: Socket) = Unit } diff --git a/src/main/kotlin/com/malinskiy/adam/request/sync/base/BasePushFileRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/sync/base/BasePushFileRequest.kt index 21170ed98..9b6fab852 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/sync/base/BasePushFileRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/sync/base/BasePushFileRequest.kt @@ -18,11 +18,12 @@ package com.malinskiy.adam.request.sync.base import com.malinskiy.adam.Const import com.malinskiy.adam.exception.PushFailedException +import com.malinskiy.adam.extension.readTransportResponse import com.malinskiy.adam.extension.toByteArray +import com.malinskiy.adam.extension.write import com.malinskiy.adam.request.AsyncChannelRequest import com.malinskiy.adam.request.ValidationResponse -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import io.ktor.util.cio.* import io.ktor.utils.io.* import kotlinx.coroutines.Dispatchers @@ -37,25 +38,25 @@ abstract class BasePushFileRequest( coroutineContext: CoroutineContext = Dispatchers.IO ) : AsyncChannelRequest() { protected val fileReadChannel = local.readChannel(coroutineContext = coroutineContext) - protected val buffer = ByteArray(8 + Const.MAX_FILE_PACKET_LENGTH) + protected val buffer = ByteArray(Const.KTOR_INTERNAL_BUFFER_LENGTH) protected var totalBytes = local.length() protected var currentPosition = 0L protected val modeValue: Int get() = mode.toInt(8) and "0777".toInt(8) - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): Double? { - val available = fileReadChannel.readAvailable(buffer, 8, Const.MAX_FILE_PACKET_LENGTH) - return when { + override suspend fun readElement(socket: Socket, sendChannel: SendChannel): Boolean { + val available = fileReadChannel.readAvailable(buffer, 8, Const.KTOR_INTERNAL_BUFFER_LENGTH - 8) + when { available < 0 -> { Const.Message.DONE.copyInto(buffer) (local.lastModified() / 1000).toInt().toByteArray().copyInto(buffer, destinationOffset = 4) - writeChannel.write(request = buffer, length = 8) - val transportResponse = readChannel.read() - readChannel.cancel(null) - writeChannel.close(null) + socket.write(request = buffer, length = 8) + val transportResponse = socket.readTransportResponse() fileReadChannel.cancel() - return if (transportResponse.okay) { - 1.0 + + if (transportResponse.okay) { + sendChannel.send(1.0) + return true } else { throw PushFailedException("adb didn't acknowledge the file transfer: ${transportResponse.message ?: ""}") } @@ -63,12 +64,13 @@ abstract class BasePushFileRequest( available > 0 -> { Const.Message.DATA.copyInto(buffer) available.toByteArray().reversedArray().copyInto(buffer, destinationOffset = 4) - writeChannel.writeFully(buffer, 0, available + 8) + socket.writeFully(buffer, 0, available + 8) currentPosition += available - currentPosition.toDouble() / totalBytes + sendChannel.send(currentPosition.toDouble() / totalBytes) } - else -> currentPosition.toDouble() / totalBytes + else -> Unit } + return false } override fun serialize() = createBaseRequest("sync:") @@ -77,7 +79,7 @@ abstract class BasePushFileRequest( fileReadChannel.cancel() } - override suspend fun writeElement(element: Unit, readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) = Unit + override suspend fun writeElement(element: Unit, socket: Socket) = Unit override fun validate(): ValidationResponse { val response = super.validate() diff --git a/src/main/kotlin/com/malinskiy/adam/request/sync/v1/ListFileRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/sync/v1/ListFileRequest.kt index d49794899..7b4351f3c 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/sync/v1/ListFileRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/sync/v1/ListFileRequest.kt @@ -18,11 +18,11 @@ package com.malinskiy.adam.request.sync.v1 import com.malinskiy.adam.Const import com.malinskiy.adam.extension.toInt +import com.malinskiy.adam.extension.writeSyncRequest import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.sync.model.FileEntryV1 -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.time.Instant class ListFileRequest( @@ -37,19 +37,19 @@ class ListFileRequest( } } - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): List { - writeChannel.writeSyncRequest(Const.Message.LIST_V1, remotePath) + override suspend fun readElement(socket: Socket): List { + socket.writeSyncRequest(Const.Message.LIST_V1, remotePath) val bytes = ByteArray(16) val stringBytes = ByteArray(Const.MAX_REMOTE_PATH_LENGTH) val result = mutableListOf() loop@ while (true) { - readChannel.readFully(bytes, 0, 4) + socket.readFully(bytes, 0, 4) when { bytes.copyOfRange(0, 4).contentEquals(Const.Message.DENT_V1) -> { - readChannel.readFully(bytes, 0, 16) + socket.readFully(bytes, 0, 16) val nameLength = bytes.copyOfRange(12, 16).toInt() - readChannel.readFully(stringBytes, 0, nameLength) + socket.readFully(stringBytes, 0, nameLength) result.add( FileEntryV1( diff --git a/src/main/kotlin/com/malinskiy/adam/request/sync/v1/PullFileRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/sync/v1/PullFileRequest.kt index 9ed48f7b3..f8691f342 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/sync/v1/PullFileRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/sync/v1/PullFileRequest.kt @@ -17,9 +17,9 @@ package com.malinskiy.adam.request.sync.v1 import com.malinskiy.adam.Const +import com.malinskiy.adam.extension.writeSyncRequest import com.malinskiy.adam.request.sync.base.BasePullFileRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import kotlinx.coroutines.Dispatchers import java.io.File import kotlin.coroutines.CoroutineContext @@ -33,8 +33,8 @@ class PullFileRequest( size: Long? = null, coroutineContext: CoroutineContext = Dispatchers.IO ) : BasePullFileRequest(remotePath, local, size, coroutineContext) { - override suspend fun handshake(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { - super.handshake(readChannel, writeChannel) - writeChannel.writeSyncRequest(Const.Message.RECV_V1, remotePath) + override suspend fun handshake(socket: Socket) { + super.handshake(socket) + socket.writeSyncRequest(Const.Message.RECV_V1, remotePath) } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/sync/v1/PushFileRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/sync/v1/PushFileRequest.kt index d33db51ec..b746bbc24 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/sync/v1/PushFileRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/sync/v1/PushFileRequest.kt @@ -18,9 +18,9 @@ package com.malinskiy.adam.request.sync.v1 import com.malinskiy.adam.Const import com.malinskiy.adam.extension.toByteArray +import com.malinskiy.adam.extension.write import com.malinskiy.adam.request.sync.base.BasePushFileRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import kotlinx.coroutines.Dispatchers import java.io.File import kotlin.coroutines.CoroutineContext @@ -32,8 +32,8 @@ class PushFileRequest( coroutineContext: CoroutineContext = Dispatchers.IO ) : BasePushFileRequest(local, remotePath, mode, coroutineContext) { - override suspend fun handshake(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { - super.handshake(readChannel, writeChannel) + override suspend fun handshake(socket: Socket) { + super.handshake(socket) val type = Const.Message.SEND_V1 @@ -49,6 +49,6 @@ class PushFileRequest( path.copyInto(cmd, 8) mode.copyInto(cmd, 8 + path.size) - writeChannel.write(cmd) + socket.write(cmd) } } diff --git a/src/main/kotlin/com/malinskiy/adam/request/sync/v1/StatFileRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/sync/v1/StatFileRequest.kt index 72078c76d..e837dd186 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/sync/v1/StatFileRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/sync/v1/StatFileRequest.kt @@ -20,21 +20,21 @@ import com.malinskiy.adam.Const import com.malinskiy.adam.exception.UnsupportedSyncProtocolException import com.malinskiy.adam.extension.toInt import com.malinskiy.adam.extension.toUInt +import com.malinskiy.adam.extension.writeSyncRequest import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.sync.model.FileEntryV1 -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.time.Instant class StatFileRequest( private val remotePath: String ) : ComplexRequest() { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): FileEntryV1 { - writeChannel.writeSyncRequest(Const.Message.LSTAT_V1, remotePath) + override suspend fun readElement(socket: Socket): FileEntryV1 { + socket.writeSyncRequest(Const.Message.LSTAT_V1, remotePath) val bytes = ByteArray(16) - readChannel.readFully(bytes, 0, 16) + socket.readFully(bytes, 0, 16) if (!bytes.copyOfRange(0, 4).contentEquals(Const.Message.LSTAT_V1)) throw UnsupportedSyncProtocolException() diff --git a/src/main/kotlin/com/malinskiy/adam/request/sync/v2/ListFileRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/sync/v2/ListFileRequest.kt index c1da8cd49..ba417d178 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/sync/v2/ListFileRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/sync/v2/ListFileRequest.kt @@ -18,16 +18,12 @@ package com.malinskiy.adam.request.sync.v2 import com.malinskiy.adam.Const import com.malinskiy.adam.annotation.Features -import com.malinskiy.adam.extension.toInt -import com.malinskiy.adam.extension.toLong -import com.malinskiy.adam.extension.toUInt -import com.malinskiy.adam.extension.toULong +import com.malinskiy.adam.extension.* import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.sync.model.FileEntryV2 -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.time.Instant @Features(Feature.LS_V2) @@ -49,20 +45,20 @@ class ListFileRequest( } } - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): List { - writeChannel.writeSyncRequest(Const.Message.LIST_V2, remotePath) + override suspend fun readElement(socket: Socket): List { + socket.writeSyncRequest(Const.Message.LIST_V2, remotePath) val stringBytes = ByteArray(Const.MAX_REMOTE_PATH_LENGTH) val bytes = ByteArray(72) val result = mutableListOf() loop@ while (true) { - readChannel.readFully(bytes, 0, 4) + socket.readFully(bytes, 0, 4) when { bytes.copyOfRange(0, 4).contentEquals(Const.Message.DENT_V2) -> { - readChannel.readFully(bytes, 0, 72) + socket.readFully(bytes, 0, 72) val nameLength = bytes.copyOfRange(68, 72).toInt() - readChannel.readFully(stringBytes, 0, nameLength) + socket.readFully(stringBytes, 0, nameLength) result.add( FileEntryV2( error = bytes.copyOfRange(0, 4).toUInt(), diff --git a/src/main/kotlin/com/malinskiy/adam/request/sync/v2/PullFileRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/sync/v2/PullFileRequest.kt index eec414c10..019da427c 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/sync/v2/PullFileRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/sync/v2/PullFileRequest.kt @@ -18,11 +18,11 @@ package com.malinskiy.adam.request.sync.v2 import com.malinskiy.adam.Const import com.malinskiy.adam.annotation.Features +import com.malinskiy.adam.extension.writeSyncV2Request import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.sync.base.BasePullFileRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import kotlinx.coroutines.Dispatchers import java.io.File import kotlin.coroutines.CoroutineContext @@ -43,9 +43,9 @@ class PullFileRequest( */ private val compressionType = CompressionType.NONE - override suspend fun handshake(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { - super.handshake(readChannel, writeChannel) - writeChannel.writeSyncV2Request(Const.Message.RECV_V2, remotePath, compressionType.toFlag()) + override suspend fun handshake(socket: Socket) { + super.handshake(socket) + socket.writeSyncV2Request(Const.Message.RECV_V2, remotePath, compressionType.toFlag()) } override fun validate(): ValidationResponse { diff --git a/src/main/kotlin/com/malinskiy/adam/request/sync/v2/PushFileRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/sync/v2/PushFileRequest.kt index bb860dc33..973b95091 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/sync/v2/PushFileRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/sync/v2/PushFileRequest.kt @@ -18,11 +18,11 @@ package com.malinskiy.adam.request.sync.v2 import com.malinskiy.adam.Const import com.malinskiy.adam.annotation.Features +import com.malinskiy.adam.extension.writeSyncV2Request import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.sync.base.BasePushFileRequest -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import kotlinx.coroutines.Dispatchers import java.io.File import kotlin.coroutines.CoroutineContext @@ -44,14 +44,14 @@ class PushFileRequest( */ private val compressionType = CompressionType.NONE - override suspend fun handshake(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) { - super.handshake(readChannel, writeChannel) + override suspend fun handshake(socket: Socket) { + super.handshake(socket) val additionalFlags = if (dryRun) { (DRY_RUN_FLAG or compressionType.toFlag().toLong()).toInt() } else { compressionType.toFlag() } - writeChannel.writeSyncV2Request(Const.Message.SEND_V2, remotePath, additionalFlags, modeValue) + socket.writeSyncV2Request(Const.Message.SEND_V2, remotePath, additionalFlags, modeValue) } override fun validate(): ValidationResponse { diff --git a/src/main/kotlin/com/malinskiy/adam/request/sync/v2/StatFileRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/sync/v2/StatFileRequest.kt index 6a89b974f..334c9c8da 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/sync/v2/StatFileRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/sync/v2/StatFileRequest.kt @@ -22,12 +22,12 @@ import com.malinskiy.adam.exception.UnsupportedSyncProtocolException import com.malinskiy.adam.extension.toLong import com.malinskiy.adam.extension.toUInt import com.malinskiy.adam.extension.toULong +import com.malinskiy.adam.extension.writeSyncRequest import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.sync.model.FileEntryV2 -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import java.time.Instant @Features(Feature.STAT_V2) @@ -35,11 +35,11 @@ class StatFileRequest( private val remotePath: String, private val supportedFeatures: List ) : ComplexRequest() { - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): FileEntryV2 { - writeChannel.writeSyncRequest(Const.Message.LSTAT_V2, remotePath) + override suspend fun readElement(socket: Socket): FileEntryV2 { + socket.writeSyncRequest(Const.Message.LSTAT_V2, remotePath) val bytes = ByteArray(72) - readChannel.readFully(bytes, 0, 72) + socket.readFully(bytes, 0, 72) if (!bytes.copyOfRange(0, 4).contentEquals(Const.Message.LSTAT_V2)) throw UnsupportedSyncProtocolException() diff --git a/src/main/kotlin/com/malinskiy/adam/request/testrunner/TestRunnerRequest.kt b/src/main/kotlin/com/malinskiy/adam/request/testrunner/TestRunnerRequest.kt index e3e0325ce..fa8a70ca9 100644 --- a/src/main/kotlin/com/malinskiy/adam/request/testrunner/TestRunnerRequest.kt +++ b/src/main/kotlin/com/malinskiy/adam/request/testrunner/TestRunnerRequest.kt @@ -21,8 +21,7 @@ import com.malinskiy.adam.request.AsyncChannelRequest import com.malinskiy.adam.request.transform.InstrumentationResponseTransformer import com.malinskiy.adam.request.transform.ProgressiveResponseTransformer import com.malinskiy.adam.request.transform.ProtoInstrumentationResponseTransformer -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import kotlinx.coroutines.channels.SendChannel /** @@ -64,20 +63,20 @@ class TestRunnerRequest( } } - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): List? { - val available = readChannel.readAvailable(buffer, 0, Const.MAX_PACKET_LENGTH) + override suspend fun readElement(socket: Socket, sendChannel: SendChannel>): Boolean { + val available = socket.readAvailable(buffer, 0, Const.MAX_PACKET_LENGTH) - return when { + when { available > 0 -> { - transformer.process(buffer, 0, available) + transformer.process(buffer, 0, available)?.let { sendChannel.send(it) } } available < 0 -> { - readChannel.cancel(null) - writeChannel.close(null) - return null + return true } else -> null } + + return false } override fun serialize() = createBaseRequest(StringBuilder().apply { @@ -126,5 +125,5 @@ class TestRunnerRequest( transformer.transform()?.let { channel.send(it) } } - override suspend fun writeElement(element: Unit, readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel) = Unit + override suspend fun writeElement(element: Unit, socket: Socket) = Unit } diff --git a/src/main/kotlin/com/malinskiy/adam/transport/AndroidReadChannel.kt b/src/main/kotlin/com/malinskiy/adam/transport/AndroidReadChannel.kt deleted file mode 100644 index 2c7ba30b8..000000000 --- a/src/main/kotlin/com/malinskiy/adam/transport/AndroidReadChannel.kt +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (C) 2019 Anton Malinskiy - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.malinskiy.adam.transport - -import com.malinskiy.adam.Const -import com.malinskiy.adam.Const.Message.OKAY -import com.malinskiy.adam.extension.copyTo -import com.malinskiy.adam.request.transform.StringResponseTransformer -import io.ktor.utils.io.* - -class AndroidReadChannel(private val delegate: ByteReadChannel) : ByteReadChannel by delegate { - suspend fun read(): TransportResponse { - val bytes = ByteArray(4) - delegate.readFully(bytes, 0, 4) - - val ok = bytes.isOkay() - val message = if (!ok) { - readOptionalProtocolString() - } else { - null - } - - return TransportResponse(ok, message) - } - - private fun ByteArray.isOkay() = contentEquals(OKAY) - - suspend fun readStatus(): String { - withDefaultBuffer { - val transformer = StringResponseTransformer() - copyTo(transformer, this) - return transformer.transform() - } - } - - suspend fun readOptionalProtocolString(): String? { - val responseLength = withDefaultBuffer { - val transformer = StringResponseTransformer() - copyTo(transformer, this, limit = 4L) - transformer.transform() - } - val errorMessageLength = responseLength.toIntOrNull(16) - return if (errorMessageLength == null) { - readStatus() - } else { - val errorBytes = ByteArray(errorMessageLength) - delegate.readFully(errorBytes, 0, errorMessageLength) - String(errorBytes, Const.DEFAULT_TRANSPORT_ENCODING) - } - } -} diff --git a/src/main/kotlin/com/malinskiy/adam/transport/AndroidWriteChannel.kt b/src/main/kotlin/com/malinskiy/adam/transport/AndroidWriteChannel.kt index da6ee0fd8..e69de29bb 100644 --- a/src/main/kotlin/com/malinskiy/adam/transport/AndroidWriteChannel.kt +++ b/src/main/kotlin/com/malinskiy/adam/transport/AndroidWriteChannel.kt @@ -1,82 +0,0 @@ -/* - * Copyright (C) 2019 Anton Malinskiy - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.malinskiy.adam.transport - -import com.malinskiy.adam.Const -import com.malinskiy.adam.extension.compatLimit -import com.malinskiy.adam.extension.compatRewind -import com.malinskiy.adam.extension.copyTo -import com.malinskiy.adam.extension.toByteArray -import io.ktor.util.cio.* -import io.ktor.utils.io.* -import io.ktor.utils.io.bits.* -import io.ktor.utils.io.core.* -import java.io.File -import java.nio.ByteBuffer -import kotlin.coroutines.CoroutineContext - -class AndroidWriteChannel(private val delegate: ByteWriteChannel) : ByteWriteChannel by delegate { - suspend fun write(request: ByteArray, length: Int? = null) { - val requestBuffer = ByteBuffer.wrap(request, 0, length ?: request.size) - delegate.writeFully(requestBuffer) - } - - suspend fun writeFile(file: File, coroutineContext: CoroutineContext) = withFileBuffer { - var fileChannel: ByteReadChannel? = null - try { - val fileChannel = file.readChannel(coroutineContext = coroutineContext) - fileChannel.copyTo(this@AndroidWriteChannel, this) - } finally { - fileChannel?.cancel() - } - } - - suspend fun writeSyncRequest(type: ByteArray, remotePath: String) { - val path = remotePath.toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val size = path.size.toByteArray().reversedArray() - - val cmd = ByteArray(8 + path.size) - - type.copyInto(cmd) - size.copyInto(cmd, 4) - path.copyInto(cmd, 8) - write(cmd) - } - - suspend fun writeSyncV2Request(type: ByteArray, remotePath: String, flags: Int, mode: Int? = null) { - val path = remotePath.toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - - withDefaultBuffer { - compatRewind() - compatLimit(4 + 4) - put(type) - putInt(path.size.reverseByteOrder()) - compatRewind() - writeFully(this) - - writeFully(path) - - compatRewind() - mode?.let { compatLimit(4 + 4 + 4) } ?: compatLimit(4 + 4) - put(type) - mode?.let { putInt(it.reverseByteOrder()) } - putInt(flags.reverseByteOrder()) - compatRewind() - writeFully(this) - } - } -} diff --git a/src/main/kotlin/com/malinskiy/adam/transport/BufferFactory.kt b/src/main/kotlin/com/malinskiy/adam/transport/BufferFactory.kt index a74fd51e9..ee2553c93 100644 --- a/src/main/kotlin/com/malinskiy/adam/transport/BufferFactory.kt +++ b/src/main/kotlin/com/malinskiy/adam/transport/BufferFactory.kt @@ -20,9 +20,8 @@ import com.malinskiy.adam.Const import io.ktor.utils.io.pool.* import java.nio.ByteBuffer -internal const val DEFAULT_BUFFER_SIZE = 4096 +internal const val DEFAULT_BUFFER_SIZE = 4088 -val AdamFilePool: ObjectPool = ByteBufferPool(DEFAULT_BUFFER_SIZE, Const.MAX_FILE_PACKET_LENGTH) val AdamDefaultPool: ObjectPool = ByteBufferPool(Const.DEFAULT_BUFFER_SIZE, DEFAULT_BUFFER_SIZE) inline fun withDefaultBuffer(block: ByteBuffer.() -> R): R { @@ -33,12 +32,3 @@ inline fun withDefaultBuffer(block: ByteBuffer.() -> R): R { AdamDefaultPool.recycle(instance) } } - -inline fun withFileBuffer(block: ByteBuffer.() -> R): R { - val instance = AdamFilePool.borrow() - return try { - block(instance) - } finally { - AdamFilePool.recycle(instance) - } -} diff --git a/src/main/kotlin/com/malinskiy/adam/transport/KtorSocket.kt b/src/main/kotlin/com/malinskiy/adam/transport/KtorSocket.kt new file mode 100644 index 000000000..ba5a2721b --- /dev/null +++ b/src/main/kotlin/com/malinskiy/adam/transport/KtorSocket.kt @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2021 Anton Malinskiy + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.malinskiy.adam.transport + +import com.malinskiy.adam.log.AdamLogging +import io.ktor.network.sockets.* +import io.ktor.utils.io.* +import java.nio.ByteBuffer +import io.ktor.network.sockets.Socket as RealKtorSocket + +class KtorSocket(private val ktorSocket: RealKtorSocket) : Socket { + private val readChannel: ByteReadChannel = ktorSocket.openReadChannel() + private val writeChannel: ByteWriteChannel = ktorSocket.openWriteChannel(autoFlush = true) + override val isClosedForWrite: Boolean + get() = writeChannel.isClosedForWrite + override val isClosedForRead: Boolean + get() = readChannel.isClosedForRead + + override suspend fun readFully(buffer: ByteBuffer): Int = readChannel.readFully(buffer) + override suspend fun readFully(buffer: ByteArray, offset: Int, limit: Int) = readChannel.readFully(buffer, offset, limit) + override suspend fun writeFully(byteBuffer: ByteBuffer) = writeChannel.writeFully(byteBuffer) + override suspend fun writeFully(toByteArray: ByteArray, offset: Int, limit: Int) = writeChannel.writeFully(toByteArray, offset, limit) + override suspend fun readAvailable(buffer: ByteArray, offset: Int, limit: Int): Int { + if (readChannel.availableForRead == 0) return 0 + return readChannel.readAvailable(buffer, offset, limit) + } + + override suspend fun readByte(): Byte = readChannel.readByte() + override suspend fun readIntLittleEndian(): Int = readChannel.readIntLittleEndian() + override suspend fun writeByte(value: Int) = writeChannel.writeByte(value) + override suspend fun writeIntLittleEndian(value: Int) = writeChannel.writeIntLittleEndian(value) + + override suspend fun close() { + try { + writeChannel.close() + readChannel.cancel() + ktorSocket.close() + } catch (e: Exception) { + log.debug(e) { "Exception during cleanup. Ignoring" } + } + } + + companion object { + private val log = AdamLogging.logger {} + } +} diff --git a/src/main/kotlin/com/malinskiy/adam/transport/KtorSocketFactory.kt b/src/main/kotlin/com/malinskiy/adam/transport/KtorSocketFactory.kt index 6a0143307..35322039f 100644 --- a/src/main/kotlin/com/malinskiy/adam/transport/KtorSocketFactory.kt +++ b/src/main/kotlin/com/malinskiy/adam/transport/KtorSocketFactory.kt @@ -18,6 +18,7 @@ package com.malinskiy.adam.transport import io.ktor.network.selector.* import io.ktor.network.sockets.* +import io.ktor.utils.io.nio.* import java.net.InetSocketAddress import kotlin.coroutines.CoroutineContext @@ -28,10 +29,10 @@ class KtorSocketFactory( private val selectorManager: SelectorManager = ActorSelectorManager(coroutineContext) override suspend fun tcp(socketAddress: InetSocketAddress): Socket { - return aSocket(selectorManager) - .tcp() - .connect(socketAddress) { - socketTimeout = this@KtorSocketFactory.socketTimeout - } + return KtorSocket(aSocket(selectorManager) + .tcp() + .connect(socketAddress) { + socketTimeout = this@KtorSocketFactory.socketTimeout + }) } -} +} \ No newline at end of file diff --git a/src/main/kotlin/com/malinskiy/adam/transport/NioSocket.kt b/src/main/kotlin/com/malinskiy/adam/transport/NioSocket.kt new file mode 100644 index 000000000..410729e88 --- /dev/null +++ b/src/main/kotlin/com/malinskiy/adam/transport/NioSocket.kt @@ -0,0 +1,323 @@ +/* + * Copyright (C) 2021 Anton Malinskiy + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.malinskiy.adam.transport + +import kotlinx.coroutines.isActive +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.coroutines.yield +import java.io.IOException +import java.net.InetSocketAddress +import java.net.SocketTimeoutException +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.channels.SelectionKey +import java.nio.channels.Selector +import java.nio.channels.SocketChannel +import java.nio.channels.spi.SelectorProvider +import java.util.concurrent.atomic.AtomicReference + +class NioSocket( + private val socketAddress: InetSocketAddress, + private val connectTimeout: Long, + private val idleTimeout: Long, +) : Socket { + private val state = AtomicReference(State.CLOSED) + private val mutex = Mutex() + + override val isClosedForWrite: Boolean + get() = socketChannel.socket().isOutputShutdown || state.get() == State.CLOSING + override val isClosedForRead: Boolean + get() = socketChannel.socket().isInputShutdown || state.get() == State.CLOSE_WAIT + + private lateinit var selector: Selector + private lateinit var socketChannel: SocketChannel + private lateinit var selectionKey: SelectionKey + + @Suppress("BlockingMethodInNonBlockingContext") + suspend fun connect() { + if (!state.compareAndSet(State.CLOSED, State.SYN_SENT)) return + + socketChannel = SelectorProvider.provider().openSocketChannel().apply { + configureBlocking(false) + configure(socket()) + } + + val success = socketChannel.connect(socketAddress) + selector = SelectorProvider.provider().openSelector() + if (success) { + processAccept(selector) + } else { + processConnect(selector) + } + } + + private fun configure(socket: java.net.Socket) { + socket.tcpNoDelay = true + } + + override suspend fun writeFully(byteBuffer: ByteBuffer) { + if (state.get() != State.ESTABLISHED) return + + processWrite(selector, byteBuffer) + } + + override suspend fun writeFully(toByteArray: ByteArray, offset: Int, limit: Int) = + writeFully(ByteBuffer.wrap(toByteArray, offset, limit)) + + override suspend fun readAvailable(buffer: ByteArray, offset: Int, limit: Int): Int = + readAvailable(ByteBuffer.wrap(buffer, offset, limit)) + + suspend fun readAvailable(buffer: ByteBuffer): Int { + if (isClosedForRead) return -1 + + return processRead(selector, buffer) + } + + override suspend fun readFully(buffer: ByteBuffer): Int { + var remaining = buffer.limit() + + return withTimeoutOrNull(idleTimeout) { + while (remaining != 0) { + when (val read = readAvailable(buffer)) { + -1 -> { + if (remaining == buffer.limit()) return@withTimeoutOrNull -1 + } + 0 -> Unit + else -> { + remaining -= read + } + } + yield() + } + + return@withTimeoutOrNull remaining + } ?: throw SocketTimeoutException("Timeout reading") + } + + override suspend fun readFully(buffer: ByteArray, offset: Int, limit: Int) { + readFully(ByteBuffer.wrap(buffer, offset, limit)) + } + + override suspend fun readByte(): Byte { + val buffer = ByteBuffer.allocate(1) + val read = readFully(buffer) + //TODO: handle EOF + return buffer.array()[0] + } + + override suspend fun writeByte(value: Int) { + writeFully(ByteArray(1) { value.toByte() }) + } + + override suspend fun readIntLittleEndian(): Int { + val allocate = ByteBuffer.allocate(4) + allocate.order(ByteOrder.LITTLE_ENDIAN) + val read = readFully(allocate) + allocate.flip() + return allocate.int + } + + override suspend fun writeIntLittleEndian(value: Int) { + val allocate = ByteBuffer.allocate(4) + allocate.order(ByteOrder.LITTLE_ENDIAN) + allocate.putInt(value) + allocate.flip() + writeFully(allocate) + } + + override suspend fun close() { + mutex.withLock { + val shouldDrain = when { + state.compareAndSet(State.ESTABLISHED, State.CLOSING) -> { + true + } + state.compareAndSet(State.CLOSE_WAIT, State.CLOSING) -> { + false + } + else -> { + return + } + } + + if (!socketChannel.socket().isOutputShutdown) { + socketChannel.socket().shutdownOutput() + } + + if (shouldDrain) { + val buffer = ByteBuffer.allocate(128) + while (true) { + buffer.clear() + if (readUnsafe(selector, buffer) == -1 || state.get() == State.CLOSED || isClosedForRead) { + break + } else { + yield() + } + } + } + + state.compareAndSet(State.CLOSING, State.CLOSED) + if (!socketChannel.socket().isInputShutdown) { + socketChannel.socket().shutdownInput() + } + selectionKey.cancel() + selector.close() + socketChannel.close() + socketChannel.socket().close() + } + } + + @Suppress("BlockingMethodInNonBlockingContext") + private suspend fun processConnect(selector: Selector) { + mutex.withLock { + selectionKey = socketChannel.register(selector, SelectionKey.OP_CONNECT) + withTimeoutOrNull(connectTimeout) { + while (isActive) { + if (selector.selectNow() == 0) yield() + val iterator = selector.selectedKeys().iterator() + while (iterator.hasNext()) { + val selectionKey = iterator.next() + if (selectionKey.isConnectable) { + socketChannel.finishConnect() + selectionKey.interestOps(0) + + val success = state.compareAndSet(State.SYN_SENT, State.ESTABLISHED) + if (!success) throw IllegalStateException("Invalid state ${state.get()} after connect") + iterator.remove() + + return@withTimeoutOrNull + } + } + } + } + selectionKey.interestOps(0) + if (socketChannel.isConnectionPending) { + try { + socketChannel.close() + } catch (e: IOException) { + //ignore + } + throw SocketTimeoutException("Channel $socketChannel timeout while connecting. Closing") + } + } + } + + @Suppress("BlockingMethodInNonBlockingContext") + private suspend fun processAccept(selector: Selector) { + mutex.withLock { + selectionKey = socketChannel.register(selector, 0) + } + } + + @Suppress("BlockingMethodInNonBlockingContext") + private suspend fun processRead(selector: Selector, buffer: ByteBuffer): Int { + mutex.withLock { + if (state.get() != State.ESTABLISHED) return 0 + return readUnsafe(selector, buffer) + } + } + + private fun readUnsafe(selector: Selector, buffer: ByteBuffer): Int { + selectionKey.interestOps(SelectionKey.OP_READ) + + selector.selectNow() + val selectedKeys = when { + selector.selectedKeys().isNotEmpty() -> selector.selectedKeys() + else -> { + selectionKey.interestOps(0) + return 0 + } + } + + val iterator = selectedKeys.iterator() + while (iterator.hasNext()) { + val selectionKey = iterator.next() + if (selectionKey.isReadable) { + val read = socketChannel.read(buffer) + if (read == -1) { + when (state.get()) { + State.ESTABLISHED -> { + state.set(State.CLOSE_WAIT) + } + State.CLOSING -> state.set(State.CLOSED) + } + } + selectionKey.interestOps(0) + return read + } + } + + selectionKey.interestOps(0) + return 0 + } + + @Suppress("BlockingMethodInNonBlockingContext") + private suspend fun processWrite(selector: Selector, buffer: ByteBuffer) { + mutex.withLock { + if (state.get() != State.ESTABLISHED) return + + selectionKey.interestOps(SelectionKey.OP_WRITE) + + var remaining = buffer.limit() + val timeout = withTimeoutOrNull(idleTimeout) { + while (true) { + selector.selectNow() + + val selectedKeys = when { + selector.selectedKeys().isNotEmpty() -> selector.selectedKeys() + else -> { + yield() + continue + } + } + + val iterator = selectedKeys.iterator() + var processed = false + while (iterator.hasNext()) { + val selectionKey = iterator.next() + if (selectionKey.isWritable) { + iterator.remove() + val written = socketChannel.write(buffer) + remaining -= written + if (written != 0) { + processed = true + } + if (remaining == 0) { + selectionKey.interestOps(0) + return@withTimeoutOrNull + } + } + } + if (!processed) yield() + } + } + selectionKey.interestOps(0) + if (timeout == null) { + throw SocketTimeoutException("Timeout writing") + } + } + } + + private enum class State { + CLOSED, + SYN_SENT, + ESTABLISHED, + CLOSING, + CLOSE_WAIT + } +} diff --git a/src/main/kotlin/com/malinskiy/adam/extension/Channel.kt b/src/main/kotlin/com/malinskiy/adam/transport/NioSocketFactory.kt similarity index 56% rename from src/main/kotlin/com/malinskiy/adam/extension/Channel.kt rename to src/main/kotlin/com/malinskiy/adam/transport/NioSocketFactory.kt index 12a18225d..37d036b9b 100644 --- a/src/main/kotlin/com/malinskiy/adam/extension/Channel.kt +++ b/src/main/kotlin/com/malinskiy/adam/transport/NioSocketFactory.kt @@ -1,5 +1,5 @@ /* - * Copyright (C) 2019 Anton Malinskiy + * Copyright (C) 2021 Anton Malinskiy * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,18 @@ * limitations under the License. */ -package com.malinskiy.adam.extension +package com.malinskiy.adam.transport -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel -import io.ktor.utils.io.ByteReadChannel -import io.ktor.utils.io.ByteWriteChannel +import java.net.InetSocketAddress - -fun ByteReadChannel.toAndroidChannel() = AndroidReadChannel(this) -fun ByteWriteChannel.toAndroidChannel() = AndroidWriteChannel(this) \ No newline at end of file +class NioSocketFactory : SocketFactory { + override suspend fun tcp(socketAddress: InetSocketAddress): Socket { + val nioSocket = NioSocket( + socketAddress = socketAddress, + connectTimeout = 10_000, + idleTimeout = 10_000, + ) + nioSocket.connect() + return nioSocket + } +} diff --git a/src/main/kotlin/com/malinskiy/adam/transport/Socket.kt b/src/main/kotlin/com/malinskiy/adam/transport/Socket.kt new file mode 100644 index 000000000..81faa9830 --- /dev/null +++ b/src/main/kotlin/com/malinskiy/adam/transport/Socket.kt @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2021 Anton Malinskiy + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.malinskiy.adam.transport + +import java.nio.ByteBuffer + +interface Socket : SuspendCloseable { + val isClosedForWrite: Boolean + val isClosedForRead: Boolean + + suspend fun writeFully(byteBuffer: ByteBuffer) + suspend fun writeFully(toByteArray: ByteArray, offset: Int, limit: Int) + + suspend fun readAvailable(buffer: ByteArray, offset: Int, limit: Int): Int + suspend fun readFully(buffer: ByteBuffer): Int + suspend fun readFully(buffer: ByteArray, offset: Int, limit: Int) + + suspend fun readByte(): Byte + suspend fun writeByte(value: Int) + + suspend fun readIntLittleEndian(): Int + suspend fun writeIntLittleEndian(value: Int) + + suspend fun writeFully(byteArray: ByteArray) = writeFully(byteArray, 0, byteArray.size) +} \ No newline at end of file diff --git a/src/main/kotlin/com/malinskiy/adam/transport/SocketFactory.kt b/src/main/kotlin/com/malinskiy/adam/transport/SocketFactory.kt index cc0026ef9..32a25ab37 100644 --- a/src/main/kotlin/com/malinskiy/adam/transport/SocketFactory.kt +++ b/src/main/kotlin/com/malinskiy/adam/transport/SocketFactory.kt @@ -16,7 +16,6 @@ package com.malinskiy.adam.transport -import io.ktor.network.sockets.Socket import java.net.InetSocketAddress interface SocketFactory { diff --git a/src/main/kotlin/com/malinskiy/adam/transport/SuspendCloseable.kt b/src/main/kotlin/com/malinskiy/adam/transport/SuspendCloseable.kt new file mode 100644 index 000000000..722dd74fc --- /dev/null +++ b/src/main/kotlin/com/malinskiy/adam/transport/SuspendCloseable.kt @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2021 Anton Malinskiy + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.malinskiy.adam.transport + +import kotlinx.coroutines.NonCancellable +import kotlinx.coroutines.withContext + +interface SuspendCloseable { + suspend fun close() +} + +suspend inline fun C.use(block: suspend (C) -> R): R { + var closed = false + + return try { + block(this) + } catch (first: Throwable) { + try { + closed = true + withContext(NonCancellable) { + close() + } + } catch (second: Throwable) { + try { + Throwable::class.java.getMethod("addSuppressed", Throwable::class.java).invoke(this, second) + } catch (t: Throwable) { + null + } + } + + throw first + } finally { + if (!closed) { + withContext(NonCancellable) { + close() + } + } + } +} \ No newline at end of file diff --git a/src/test/kotlin/com/malinskiy/adam/AndroidDebugBridgeClientTest.kt b/src/test/kotlin/com/malinskiy/adam/AndroidDebugBridgeClientTest.kt index d3a027f38..0a2e67adb 100644 --- a/src/test/kotlin/com/malinskiy/adam/AndroidDebugBridgeClientTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/AndroidDebugBridgeClientTest.kt @@ -24,8 +24,7 @@ import com.malinskiy.adam.request.ComplexRequest import com.malinskiy.adam.request.ValidationResponse import com.malinskiy.adam.request.shell.v1.ShellCommandRequest import com.malinskiy.adam.server.AndroidDebugBridgeServer -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel +import com.malinskiy.adam.transport.Socket import io.ktor.utils.io.* import kotlinx.coroutines.runBlocking import org.junit.Test @@ -134,7 +133,7 @@ class AndroidDebugBridgeClientTest { client.execute(object : ComplexRequest() { override fun validate() = ValidationResponse(false, "Fake") - override suspend fun readElement(readChannel: AndroidReadChannel, writeChannel: AndroidWriteChannel): String { + override suspend fun readElement(socket: Socket): String { TODO("Not yet implemented") } diff --git a/src/test/kotlin/com/malinskiy/adam/request/abb/AbbExecRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/abb/AbbExecRequestTest.kt index c7c45efc6..6010d0464 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/abb/AbbExecRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/abb/AbbExecRequestTest.kt @@ -20,15 +20,11 @@ import assertk.assertThat import assertk.assertions.isEqualTo import assertk.assertions.isFalse import assertk.assertions.isTrue -import com.malinskiy.adam.extension.toAndroidChannel +import com.malinskiy.adam.Const import com.malinskiy.adam.extension.toRequestString import com.malinskiy.adam.request.Feature -import com.malinskiy.adam.transport.AndroidReadChannel -import com.malinskiy.adam.transport.AndroidWriteChannel -import io.ktor.util.cio.readChannel -import io.ktor.util.cio.writeChannel -import io.ktor.utils.io.cancel -import io.ktor.utils.io.close +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import kotlinx.coroutines.runBlocking import org.junit.Rule import org.junit.Test @@ -61,21 +57,10 @@ class AbbExecRequestTest { @Test fun testDummy() { runBlocking { - val newFile = temp.newFile().apply { writeText("cafebabe") } - var readChannel: AndroidReadChannel? = null - var writeChannel: AndroidWriteChannel? = null - try { - readChannel = newFile.readChannel(coroutineContext = coroutineContext).toAndroidChannel() - writeChannel = newFile.writeChannel(coroutineContext).toAndroidChannel() + StubSocket("cafebabe".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING)).use { socket -> assertThat( - AbbExecRequest(listOf(), supportedFeatures = emptyList()).readElement( - readChannel, - writeChannel - ) + AbbExecRequest(listOf(), supportedFeatures = emptyList()).readElement(socket) ).isEqualTo("cafebabe") - } finally { - readChannel?.cancel() - writeChannel?.close() } } } diff --git a/src/test/kotlin/com/malinskiy/adam/request/device/FetchDeviceFeaturesRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/device/FetchDeviceFeaturesRequestTest.kt index 3c2cbdfc8..38f6b6677 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/device/FetchDeviceFeaturesRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/device/FetchDeviceFeaturesRequestTest.kt @@ -20,10 +20,9 @@ import assertk.assertThat import assertk.assertions.containsExactly import assertk.assertions.isEqualTo import com.malinskiy.adam.Const -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.request.Feature -import io.ktor.utils.io.* -import io.ktor.utils.io.core.* +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import kotlinx.coroutines.runBlocking import org.junit.Test @@ -44,25 +43,23 @@ class FetchDeviceFeaturesRequestTest { fun testResponse() { val fetchDeviceFeaturesRequest = FetchDeviceFeaturesRequest("cafebabe") runBlocking { - val response = "0054fixed_push_symlink_timestamp,apex,fixed_push_mkdir,stat_v2,abb_exec,cmd,abb,shell_v2" - .toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) + StubSocket( + content = "0054fixed_push_symlink_timestamp,apex,fixed_push_mkdir,stat_v2,abb_exec,cmd,abb,shell_v2" + .toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) + ).use { socket -> + val features = fetchDeviceFeaturesRequest.readElement(socket) - val features = fetchDeviceFeaturesRequest.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) - - assertThat(features).containsExactly( - Feature.FIXED_PUSH_SYMLINK_TIMESTAMP, - Feature.APEX, - Feature.FIXED_PUSH_MKDIR, - Feature.STAT_V2, - Feature.ABB_EXEC, - Feature.CMD, - Feature.ABB, - Feature.SHELL_V2 - ) + assertThat(features).containsExactly( + Feature.FIXED_PUSH_SYMLINK_TIMESTAMP, + Feature.APEX, + Feature.FIXED_PUSH_MKDIR, + Feature.STAT_V2, + Feature.ABB_EXEC, + Feature.CMD, + Feature.ABB, + Feature.SHELL_V2 + ) + } } } } diff --git a/src/test/kotlin/com/malinskiy/adam/request/framebuffer/screencapture/BufferedImageScreenCaptureAdapterTest.kt b/src/test/kotlin/com/malinskiy/adam/request/framebuffer/screencapture/BufferedImageScreenCaptureAdapterTest.kt index e61094dd5..8110d621e 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/framebuffer/screencapture/BufferedImageScreenCaptureAdapterTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/framebuffer/screencapture/BufferedImageScreenCaptureAdapterTest.kt @@ -16,9 +16,9 @@ package com.malinskiy.adam.request.framebuffer.screencapture -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.request.framebuffer.BufferedImageScreenCaptureAdapter -import io.ktor.utils.io.* +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import kotlinx.coroutines.runBlocking import org.junit.Test @@ -27,26 +27,25 @@ class BufferedImageScreenCaptureAdapterTest { fun testThrowsExceptionIfUnsupportedImage() { val adapter = BufferedImageScreenCaptureAdapter() runBlocking { - val byteChannel = ByteChannel(autoFlush = true) - byteChannel.writeByte(0) - adapter.process( - 1, - 24, - 1, - 1, - 1, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - null, - (byteChannel as ByteReadChannel).toAndroidChannel() - ) - byteChannel.close() + StubSocket(ByteArray(1) { 0 }).use { socket -> + adapter.process( + 1, + 24, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + null, + socket + ) + } } } } diff --git a/src/test/kotlin/com/malinskiy/adam/request/mdns/ListMdnsServicesRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/mdns/ListMdnsServicesRequestTest.kt index 613e92a10..0e8571386 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/mdns/ListMdnsServicesRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/mdns/ListMdnsServicesRequestTest.kt @@ -20,13 +20,9 @@ import assertk.assertThat import assertk.assertions.containsExactly import assertk.assertions.isEqualTo import com.malinskiy.adam.Const -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.extension.toRequestString -import io.ktor.utils.io.ByteChannelSequentialJVM -import io.ktor.utils.io.ByteReadChannel -import io.ktor.utils.io.ByteWriteChannel -import io.ktor.utils.io.close -import io.ktor.utils.io.core.IoBuffer +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import kotlinx.coroutines.runBlocking import org.junit.Test @@ -40,26 +36,21 @@ class ListMdnsServicesRequestTest { fun testReturnsContent() { runBlocking { val response = "0027adb-serial\t_adb._tcp.\t192.168.1.2:9999\n".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) - val services = ListMdnsServicesRequest().readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) - - assertThat(services).containsExactly( - MdnsService( - name = "adb-serial", - serviceType = "_adb._tcp.", - url = "192.168.1.2:9999" + StubSocket(response).use { socket -> + val services = ListMdnsServicesRequest().readElement(socket) + assertThat(services).containsExactly( + MdnsService( + name = "adb-serial", + serviceType = "_adb._tcp.", + url = "192.168.1.2:9999" + ) ) - ) - - assertThat(services.first().name).isEqualTo("adb-serial") - assertThat(services.first().serviceType).isEqualTo("_adb._tcp.") - assertThat(services.first().url).isEqualTo("192.168.1.2:9999") - byteBufferChannel.close() + assertThat(services.first().name).isEqualTo("adb-serial") + assertThat(services.first().serviceType).isEqualTo("_adb._tcp.") + assertThat(services.first().url).isEqualTo("192.168.1.2:9999") + } } } } diff --git a/src/test/kotlin/com/malinskiy/adam/request/mdns/MdnsCheckRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/mdns/MdnsCheckRequestTest.kt index aff1134b4..6b9e2dfb2 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/mdns/MdnsCheckRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/mdns/MdnsCheckRequestTest.kt @@ -19,13 +19,9 @@ package com.malinskiy.adam.request.mdns import assertk.assertThat import assertk.assertions.isEqualTo import com.malinskiy.adam.Const -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.extension.toRequestString -import io.ktor.utils.io.ByteChannelSequentialJVM -import io.ktor.utils.io.ByteReadChannel -import io.ktor.utils.io.ByteWriteChannel -import io.ktor.utils.io.close -import io.ktor.utils.io.core.IoBuffer +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import kotlinx.coroutines.runBlocking import org.junit.Test @@ -39,15 +35,11 @@ class MdnsCheckRequestTest { fun testReturnsContent() { runBlocking { val response = "001Dmdns daemon version [8787003]".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) - val value = MdnsCheckRequest().readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) - - assertThat(value).isEqualTo(MdnsStatus(true, "8787003")) - byteBufferChannel.close() + StubSocket(response).use { socket -> + val value = MdnsCheckRequest().readElement(socket) + assertThat(value).isEqualTo(MdnsStatus(true, "8787003")) + } } } @@ -55,15 +47,12 @@ class MdnsCheckRequestTest { fun testReturnsUnavailable() { runBlocking { val response = "001Dmdns daemon unavailable87003]".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) - val value = MdnsCheckRequest().readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) + StubSocket(response).use { socket -> + val value = MdnsCheckRequest().readElement(socket) + assertThat(value).isEqualTo(MdnsStatus(false, null)) - assertThat(value).isEqualTo(MdnsStatus(false, null)) - byteBufferChannel.close() + } } } } diff --git a/src/test/kotlin/com/malinskiy/adam/request/misc/FetchHostFeaturesRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/misc/FetchHostFeaturesRequestTest.kt index 04e3b3df0..fbde962b2 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/misc/FetchHostFeaturesRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/misc/FetchHostFeaturesRequestTest.kt @@ -20,14 +20,12 @@ import assertk.assertThat import assertk.assertions.containsExactly import assertk.assertions.isEqualTo import com.malinskiy.adam.Const -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.extension.toRequestString import com.malinskiy.adam.request.Feature -import io.ktor.utils.io.* -import io.ktor.utils.io.core.* +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import kotlinx.coroutines.runBlocking import org.junit.Test -import kotlin.text.toByteArray class FetchHostFeaturesRequestTest { @Test @@ -41,23 +39,20 @@ class FetchHostFeaturesRequestTest { runBlocking { val response = "0054fixed_push_symlink_timestamp,apex,fixed_push_mkdir,stat_v2,abb_exec,cmd,abb,shell_v2" .toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) - val features = request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) - - assertThat(features).containsExactly( - Feature.FIXED_PUSH_SYMLINK_TIMESTAMP, - Feature.APEX, - Feature.FIXED_PUSH_MKDIR, - Feature.STAT_V2, - Feature.ABB_EXEC, - Feature.CMD, - Feature.ABB, - Feature.SHELL_V2 - ) + StubSocket(response).use { socket -> + val features = request.readElement(socket) + assertThat(features).containsExactly( + Feature.FIXED_PUSH_SYMLINK_TIMESTAMP, + Feature.APEX, + Feature.FIXED_PUSH_MKDIR, + Feature.STAT_V2, + Feature.ABB_EXEC, + Feature.CMD, + Feature.ABB, + Feature.SHELL_V2 + ) + } } } } \ No newline at end of file diff --git a/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/AddSessionRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/AddSessionRequestTest.kt index 3a11072e7..d4c0054a4 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/AddSessionRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/AddSessionRequestTest.kt @@ -20,16 +20,13 @@ import assertk.assertThat import assertk.assertions.isEqualTo import com.malinskiy.adam.Const import com.malinskiy.adam.exception.RequestRejectedException -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.extension.toRequestString import com.malinskiy.adam.request.Feature -import io.ktor.utils.io.* -import io.ktor.utils.io.core.* +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import kotlinx.coroutines.runBlocking import org.junit.Test -import kotlin.text.toByteArray - class AddSessionRequestTest { @Test fun serialize() { @@ -66,12 +63,10 @@ class AddSessionRequestTest { ) val response = "Success".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) runBlocking { - request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) + StubSocket(response).use { socket -> + request.readElement(socket) + } } } @@ -85,12 +80,10 @@ class AddSessionRequestTest { ) val response = "Failure".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) runBlocking { - request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) + StubSocket(response).use { socket -> + request.readElement(socket) + } } } } diff --git a/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/CreateIndividualPackageSessionRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/CreateIndividualPackageSessionRequestTest.kt index 3ec902e1f..58615e3c8 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/CreateIndividualPackageSessionRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/CreateIndividualPackageSessionRequestTest.kt @@ -21,13 +21,10 @@ import assertk.assertions.isEqualTo import com.malinskiy.adam.Const import com.malinskiy.adam.exception.RequestRejectedException import com.malinskiy.adam.extension.newFileWithExtension -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.extension.toRequestString import com.malinskiy.adam.request.Feature -import io.ktor.utils.io.ByteChannelSequentialJVM -import io.ktor.utils.io.ByteReadChannel -import io.ktor.utils.io.ByteWriteChannel -import io.ktor.utils.io.core.IoBuffer +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import kotlinx.coroutines.runBlocking import org.junit.Rule import org.junit.Test @@ -104,13 +101,11 @@ class CreateIndividualPackageSessionRequestTest { fun testRead() { val request = stub() val response = "Success [my-session-id]".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) runBlocking { - val sessionId = request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) - assertThat(sessionId).isEqualTo("my-session-id") + StubSocket(response).use { socket -> + val sessionId = request.readElement(socket) + assertThat(sessionId).isEqualTo("my-session-id") + } } } @@ -118,12 +113,10 @@ class CreateIndividualPackageSessionRequestTest { fun testReadException() { val request = stub() val response = "Failure".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) runBlocking { - request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) + StubSocket(response).use { socket -> + request.readElement(socket) + } } } @@ -131,12 +124,10 @@ class CreateIndividualPackageSessionRequestTest { fun testReadNoSession() { val request = stub() val response = "Success no session returned".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) runBlocking { - request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) + StubSocket(response).use { socket -> + request.readElement(socket) + } } } diff --git a/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/CreateMultiPackageSessionRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/CreateMultiPackageSessionRequestTest.kt index 09dd7c728..824be6cc2 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/CreateMultiPackageSessionRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/CreateMultiPackageSessionRequestTest.kt @@ -23,10 +23,11 @@ import assertk.assertions.isTrue import com.malinskiy.adam.Const import com.malinskiy.adam.exception.RequestRejectedException import com.malinskiy.adam.extension.newFileWithExtension -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.extension.toRequestString import com.malinskiy.adam.request.Feature import com.malinskiy.adam.request.ValidationResponse +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import io.ktor.utils.io.* import io.ktor.utils.io.core.* import kotlinx.coroutines.runBlocking @@ -106,11 +107,10 @@ class CreateMultiPackageSessionRequestTest { val response = "Success [my-session-id]".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) runBlocking { - val sessionId = request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) - assertThat(sessionId).isEqualTo("my-session-id") + StubSocket(response).use { socket -> + val sessionId = request.readElement(socket) + assertThat(sessionId).isEqualTo("my-session-id") + } } } @@ -120,10 +120,9 @@ class CreateMultiPackageSessionRequestTest { val response = "Failure".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) runBlocking { - request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) + StubSocket(response).use { socket -> + request.readElement(socket) + } } } @@ -131,12 +130,10 @@ class CreateMultiPackageSessionRequestTest { fun testReadNoSession() { val request = stub() val response = "Success no session returned".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) runBlocking { - request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) + StubSocket(response).use { socket -> + request.readElement(socket) + } } } diff --git a/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/InstallCommitRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/InstallCommitRequestTest.kt index 466a3953f..a2c26da20 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/InstallCommitRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/InstallCommitRequestTest.kt @@ -20,13 +20,10 @@ import assertk.assertThat import assertk.assertions.isEqualTo import com.malinskiy.adam.Const import com.malinskiy.adam.exception.RequestRejectedException -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.extension.toRequestString import com.malinskiy.adam.request.Feature -import io.ktor.utils.io.ByteChannelSequentialJVM -import io.ktor.utils.io.ByteReadChannel -import io.ktor.utils.io.ByteWriteChannel -import io.ktor.utils.io.core.IoBuffer +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import kotlinx.coroutines.runBlocking import org.junit.Test @@ -63,12 +60,10 @@ class InstallCommitRequestTest { fun testRead() { val request = stub() val response = "Success [my-session-id]".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) runBlocking { - request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) + StubSocket(response).use { socket -> + request.readElement(socket) + } } } @@ -76,12 +71,10 @@ class InstallCommitRequestTest { fun testReadException() { val request = stub() val response = "Failure".toByteArray(Const.DEFAULT_TRANSPORT_ENCODING) - val byteBufferChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) runBlocking { - request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) + StubSocket(response).use { socket -> + request.readElement(socket) + } } } diff --git a/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/WriteIndividualPackageRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/WriteIndividualPackageRequestTest.kt index 2d1e91504..6ec5c120e 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/WriteIndividualPackageRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/pkg/multi/WriteIndividualPackageRequestTest.kt @@ -21,13 +21,12 @@ import assertk.assertions.isEqualTo import com.malinskiy.adam.Const import com.malinskiy.adam.exception.RequestRejectedException import com.malinskiy.adam.extension.newFileWithExtension -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.extension.toRequestString import com.malinskiy.adam.request.Feature -import io.ktor.util.cio.writeChannel -import io.ktor.utils.io.ByteReadChannel -import io.ktor.utils.io.ByteWriteChannel -import io.ktor.utils.io.close +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use +import io.ktor.util.cio.* +import io.ktor.utils.io.* import kotlinx.coroutines.runBlocking import org.junit.Rule import org.junit.Test @@ -82,11 +81,9 @@ class WriteIndividualPackageRequestTest { runBlocking { val byteWriteChannel = actual.writeChannel(coroutineContext) - request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteWriteChannel.toAndroidChannel() - ) - byteWriteChannel.close() + StubSocket(ByteReadChannel(response), byteWriteChannel).use { socket -> + request.readElement(socket) + } } assertThat(actual.readBytes()).isEqualTo(fixture.readBytes()) @@ -105,10 +102,9 @@ class WriteIndividualPackageRequestTest { val actual = temp.newFileWithExtension("apk") runBlocking { val byteBufferChannel: ByteWriteChannel = actual.writeChannel(coroutineContext) - request.readElement( - ByteReadChannel(response).toAndroidChannel(), - byteBufferChannel.toAndroidChannel() - ) + StubSocket(ByteReadChannel(response), byteBufferChannel).use { socket -> + request.readElement(socket) + } } assertThat(actual.readBytes()).isEqualTo(fixture.readBytes()) diff --git a/src/test/kotlin/com/malinskiy/adam/request/shell/v1/ChanneledShellCommandRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/shell/v1/ChanneledShellCommandRequestTest.kt index d816fc0a0..2a1e5d27f 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/shell/v1/ChanneledShellCommandRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/shell/v1/ChanneledShellCommandRequestTest.kt @@ -18,11 +18,9 @@ package com.malinskiy.adam.request.shell.v1 import assertk.assertThat import assertk.assertions.isEqualTo -import assertk.fail import com.malinskiy.adam.Const import com.malinskiy.adam.server.AndroidDebugBridgeServer -import io.ktor.utils.io.close -import kotlinx.coroutines.channels.receiveOrNull +import io.ktor.utils.io.* import kotlinx.coroutines.runBlocking import org.junit.Test @@ -51,8 +49,8 @@ class ChanneledShellCommandRequestTest { val updates = client.execute(ChanneledShellCommandRequest("logcat -v"), scope = this, serial = "emulator-5554") val stringBuffer = StringBuffer() - while (!updates.isClosedForReceive) { - stringBuffer.append(updates.receiveOrNull() ?: fail("should receive content")) + for (update in updates) { + stringBuffer.append(update) } assertThat(stringBuffer.toString()).isEqualTo("something-somethingsomething2-something2") diff --git a/src/test/kotlin/com/malinskiy/adam/request/shell/v2/ChanneledShellCommandRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/shell/v2/ChanneledShellCommandRequestTest.kt index 8b442875c..3f8d25253 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/shell/v2/ChanneledShellCommandRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/shell/v2/ChanneledShellCommandRequestTest.kt @@ -20,7 +20,7 @@ import assertk.assertThat import assertk.assertions.isEqualTo import com.malinskiy.adam.Const import com.malinskiy.adam.server.AndroidDebugBridgeServer -import io.ktor.utils.io.close +import io.ktor.utils.io.* import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.runBlocking @@ -56,6 +56,7 @@ class ChanneledShellCommandRequestTest { output.respondShellV2Exit(17) output.close() + input.discard() } val stdio = Channel() diff --git a/src/test/kotlin/com/malinskiy/adam/request/sync/compat/CompatPushFileRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/sync/compat/CompatPushFileRequestTest.kt index f2034dcb7..11d038ac2 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/sync/compat/CompatPushFileRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/sync/compat/CompatPushFileRequestTest.kt @@ -21,7 +21,7 @@ import assertk.assertions.isEqualTo import com.malinskiy.adam.Const import com.malinskiy.adam.request.Feature import com.malinskiy.adam.server.AndroidDebugBridgeServer -import io.ktor.utils.io.close +import io.ktor.utils.io.* import kotlinx.coroutines.channels.receiveOrNull import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking @@ -58,6 +58,9 @@ class CompatPushFileRequestTest { assertThat(receiveCmd).isEqualTo("/sdcard/testfile,511") input.receiveFile(receiveFile) output.respond(Const.Message.OKAY) + + input.discard() + output.close() } diff --git a/src/test/kotlin/com/malinskiy/adam/request/sync/v1/PullFileRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/sync/v1/PullFileRequestTest.kt index cf897b371..72b146aed 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/sync/v1/PullFileRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/sync/v1/PullFileRequestTest.kt @@ -22,7 +22,7 @@ import com.malinskiy.adam.Const import com.malinskiy.adam.exception.PullFailedException import com.malinskiy.adam.exception.UnsupportedSyncProtocolException import com.malinskiy.adam.server.AndroidDebugBridgeServer -import io.ktor.utils.io.writeIntLittleEndian +import io.ktor.utils.io.* import kotlinx.coroutines.channels.receiveOrNull import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking @@ -37,6 +37,10 @@ class PullFileRequestTest { @JvmField val temp = TemporaryFolder() +// @Rule +// @JvmField +// val coroutines = CoroutinesTimeout.seconds(5) + @Test fun testSerialize() { assertThat(String(PullFileRequest("/sdcard/testfile", File("/tmp/testfile")).serialize(), Const.DEFAULT_TRANSPORT_ENCODING)) @@ -209,6 +213,11 @@ class PullFileRequestTest { output.respond(Const.Message.DATA) output.respondData(ByteArray(Const.MAX_FILE_PACKET_LENGTH + 1)) + + input.discard() + while (input.isClosedForRead == false) { + } + output.close() } val request = PullFileRequest("/sdcard/testfile", tempFile) diff --git a/src/test/kotlin/com/malinskiy/adam/request/testrunner/TestRunnerRequestTest.kt b/src/test/kotlin/com/malinskiy/adam/request/testrunner/TestRunnerRequestTest.kt index ccf14d29f..15ba92b79 100644 --- a/src/test/kotlin/com/malinskiy/adam/request/testrunner/TestRunnerRequestTest.kt +++ b/src/test/kotlin/com/malinskiy/adam/request/testrunner/TestRunnerRequestTest.kt @@ -20,9 +20,12 @@ import assertk.assertThat import assertk.assertions.containsOnly import assertk.assertions.isEqualTo import com.malinskiy.adam.Const -import com.malinskiy.adam.extension.toAndroidChannel import com.malinskiy.adam.server.AndroidDebugBridgeServer +import com.malinskiy.adam.server.StubSocket +import com.malinskiy.adam.transport.use import io.ktor.utils.io.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.Channel.Factory.BUFFERED import kotlinx.coroutines.channels.receiveOrNull import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking @@ -103,15 +106,12 @@ class TestRunnerRequestTest { @Test fun testChannelIsEmpty() { val request = TestRunnerRequest("com.example.test", InstrumentOptions()) - val readChannel = ByteChannel(autoFlush = true) - val writeChannel = ByteChannel(autoFlush = true) runBlocking { - readChannel.close() - val readElement = request.readElement( - (readChannel as ByteReadChannel).toAndroidChannel(), - (writeChannel as ByteWriteChannel).toAndroidChannel() - ) - assertThat(readElement).isEqualTo(null) + StubSocket(ByteChannel(autoFlush = true).apply { close() }, ByteChannel(autoFlush = true)).use { socket -> + val channel = Channel>(BUFFERED) + val readElement = request.readElement(socket, channel) + assertThat(channel.poll()).isEqualTo(null) + } } } } diff --git a/src/test/kotlin/com/malinskiy/adam/server/AndroidDebugBridgeServer.kt b/src/test/kotlin/com/malinskiy/adam/server/AndroidDebugBridgeServer.kt index 38477193c..5299e0d85 100644 --- a/src/test/kotlin/com/malinskiy/adam/server/AndroidDebugBridgeServer.kt +++ b/src/test/kotlin/com/malinskiy/adam/server/AndroidDebugBridgeServer.kt @@ -18,6 +18,7 @@ package com.malinskiy.adam.server import com.malinskiy.adam.AndroidDebugBridgeClient import com.malinskiy.adam.AndroidDebugBridgeClientFactory +import com.malinskiy.adam.transport.NioSocketFactory import io.ktor.network.selector.* import io.ktor.network.sockets.* import io.ktor.util.network.* @@ -61,6 +62,7 @@ class AndroidDebugBridgeServer : CoroutineScope { val client = AndroidDebugBridgeClientFactory().apply { port = this@AndroidDebugBridgeServer.port + socketFactory = NioSocketFactory() }.build() return client diff --git a/src/test/kotlin/com/malinskiy/adam/server/StubSocket.kt b/src/test/kotlin/com/malinskiy/adam/server/StubSocket.kt new file mode 100644 index 000000000..094989e16 --- /dev/null +++ b/src/test/kotlin/com/malinskiy/adam/server/StubSocket.kt @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2021 Anton Malinskiy + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.malinskiy.adam.server + +import com.malinskiy.adam.log.AdamLogging +import com.malinskiy.adam.transport.Socket +import io.ktor.utils.io.* +import io.ktor.utils.io.core.* +import java.nio.ByteBuffer + +class StubSocket( + val readChannel: ByteReadChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false), + val writeChannel: ByteWriteChannel = ByteChannelSequentialJVM(IoBuffer.Empty, false) +) : Socket { + override val isClosedForWrite: Boolean + get() = writeChannel.isClosedForWrite + override val isClosedForRead: Boolean + get() = readChannel.isClosedForRead + + constructor(content: ByteArray) : this(readChannel = ByteReadChannel(content)) + + override suspend fun readFully(buffer: ByteBuffer): Int = readChannel.readFully(buffer) + override suspend fun readFully(buffer: ByteArray, offset: Int, limit: Int) = readChannel.readFully(buffer, offset, limit) + override suspend fun writeFully(byteBuffer: ByteBuffer) = writeChannel.writeFully(byteBuffer) + override suspend fun writeFully(toByteArray: ByteArray, offset: Int, limit: Int) = writeChannel.writeFully(toByteArray, offset, limit) + override suspend fun readAvailable(buffer: ByteArray, offset: Int, limit: Int): Int = readChannel.readAvailable(buffer, offset, limit) + override suspend fun readByte(): Byte = readChannel.readByte() + override suspend fun readIntLittleEndian(): Int = readChannel.readIntLittleEndian() + override suspend fun writeByte(value: Int) = writeChannel.writeByte(value) + override suspend fun writeIntLittleEndian(value: Int) = writeChannel.writeIntLittleEndian(value) + + override suspend fun close() { + try { + writeChannel.close() + readChannel.cancel() + } catch (e: Exception) { + log.debug(e) { "Exception during cleanup. Ignoring" } + } + } + + companion object { + private val log = AdamLogging.logger {} + } +} \ No newline at end of file