Skip to content

Commit

Permalink
Merge pull request #23 from Malinskiy/feat/buffer-pooling
Browse files Browse the repository at this point in the history
feat(transport): optimize buffer pooling + yield when possible
  • Loading branch information
Malinskiy authored Jan 21, 2021
2 parents 213a534 + 6a94700 commit e59cd6a
Show file tree
Hide file tree
Showing 41 changed files with 621 additions and 471 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
- name: gradle test jacocoTestReport
run: ./gradlew test jacocoTestReport
- name: archive test results
if: failure()
run: (cd build/reports/tests/test; zip -r -X ../../../../test-result.zip .)
- name: Save test output
uses: actions/upload-artifact@master
Expand Down Expand Up @@ -55,6 +56,7 @@ jobs:
- name: Generate integration code coverage report
run: ./gradlew jacocoIntegrationTestReport
- name: archive integration test results
if: failure()
run: (cd build/reports/tests/integrationTest; zip -r -X ../../../../integration-test-result.zip .)
- name: Save integration test output
uses: actions/upload-artifact@master
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.malinskiy.adam.request.shell.v1.ShellCommandRequest
import com.malinskiy.adam.request.sync.v1.PushFileRequest
import com.malinskiy.adam.rule.AdbDeviceRule
import kotlinx.coroutines.channels.receiveOrNull
import kotlinx.coroutines.debug.junit4.CoroutinesTimeout
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import org.junit.After
Expand All @@ -41,6 +42,10 @@ class ApkE2ETest {
val adb = AdbDeviceRule()
val client = adb.adb

@Rule
@JvmField
val timeout = CoroutinesTimeout.seconds(60)

@Before
fun setup() {
runBlocking {
Expand Down
1 change: 0 additions & 1 deletion src/main/kotlin/com/malinskiy/adam/Const.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ object Const {
const val SERVER_PORT_ENV_VAR = "ANDROID_ADB_SERVER_PORT"
const val MAX_PACKET_LENGTH = 16384
const val MAX_FILE_PACKET_LENGTH = 64 * 1024
const val KTOR_INTERNAL_BUFFER_LENGTH = 4088

const val MAX_PROTOBUF_LOGCAT_LENGTH = 10_000
const val MAX_PROTOBUF_PACKET_LENGTH = 10 * 1024 * 1024L //10Mb
Expand Down
2 changes: 2 additions & 0 deletions src/main/kotlin/com/malinskiy/adam/extension/ByteBuffer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ import java.nio.ByteBuffer
fun ByteBuffer.compatRewind() = ((this as Buffer).rewind() as ByteBuffer)
fun ByteBuffer.compatLimit(newLimit: Int) = ((this as Buffer).limit(newLimit) as ByteBuffer)
fun ByteBuffer.compatPosition(newLimit: Int) = ((this as Buffer).position(newLimit) as ByteBuffer)
fun ByteBuffer.compatFlip() = ((this as Buffer).flip() as ByteBuffer)
fun ByteBuffer.compatClear() = ((this as Buffer).clear() as ByteBuffer)
25 changes: 24 additions & 1 deletion src/main/kotlin/com/malinskiy/adam/extension/ByteReadChannel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,27 @@ suspend fun ByteReadChannel.copyTo(socket: Socket, buffer: ByteArray): Long {
}
}
return processed
}
}

/**
* Copies up to limit bytes into transformer using buffer. If limit is null - copy until EOF
*/
suspend fun ByteReadChannel.copyTo(buffer: ByteArray, offset: Int, limit: Int): Int {
var processed = 0
loop@ while (true) {
val toRead = (buffer.size - offset) - processed
val available = readAvailable(buffer, offset + processed, toRead)
when {
processed == limit -> break@loop
available < 0 && processed != 0 -> {
break@loop
}
available < 0 -> return available
available > 0 -> {
processed += available
}
else -> continue@loop
}
}
return processed
}
82 changes: 61 additions & 21 deletions src/main/kotlin/com/malinskiy/adam/extension/Socket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.malinskiy.adam.extension

import com.malinskiy.adam.Const
import com.malinskiy.adam.exception.RequestRejectedException
import com.malinskiy.adam.request.transform.ResponseTransformer
import com.malinskiy.adam.request.transform.StringResponseTransformer
import com.malinskiy.adam.transport.Socket
Expand All @@ -26,6 +27,7 @@ import io.ktor.util.cio.*
import io.ktor.utils.io.*
import io.ktor.utils.io.bits.*
import io.ktor.utils.io.core.*
import kotlinx.coroutines.yield
import java.io.File
import java.nio.ByteBuffer
import kotlin.coroutines.CoroutineContext
Expand All @@ -43,7 +45,10 @@ suspend fun Socket.copyTo(channel: ByteWriteChannel, buffer: ByteArray): Long {
channel.writeFully(buffer, 0, available)
processed += available
}
else -> continue@loop
else -> {
yield()
continue@loop
}
}
}
return processed
Expand Down Expand Up @@ -73,7 +78,10 @@ suspend fun <T> Socket.copyTo(transformer: ResponseTransformer<T>, buffer: ByteA
transformer.process(buffer, 0, available)
processed += available
}
else -> continue@loop
else -> {
yield()
continue@loop
}
}
}
return processed
Expand All @@ -98,27 +106,56 @@ suspend fun Socket.readOptionalProtocolString(): String? {
return if (errorMessageLength == null) {
readStatus()
} else {
val errorBytes = ByteArray(errorMessageLength)
readFully(errorBytes, 0, errorMessageLength)
String(errorBytes, Const.DEFAULT_TRANSPORT_ENCODING)
withDefaultBuffer {
val transformer = StringResponseTransformer()
this@readOptionalProtocolString.copyTo(transformer, this, limit = errorMessageLength.toLong())
transformer.transform()
}
}
}

