Skip to content

Commit

Permalink
move useColumnFamilies into key encoder API
Browse files Browse the repository at this point in the history
  • Loading branch information
jingz-db committed Jul 4, 2024
1 parent 1393a0d commit 23f9d41
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ sealed trait RocksDBKeyStateEncoder {
def supportPrefixKeyScan: Boolean
def encodePrefixKey(prefixKey: UnsafeRow, vcfId: Option[Short]): Array[Byte]
def encodeKey(row: UnsafeRow, vcfId: Option[Short]): Array[Byte]
def decodeKey(keyBytes: Array[Byte], hasVcfPrefix: Boolean = false): UnsafeRow
def decodeKey(keyBytes: Array[Byte]): UnsafeRow
def offSetForColFamilyPrefix: Int
}

sealed trait RocksDBValueStateEncoder {
Expand All @@ -43,17 +44,19 @@ sealed trait RocksDBValueStateEncoder {
}

object RocksDBStateEncoder {
def getKeyEncoder(keyStateEncoderSpec: KeyStateEncoderSpec): RocksDBKeyStateEncoder = {
def getKeyEncoder(
keyStateEncoderSpec: KeyStateEncoderSpec,
useColumnFamilies: Boolean): RocksDBKeyStateEncoder = {
// Return the key state encoder based on the requested type
keyStateEncoderSpec match {
case NoPrefixKeyStateEncoderSpec(keySchema) =>
new NoPrefixKeyStateEncoder(keySchema)
new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies)

case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey)
new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey, useColumnFamilies)

case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
new RangeKeyScanStateEncoder(keySchema, orderingOrdinals)
new RangeKeyScanStateEncoder(keySchema, orderingOrdinals, useColumnFamilies)

case _ =>
throw new IllegalArgumentException(s"Unsupported key state encoder spec: " +
Expand Down Expand Up @@ -118,10 +121,14 @@ object RocksDBStateEncoder {
*/
class PrefixKeyScanStateEncoder(
keySchema: StructType,
numColsPrefixKey: Int) extends RocksDBKeyStateEncoder {
numColsPrefixKey: Int,
useColumnFamilies: Boolean = false) extends RocksDBKeyStateEncoder {

import RocksDBStateEncoder._

override def offSetForColFamilyPrefix: Int =
if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0

private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = {
keySchema.zipWithIndex.take(numColsPrefixKey)
}
Expand All @@ -148,15 +155,12 @@ class PrefixKeyScanStateEncoder(
private val joinedRowOnKey = new JoinedRow()

override def encodeKey(row: UnsafeRow, vcfId: Option[Short]): Array[Byte] = {
val hasVirtualColFamilyPrefix: Boolean = vcfId.isDefined
val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
val offSetForColFamilyPrefix =
if (hasVirtualColFamilyPrefix) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0

val encodedBytes = new Array[Byte](prefixKeyEncoded.length +
remainingEncoded.length + 4 + offSetForColFamilyPrefix)
if (hasVirtualColFamilyPrefix) {
if (useColumnFamilies) {
Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET, vcfId.get)
}

Expand All @@ -175,11 +179,7 @@ class PrefixKeyScanStateEncoder(
encodedBytes
}

override def decodeKey(
keyBytes: Array[Byte],
hasVcfPrefix: Boolean = false): UnsafeRow = {
val offSetForColFamilyPrefix =
if (hasVcfPrefix) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
val prefixKeyEncodedLen = Platform.getInt(
keyBytes, Platform.BYTE_ARRAY_OFFSET + offSetForColFamilyPrefix)
val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen)
Expand Down Expand Up @@ -210,15 +210,10 @@ class PrefixKeyScanStateEncoder(
}

override def encodePrefixKey(prefixKey: UnsafeRow, vcfId: Option[Short]): Array[Byte] = {
val hasVirtualColFamilyPrefix = vcfId.isDefined

val offSetForColFamilyPrefix =
if (hasVirtualColFamilyPrefix) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0

val prefixKeyEncoded = encodeUnsafeRow(prefixKey)
val prefix = new Array[Byte](
prefixKeyEncoded.length + 4 + offSetForColFamilyPrefix)
if (hasVirtualColFamilyPrefix) {
if (useColumnFamilies) {
Platform.putShort(prefix, Platform.BYTE_ARRAY_OFFSET, vcfId.get)
}
Platform.putInt(prefix, Platform.BYTE_ARRAY_OFFSET + offSetForColFamilyPrefix,
Expand Down Expand Up @@ -264,10 +259,14 @@ class PrefixKeyScanStateEncoder(
*/
class RangeKeyScanStateEncoder(
keySchema: StructType,
orderingOrdinals: Seq[Int]) extends RocksDBKeyStateEncoder {
orderingOrdinals: Seq[Int],
useColumnFamilies: Boolean = false) extends RocksDBKeyStateEncoder {

import RocksDBStateEncoder._

override def offSetForColFamilyPrefix: Int =
if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0

private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
orderingOrdinals.map { ordinal =>
val field = keySchema(ordinal)
Expand Down Expand Up @@ -523,20 +522,16 @@ class RangeKeyScanStateEncoder(
}

override def encodeKey(row: UnsafeRow, vcfId: Option[Short]): Array[Byte] = {
val hasVirtualColFamilyPrefix: Boolean = vcfId.isDefined
// This prefix key has the columns specified by orderingOrdinals
val prefixKey = extractPrefixKey(row)
val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey))

val offSetForColFamilyPrefix =
if (hasVirtualColFamilyPrefix) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0

val result = if (orderingOrdinals.length < keySchema.length) {
val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
val encodedBytes = new Array[Byte](rangeScanKeyEncoded.length +
remainingEncoded.length + 4 + offSetForColFamilyPrefix)

if (hasVirtualColFamilyPrefix) {
if (useColumnFamilies) {
Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET, vcfId.get)
}

Expand All @@ -557,7 +552,7 @@ class RangeKeyScanStateEncoder(
// encode the remaining key as it's empty.
val encodedBytes = new Array[Byte](
rangeScanKeyEncoded.length + 4 + offSetForColFamilyPrefix)
if (hasVirtualColFamilyPrefix) {
if (useColumnFamilies) {
Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET, vcfId.get)
}

Expand All @@ -571,12 +566,7 @@ class RangeKeyScanStateEncoder(
result
}

override def decodeKey(
keyBytes: Array[Byte],
hasVcfPrefix: Boolean = false): UnsafeRow = {
val offSetForColFamilyPrefix =
if (hasVcfPrefix) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0

override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
val prefixKeyEncodedLen = Platform.getInt(
keyBytes, Platform.BYTE_ARRAY_OFFSET + offSetForColFamilyPrefix)
val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen)
Expand Down Expand Up @@ -612,14 +602,10 @@ class RangeKeyScanStateEncoder(
}

override def encodePrefixKey(prefixKey: UnsafeRow, vcfId: Option[Short]): Array[Byte] = {
val hasVirtualColFamilyPrefix: Boolean = vcfId.isDefined
val offSetForColFamilyPrefix =
if (hasVirtualColFamilyPrefix) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0

val rangeScanKeyEncoded = encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey))
val prefix = new Array[Byte](rangeScanKeyEncoded.length + 4 + offSetForColFamilyPrefix)

if (hasVirtualColFamilyPrefix) {
if (useColumnFamilies) {
Platform.putShort(prefix, Platform.BYTE_ARRAY_OFFSET, vcfId.get)
}
Platform.putInt(prefix, Platform.BYTE_ARRAY_OFFSET + offSetForColFamilyPrefix,
Expand All @@ -644,16 +630,19 @@ class RangeKeyScanStateEncoder(
* (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
* then the generated array byte will be N+1 bytes.
*/
class NoPrefixKeyStateEncoder(keySchema: StructType)
class NoPrefixKeyStateEncoder(keySchema: StructType, useColumnFamilies: Boolean = false)
extends RocksDBKeyStateEncoder {

import RocksDBStateEncoder._

override def offSetForColFamilyPrefix: Int =
if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0

// Reusable objects
private val keyRow = new UnsafeRow(keySchema.size)

override def encodeKey(row: UnsafeRow, vcfId: Option[Short]): Array[Byte] = {
if (!vcfId.isDefined) {
if (!useColumnFamilies) {
encodeUnsafeRow(row)
} else {
val bytesToEncode = row.getBytes
Expand All @@ -678,10 +667,8 @@ class NoPrefixKeyStateEncoder(keySchema: StructType)
* @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
* the given byte array.
*/
override def decodeKey(
keyBytes: Array[Byte],
hasVcfPrefix: Boolean = false): UnsafeRow = {
if (hasVcfPrefix) {
override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
if (useColumnFamilies) {
if (keyBytes != null) {
// Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
keyRow.pointTo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private[sql] class RocksDBStateStoreProvider
ColumnFamilyUtils.createColFamilyIfAbsent(colFamilyName, isInternal)

keyValueEncoderMap.putIfAbsent(colFamilyName,
(RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec),
(RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, useColumnFamilies),
RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey)))
}

Expand Down Expand Up @@ -161,7 +161,7 @@ private[sql] class RocksDBStateStoreProvider
if (useColumnFamilies) {
val cfId: Short = colFamilyNameToIdMap.get(colFamilyName)
rocksDB.prefixScan(ColumnFamilyUtils.getVcfIdBytes(cfId)).map { kv =>
rowPair.withRows(kvEncoder._1.decodeKey(kv.key, true),
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
kvEncoder._2.decodeValue(kv.value))
if (!isValidated && rowPair.value != null && !useColumnFamilies) {
StateStoreProvider.validateStateRowFormat(
Expand Down Expand Up @@ -196,7 +196,7 @@ private[sql] class RocksDBStateStoreProvider
val prefix =
kvEncoder._1.encodePrefixKey(prefixKey, Option(colFamilyNameToIdMap.get(colFamilyName)))
rocksDB.prefixScan(prefix).map { kv =>
rowPair.withRows(kvEncoder._1.decodeKey(kv.key, useColumnFamilies),
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
kvEncoder._2.decodeValue(kv.value))
rowPair
}
Expand Down Expand Up @@ -355,7 +355,7 @@ private[sql] class RocksDBStateStoreProvider
}

keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
(RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec),
(RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, useColumnFamilies),
RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey)))
if (useColumnFamilies) {
// put default column family only if useColumnFamilies are enabled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
errorClass = "STATE_STORE_UNSUPPORTED_OPERATION",
parameters = Map(
"operationType" -> "create_col_family",
"entity" -> "multiple column families disabled in RocksDBStateStoreProvider"
"entity" -> "multiple column families is disabled in RocksDBStateStoreProvider"
),
matchPVals = true
)
Expand Down Expand Up @@ -974,7 +974,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
errorClass = "STATE_STORE_UNSUPPORTED_OPERATION",
parameters = Map(
"operationType" -> "create_col_family",
"entity" -> "multiple column families disabled in RocksDBStateStoreProvider"
"entity" -> "multiple column families is disabled in RocksDBStateStoreProvider"
),
matchPVals = true
)
Expand Down Expand Up @@ -1192,7 +1192,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
errorClass = "STATE_STORE_UNSUPPORTED_OPERATION",
parameters = Map(
"operationType" -> operationName,
"entity" -> "multiple column families disabled in RocksDBStateStoreProvider"
"entity" -> "multiple column families is disabled in RocksDBStateStoreProvider"
),
matchPVals = true
)
Expand Down

0 comments on commit 23f9d41

Please sign in to comment.