Skip to content

Commit

Permalink
[scrooge] Speeding up serialization of collections and in particular …
Browse files Browse the repository at this point in the history
…arrays of primitives

Differential Revision: https://phabricator.twitter.biz/D1173708
  • Loading branch information
mbezoyan authored and jenkins committed Oct 3, 2024
1 parent 2468700 commit 1cbcb4a
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 30 deletions.
5 changes: 0 additions & 5 deletions scrooge-benchmark/BUILD

This file was deleted.

5 changes: 5 additions & 0 deletions scrooge-benchmark/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
target(
dependencies = [
"scrooge/scrooge-benchmark/src/main/scala:benchmark",
],
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
scala_library(
scala_benchmark_jmh(
name = "benchmark",
sources = ["**/*.scala"],
compiler_option_sets = ["fatal_warnings"],
platform = "java8",
Expand All @@ -11,20 +12,23 @@ scala_library(
"scrooge/scrooge-core/src/main/scala",
"scrooge/scrooge-serializer",
],
exports = [
"3rdparty/jvm/org/openjdk/jmh:jmh-core",
],
)

jvm_binary(
name = "jmh",
main = "org.openjdk.jmh.Main",
platform = "java8",
dependencies = [
":scala",
":benchmark_compiled_benchmark_lib",
scoped(
"3rdparty/jvm/org/slf4j:slf4j-nop",
scope = "runtime",
),
],
)

jvm_app(
name = "jmh-bundle",
basename = "scrooge-benchmark-bundle",
binary = ":jmh",
)
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package com.twitter.scrooge.benchmark

import com.twitter.scrooge.ThriftStruct
import com.twitter.scrooge.ThriftStructCodec
import java.io.ByteArrayOutputStream
import java.util.concurrent.TimeUnit
import java.util.Random
import org.apache.thrift.protocol.{TProtocol, TBinaryProtocol}
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.protocol.TProtocol
import org.apache.thrift.transport.TTransport
import org.openjdk.jmh.annotations._
import thrift.benchmark._
Expand Down Expand Up @@ -39,6 +41,7 @@ class TRewindable extends TTransport {

def rewind(): Unit = {
pos = 0
arr.reset()
}

def inspect: String = {
Expand All @@ -63,33 +66,57 @@ class Collections(size: Int) {
val list: TRewindable = new TRewindable
val listProt: TBinaryProtocol = new TBinaryProtocol(list)

val listDouble: TRewindable = new TRewindable
val listDoubleProt: TBinaryProtocol = new TBinaryProtocol(listDouble)

val rng: Random = new Random(31415926535897932L)

val mapVals: mutable.Builder[(Long, String), Map[Long, String]] = Map.newBuilder[Long, String]
val setVals: mutable.Builder[Long, Set[Long]] = Set.newBuilder[Long]
val listVals: mutable.Builder[Long, Seq[Long]] = Seq.newBuilder[Long]
val arrayVals = new Array[Long](size)
val arrayDoublesVals = new Array[Double](size)

val m: Unit = for (_ <- (0 until size)) {
val m: Unit = for (i <- (0 until size)) {
val num = rng.nextLong()
mapVals += (num -> num.toString)
setVals += num
listVals += num
arrayVals(i) = num
arrayDoublesVals(i) = num
}

MapCollections.encode(MapCollections(mapVals.result), mapProt)
SetCollections.encode(SetCollections(setVals.result), setProt)
ListCollections.encode(ListCollections(listVals.result), listProt)
val mapCollections: MapCollections = MapCollections(mapVals.result)
val setCollections: SetCollections = SetCollections(setVals.result)
val listCollections: ListCollections = ListCollections(listVals.result)
val arrayCollections: ListCollections = ListCollections(arrayVals)
val arrayDoubleCollections: ListDoubleCollections = ListDoubleCollections(arrayDoublesVals)

MapCollections.encode(mapCollections, mapProt)
SetCollections.encode(setCollections, setProt)
ListCollections.encode(listCollections, listProt)
ListDoubleCollections.encode(arrayDoubleCollections, listDoubleProt)

def run(codec: ThriftStructCodec[_], prot: TProtocol, buff: TRewindable): Unit = {
def decode(codec: ThriftStructCodec[_], prot: TProtocol, buff: TRewindable): Unit = {
codec.decode(prot)
buff.rewind()
}

def encode[T <: ThriftStruct](
codec: ThriftStructCodec[T],
prot: TProtocol,
buff: TRewindable,
obj: T
): Unit = {
codec.encode(obj, prot)
buff.rewind()
}
}

object CollectionsBenchmark {
@State(Scope.Thread)
class CollectionsState {
@Param(Array("1", "5", "10", "100", "500", "1000"))
@Param(Array("1", "5", "10", "100", "500"))
var size: Int = 1

var col: Collections = _
Expand All @@ -98,24 +125,53 @@ object CollectionsBenchmark {
def setup(): Unit = {
col = new Collections(size)
}

}
}

@OutputTimeUnit(TimeUnit.NANOSECONDS)
@OutputTimeUnit(TimeUnit.SECONDS)
@BenchmarkMode(Array(Mode.Throughput))
@Fork(1)
@Warmup(iterations = 3, time = 10, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 10, timeUnit = TimeUnit.SECONDS)
class CollectionsBenchmark {
import CollectionsBenchmark._

@Benchmark
def timeMap(state: CollectionsState): Unit =
state.col.run(MapCollections, state.col.mapProt, state.col.map)
def timeEncodeList(state: CollectionsState): Unit =
state.col.encode(
ListCollections,
state.col.listProt,
state.col.list,
state.col.listCollections
)

@Benchmark
def timeEncodeArray(state: CollectionsState): Unit =
state.col.encode(
ListCollections,
state.col.listProt,
state.col.list,
state.col.arrayCollections
)

@Benchmark
def timeEncodeDoubleArray(state: CollectionsState): Unit =
state.col.encode(
ListDoubleCollections,
state.col.listDoubleProt,
state.col.listDouble,
state.col.arrayDoubleCollections
)

@Benchmark
def timeDecodeMap(state: CollectionsState): Unit =
state.col.decode(MapCollections, state.col.mapProt, state.col.map)

@Benchmark
def timeSet(state: CollectionsState): Unit =
state.col.run(SetCollections, state.col.setProt, state.col.set)
def timeDecodeSet(state: CollectionsState): Unit =
state.col.decode(SetCollections, state.col.setProt, state.col.set)

@Benchmark
def timeList(state: CollectionsState): Unit =
state.col.run(ListCollections, state.col.listProt, state.col.list)
def timeDecodeList(state: CollectionsState): Unit =
state.col.decode(ListCollections, state.col.listProt, state.col.list)
}
File renamed without changes.
4 changes: 4 additions & 0 deletions scrooge-benchmark/src/main/thrift/collections.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ struct SetCollections {
struct ListCollections {
1: list<i64> longs
}

struct ListDoubleCollections {
1: list<double> doubles
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import com.twitter.scrooge.TFieldBlob
import com.twitter.scrooge.ThriftEnum
import com.twitter.scrooge.ThriftUnion
import java.nio.ByteBuffer
import java.util.function.ObjDoubleConsumer
import java.util.function.ObjLongConsumer
import org.apache.thrift.protocol._
import scala.collection.immutable
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
* Reads and writes fields for a `TProtocol`. Intended to be used
Expand Down Expand Up @@ -93,11 +96,26 @@ final class TProtocols private[TProtocols] {
elementType: Byte,
writeElement: (TProtocol, T) => Unit
): Unit = {
protocol.writeListBegin(new TList(typeForCollection(elementType), list.size))
val size = list.size
protocol.writeListBegin(new TList(typeForCollection(elementType), size))
list match {
case wrappedArray: mutable.WrappedArray[T] =>
val arr = wrappedArray.array
var i = 0
while (i < size) {
val el: T = arr(i).asInstanceOf[T]
writeElement(protocol, el)
i += 1
}
case arrayBuffer: ArrayBuffer[T] =>
var i = 0
while (i < size) {
writeElement(protocol, arrayBuffer(i))
i += 1
}
case _: IndexedSeq[_] =>
var i = 0
while (i < list.size) {
while (i < size) {
writeElement(protocol, list(i))
i += 1
}
Expand All @@ -109,6 +127,78 @@ final class TProtocols private[TProtocols] {
protocol.writeListEnd()
}

def writeListDouble(
protocol: TProtocol,
list: collection.Seq[Double],
elementType: Byte,
writeElement: ObjDoubleConsumer[TProtocol]
): Unit = {
val size = list.size
protocol.writeListBegin(new TList(typeForCollection(elementType), size))
list match {
case wrappedArray: mutable.WrappedArray.ofDouble =>
val arr = wrappedArray.array
var i = 0
while (i < size) {
writeElement.accept(protocol, arr(i))
i += 1
}
case arrayBuffer: ArrayBuffer[Double] =>
var i = 0
while (i < size) {
writeElement.accept(protocol, arrayBuffer(i))
i += 1
}
case _: IndexedSeq[_] =>
var i = 0
while (i < size) {
writeElement.accept(protocol, list(i))
i += 1
}
case _ =>
list.foreach { element =>
writeElement.accept(protocol, element)
}
}
protocol.writeListEnd()
}

def writeListI64(
protocol: TProtocol,
list: collection.Seq[Long],
elementType: Byte,
writeElement: ObjLongConsumer[TProtocol]
): Unit = {
val len = list.size
protocol.writeListBegin(new TList(typeForCollection(elementType), len))
list match {
case wrappedArray: mutable.WrappedArray.ofLong =>
val arr = wrappedArray.array
var i = 0
while (i < len) {
writeElement.accept(protocol, arr(i))
i += 1
}
case arrayBuffer: ArrayBuffer[Long] =>
var i = 0
while (i < len) {
writeElement.accept(protocol, arrayBuffer(i))
i += 1
}
case _: IndexedSeq[_] =>
var i = 0
while (i < len) {
writeElement.accept(protocol, list(i))
i += 1
}
case _ =>
list.foreach { element =>
writeElement.accept(protocol, element)
}
}
protocol.writeListEnd()
}

def writeSet[T](
protocol: TProtocol,
set: collection.Set[T],
Expand Down Expand Up @@ -193,9 +283,15 @@ object TProtocols {
val writeI64Fn: (TProtocol, Long) => Unit =
(protocol, value) => protocol.writeI64(value)

val writeI64Consumer: ObjLongConsumer[TProtocol] =
(protocol: TProtocol, value: Long) => protocol.writeI64(value)

val writeDoubleFn: (TProtocol, Double) => Unit =
(protocol, value) => protocol.writeDouble(value)

val writeDoubleConsumer: ObjDoubleConsumer[TProtocol] =
(protocol: TProtocol, value: Double) => protocol.writeDouble(value)

val writeStringFn: (TProtocol, String) => Unit =
(protocol, value) => protocol.writeString(value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,25 @@ trait StructTemplate { self: TemplateGenerator =>
}
}

@scala.annotation.tailrec
private[this] def genWriteListFn(
elementFieldType: FieldType,
fieldName: CodeFragment,
protoName: String
): CodeFragment = {
elementFieldType match {
case at: AnnotatedFieldType => genWriteListFn(at.unwrap, fieldName, protoName)
case TDouble =>
v(s"$rootProtos.writeListDouble($protoName, $fieldName, TType.DOUBLE, _root_.com.twitter.scrooge.internal.TProtocols.writeDoubleConsumer)")
case TI64 =>
v(s"$rootProtos.writeListI64($protoName, $fieldName, TType.I64, _root_.com.twitter.scrooge.internal.TProtocols.writeI64Consumer)")
case _ =>
val elemFieldType = s"TType.${genConstType(elementFieldType)}"
val writeElementFn = genWriteValueFn2(elementFieldType)
v(s"$rootProtos.writeList($protoName, $fieldName, $elemFieldType, $writeElementFn)")
}
}

@scala.annotation.tailrec
private[this] def genWriteValueFn2(fieldType: FieldType): CodeFragment = {
fieldType match {
Expand Down Expand Up @@ -306,9 +325,7 @@ trait StructTemplate { self: TemplateGenerator =>
val writeElement = genWriteValueFn2(t.eltType)
v(s"$rootProtos.writeSet($protoName, $fieldName, $elemFieldType, $writeElement)")
case t: ListType =>
val elemFieldType = s"TType.${genConstType(t.eltType)}"
val writeElement = genWriteValueFn2(t.eltType)
v(s"$rootProtos.writeList($protoName, $fieldName, $elemFieldType, $writeElement)")
genWriteListFn(t.eltType, fieldName, protoName)
case t: MapType =>
val keyType = s"TType.${genConstType(t.keyType)}"
val valType = s"TType.${genConstType(t.valueType)}"
Expand Down

0 comments on commit 1cbcb4a

Please sign in to comment.