suspend fun Socket.read(): TransportResponse {
val bytes = ByteArray(4)
readFully(bytes, 0, 4)
/**
* @throws RequestRejectedException
*/
suspend fun Socket.readProtocolString(): String {
withDefaultBuffer {
val transformer = StringResponseTransformer()
val copied = copyTo(transformer, this, limit = 4L)
val length = transformer.transform()
if (copied != 4L) {
throw RequestRejectedException("Unexpected string length: $length")
}
val messageLength = length.toIntOrNull(16) ?: throw RequestRejectedException("Unexpected string length: $length")

compatClear()
compatLimit(messageLength)
val read = readFully(this)
if (read != messageLength) throw RequestRejectedException("Incomplete string received")
return String(array(), 0, read, Const.DEFAULT_TRANSPORT_ENCODING)
}
}

val ok = bytes.isOkay()
suspend fun Socket.read(): TransportResponse {
val ok = withDefaultBuffer {
compatLimit(4)
readFully(this)
isOkay()
}
val message = if (!ok) {
readOptionalProtocolString()
} else {
null
}

return TransportResponse(ok, message)
}

private fun ByteArray.isOkay() = contentEquals(Const.Message.OKAY)
private fun ByteBuffer.isOkay(): Boolean {
if (limit() != 4) return false
for (i in 0..3) {
if (get(i) != Const.Message.OKAY[i]) return false
}
return true
}

suspend fun Socket.readStatus(): String {
withDefaultBuffer {
Expand Down Expand Up @@ -146,12 +183,13 @@ suspend fun Socket.writeSyncRequest(type: ByteArray, remotePath: String) {
val path = remotePath.toByteArray(Const.DEFAULT_TRANSPORT_ENCODING)
val size = path.size.toByteArray().reversedArray()

val cmd = ByteArray(8 + path.size)

type.copyInto(cmd)
size.copyInto(cmd, 4)
path.copyInto(cmd, 8)
write(cmd)
withDefaultBuffer {
put(type)
put(size)
put(path)
compatFlip()
writeFully(this)
}
}

suspend fun Socket.writeSyncV2Request(type: ByteArray, remotePath: String, flags: Int, mode: Int? = null) {
Expand All @@ -177,10 +215,12 @@ suspend fun Socket.writeSyncV2Request(type: ByteArray, remotePath: String, flags
}

suspend fun Socket.readTransportResponse(): TransportResponse {
val bytes = ByteArray(4)
readFully(bytes, 0, 4)

val ok = bytes.isOkay()
val ok = withDefaultBuffer {
compatLimit(4)
readFully(this)
compatFlip()
isOkay()
}
val message = if (!ok) {
readOptionalProtocolString()
} else {
Expand Down
34 changes: 19 additions & 15 deletions src/main/kotlin/com/malinskiy/adam/request/SynchronousRequest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,31 @@

package com.malinskiy.adam.request

import com.malinskiy.adam.Const
import com.malinskiy.adam.request.transform.ResponseTransformer
import com.malinskiy.adam.transport.Socket
import com.malinskiy.adam.transport.withMaxPacketBuffer
import kotlinx.coroutines.yield

abstract class SynchronousRequest<T : Any?>(target: Target = NonSpecifiedTarget) : ComplexRequest<T>(target), ResponseTransformer<T> {
override suspend fun readElement(socket: Socket): T {
val data = ByteArray(Const.MAX_PACKET_LENGTH)
loop@ do {
if (socket.isClosedForWrite || socket.isClosedForRead) break@loop
withMaxPacketBuffer {
loop@ do {
if (socket.isClosedForWrite || socket.isClosedForRead) break@loop

val count = socket.readAvailable(data, 0, Const.MAX_PACKET_LENGTH)
when {
count == 0 -> {
continue@loop
val data = array()
val count = socket.readAvailable(data, 0, data.size)
when {
count == 0 -> {
yield()
continue@loop
}
count > 0 -> {
process(data, 0, count)
}
}
count > 0 -> {
process(data, 0, count)
}
}
} while (count >= 0)
} while (count >= 0)

return transform()
return transform()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,15 @@

package com.malinskiy.adam.request.device

import com.malinskiy.adam.Const
import com.malinskiy.adam.extension.readProtocolString
import com.malinskiy.adam.request.AsyncChannelRequest
import com.malinskiy.adam.request.HostTarget
import com.malinskiy.adam.transport.Socket
import kotlinx.coroutines.channels.SendChannel
import java.nio.ByteBuffer

class AsyncDeviceMonitorRequest : AsyncChannelRequest<List<Device>, Unit>(target = HostTarget) {
override suspend fun readElement(socket: Socket, sendChannel: SendChannel<List<Device>>): Boolean {
val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4)
socket.readFully(sizeBuffer)
val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16)

val payloadBuffer = ByteBuffer.allocate(size)
socket.readFully(payloadBuffer)
val payload = String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING)
sendChannel.send(payload.lines()
sendChannel.send(socket.readProtocolString().lines()
.filter { it.isNotEmpty() }
.map {
val line = it.trim()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,17 @@

package com.malinskiy.adam.request.device

import com.malinskiy.adam.Const
import com.malinskiy.adam.extension.readProtocolString
import com.malinskiy.adam.request.ComplexRequest
import com.malinskiy.adam.request.Feature
import com.malinskiy.adam.request.SerialTarget
import com.malinskiy.adam.transport.Socket
import java.nio.ByteBuffer


class FetchDeviceFeaturesRequest(serial: String) : ComplexRequest<List<Feature>>(target = SerialTarget(serial)) {

override fun serialize() = createBaseRequest("features")

override suspend fun readElement(socket: Socket): List<Feature> {
val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4)
socket.readFully(sizeBuffer)
val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16)

val payloadBuffer = ByteBuffer.allocate(size)
socket.readFully(payloadBuffer)
return String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).split(',').mapNotNull { Feature.of(it) }
return socket.readProtocolString().split(',').mapNotNull { Feature.of(it) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,16 @@

package com.malinskiy.adam.request.device

import com.malinskiy.adam.Const
import com.malinskiy.adam.extension.readProtocolString
import com.malinskiy.adam.request.ComplexRequest
import com.malinskiy.adam.request.HostTarget
import com.malinskiy.adam.transport.Socket
import java.nio.ByteBuffer

class ListDevicesRequest : ComplexRequest<List<Device>>(target = HostTarget) {
override fun serialize() = createBaseRequest("devices")

override suspend fun readElement(socket: Socket): List<Device> {
val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4)
socket.readFully(sizeBuffer)
val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16)

val payloadBuffer = ByteBuffer.allocate(size)
socket.readFully(payloadBuffer)
val payload = String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING)
return payload.lines()
return socket.readProtocolString().lines()
.filter { it.isNotEmpty() }
.map {
val line = it.trim()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,14 @@

package com.malinskiy.adam.request.forwarding

import com.malinskiy.adam.Const
import com.malinskiy.adam.extension.readProtocolString
import com.malinskiy.adam.request.ComplexRequest
import com.malinskiy.adam.request.SerialTarget
import com.malinskiy.adam.transport.Socket
import java.nio.ByteBuffer

class ListPortForwardsRequest(serial: String) : ComplexRequest<List<PortForwardingRule>>(target = SerialTarget(serial)) {
override suspend fun readElement(socket: Socket): List<PortForwardingRule> {
val sizeBuffer: ByteBuffer = ByteBuffer.allocate(4)
socket.readFully(sizeBuffer)
val size = String(sizeBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING).toInt(radix = 16)

val payloadBuffer = ByteBuffer.allocate(size)
socket.readFully(payloadBuffer)
val payload = String(payloadBuffer.array(), Const.DEFAULT_TRANSPORT_ENCODING)
return payload.lines().mapNotNull { line ->
return socket.readProtocolString().lines().mapNotNull { line ->
if (line.isNotEmpty()) {
val split = line.split(" ")
PortForwardingRule(
Expand Down
Loading

0 comments on commit e59cd6a

Please sign in to comment.