From e067afda6eed70e07c29406df55386b7c266f6e1 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Sat, 14 Sep 2024 16:13:21 +0200 Subject: [PATCH] feat: Add a way to reorder elements during encoding --- core/build.gradle.kts | 2 +- .../encoding/ReorderingCompositeEncoder.kt | 247 +++++++++ .../ReorderingCompositeEncoderTest.kt | 467 ++++++++++++++++++ 3 files changed, 715 insertions(+), 1 deletion(-) create mode 100644 core/commonMain/src/kotlinx/serialization/encoding/ReorderingCompositeEncoder.kt create mode 100644 core/commonTest/src/kotlinx/serialization/encoding/ReorderingCompositeEncoderTest.kt diff --git a/core/build.gradle.kts b/core/build.gradle.kts index b3d885ee26..c142233474 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -37,7 +37,7 @@ kotlin { Require-Kotlin-Version is used to determine whether runtime library with new features can work with old compilers. In ideal case, its value should always be 1.4, but some refactorings (e.g. adding a method to the Encoder interface) may unexpectedly break old compilers, so it is left out as a safety net. Compiler plugins, starting from 1.4 are instructed - to reject runtime if runtime's Require-Kotlin-Version is greater then the current compiler. + to reject runtime if runtime's Require-Kotlin-Version is greater than the current compiler. */ tasks.withType().named(kotlin.jvm().artifactsTaskName) { diff --git a/core/commonMain/src/kotlinx/serialization/encoding/ReorderingCompositeEncoder.kt b/core/commonMain/src/kotlinx/serialization/encoding/ReorderingCompositeEncoder.kt new file mode 100644 index 0000000000..305dc4895c --- /dev/null +++ b/core/commonMain/src/kotlinx/serialization/encoding/ReorderingCompositeEncoder.kt @@ -0,0 +1,247 @@ +/* + * Copyright 2017-2024 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.serialization.encoding + +import kotlinx.serialization.* +import kotlinx.serialization.descriptors.* +import kotlinx.serialization.modules.* + + +/** + * Encodes elements in a user-defined order managed by [mapElementIndex]. + * + * This encoder will replicate the behavior of a standard encoding, but calling the `encode*Element` methods in + * the order defined by [mapElementIndex]. It first buffers each `encode*Element` calls in an array following + * the given indexes using [mapElementIndex], then when [endStructure] is called, it encodes the buffered calls + * in the expected order by replaying the previous calls on the given [compositeEncoderDelegate]. + * + * This encoder is stateful and not designed to be reused. + * + * @param compositeEncoderDelegate the [CompositeEncoder] to be used to encode the given descriptor's elements in the expected order. + * @param encodedElementsCount The final number of elements to encode, which could be smaller than the original descriptor when [mapElementIndex] returns [SKIP_ELEMENT_INDEX] or when the index mapper has returned the same index twice. + * @param mapElementIndex maps the element index to a new positional zero-based index. + * The mapped index just helps to reorder the elements, + * but the reordered `encode*Element` method calls will still pass the original element index. + * If this mapper returns [SKIP_ELEMENT_INDEX] or -1, the element will be ignored and not encoded. + * If this mapper provides the same index for multiple elements, + * only the last one will be encoded as the previous ones will be overridden. + */ +@ExperimentalSerializationApi +public class ReorderingCompositeEncoder( + encodedElementsCount: Int, + private val compositeEncoderDelegate: CompositeEncoder, + private val mapElementIndex: (SerialDescriptor, Int) -> Int, +) : CompositeEncoder { + private val bufferedCalls = Array(encodedElementsCount) { null } + + public companion object { + @ExperimentalSerializationApi + public const val SKIP_ELEMENT_INDEX: Int = -1 + } + + override val serializersModule: SerializersModule + // No need to return a serializers module as it's not used during buffering + get() = EmptySerializersModule() + + private data class BufferedCall( + val originalElementIndex: Int, + val encoder: () -> Unit, + ) + + private fun bufferEncoding( + descriptor: SerialDescriptor, + index: Int, + encoder: () -> Unit + ) { + val newIndex = mapElementIndex(descriptor, index) + if (newIndex != SKIP_ELEMENT_INDEX) { + bufferedCalls[newIndex] = BufferedCall(index, encoder) + } + } + + override fun endStructure(descriptor: SerialDescriptor) { + bufferedCalls.forEach { fieldToEncode -> + // In case of skipped fields, overridden fields (mapped to same index) or too big [encodedElementsCount], + // the fieldToEncode may be null as no element was encoded for that index + fieldToEncode?.encoder?.invoke() + } + compositeEncoderDelegate.endStructure(descriptor) + } + + override fun encodeBooleanElement(descriptor: SerialDescriptor, index: Int, value: Boolean) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeBooleanElement(descriptor, index, value) + } + } + + override fun encodeByteElement(descriptor: SerialDescriptor, index: Int, value: Byte) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeByteElement(descriptor, index, value) + } + } + + override fun encodeCharElement(descriptor: SerialDescriptor, index: Int, value: Char) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeCharElement(descriptor, index, value) + } + } + + override fun encodeDoubleElement(descriptor: SerialDescriptor, index: Int, value: Double) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeDoubleElement(descriptor, index, value) + } + } + + override fun encodeFloatElement(descriptor: SerialDescriptor, index: Int, value: Float) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeFloatElement(descriptor, index, value) + } + } + + override fun encodeIntElement(descriptor: SerialDescriptor, index: Int, value: Int) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeIntElement(descriptor, index, value) + } + } + + override fun encodeLongElement(descriptor: SerialDescriptor, index: Int, value: Long) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeLongElement(descriptor, index, value) + } + } + + override fun encodeNullableSerializableElement( + descriptor: SerialDescriptor, + index: Int, + serializer: SerializationStrategy, + value: T? + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeNullableSerializableElement(descriptor, index, serializer, value) + } + } + + override fun encodeSerializableElement( + descriptor: SerialDescriptor, + index: Int, + serializer: SerializationStrategy, + value: T + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeSerializableElement(descriptor, index, serializer, value) + } + } + + override fun encodeShortElement(descriptor: SerialDescriptor, index: Int, value: Short) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeShortElement(descriptor, index, value) + } + } + + override fun encodeStringElement(descriptor: SerialDescriptor, index: Int, value: String) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeStringElement(descriptor, index, value) + } + } + + override fun encodeInlineElement(descriptor: SerialDescriptor, index: Int): Encoder { + return BufferingInlineEncoder(descriptor, index) + } + + override fun shouldEncodeElementDefault(descriptor: SerialDescriptor, index: Int): Boolean { + return compositeEncoderDelegate.shouldEncodeElementDefault(descriptor, index) + } + + private inner class BufferingInlineEncoder( + private val descriptor: SerialDescriptor, + private val elementIndex: Int, + ) : Encoder { + private var encodeNotNullMarkCalled = false + + override val serializersModule: SerializersModule + get() = this@ReorderingCompositeEncoder.serializersModule + + private fun bufferEncoding(encoder: Encoder.() -> Unit) { + bufferEncoding(descriptor, elementIndex) { + compositeEncoderDelegate.encodeInlineElement(descriptor, elementIndex).apply { + if (encodeNotNullMarkCalled) { + encodeNotNullMark() + } + encoder() + } + } + } + + override fun encodeNotNullMark() { + encodeNotNullMarkCalled = true + } + + override fun encodeNullableSerializableValue(serializer: SerializationStrategy, value: T?) { + bufferEncoding { encodeNullableSerializableValue(serializer, value) } + } + + override fun encodeSerializableValue(serializer: SerializationStrategy, value: T) { + bufferEncoding { encodeSerializableValue(serializer, value) } + } + + override fun encodeBoolean(value: Boolean) { + bufferEncoding { encodeBoolean(value) } + } + + override fun encodeByte(value: Byte) { + bufferEncoding { encodeByte(value) } + } + + override fun encodeChar(value: Char) { + bufferEncoding { encodeChar(value) } + } + + override fun encodeDouble(value: Double) { + bufferEncoding { encodeDouble(value) } + } + + override fun encodeEnum(enumDescriptor: SerialDescriptor, index: Int) { + bufferEncoding { encodeEnum(enumDescriptor, index) } + } + + override fun encodeFloat(value: Float) { + bufferEncoding { encodeFloat(value) } + } + + override fun encodeInt(value: Int) { + bufferEncoding { encodeInt(value) } + } + + override fun encodeLong(value: Long) { + bufferEncoding { encodeLong(value) } + } + + @ExperimentalSerializationApi + override fun encodeNull() { + bufferEncoding { encodeNull() } + } + + override fun encodeShort(value: Short) { + bufferEncoding { encodeShort(value) } + } + + override fun encodeString(value: String) { + bufferEncoding { encodeString(value) } + } + + override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { + unexpectedCall(::beginStructure.name) + } + + override fun encodeInline(descriptor: SerialDescriptor): Encoder { + unexpectedCall(::encodeInline.name) + } + + private fun unexpectedCall(methodName: String): Nothing { + // This method is normally called from within encodeSerializableValue or encodeNullableSerializableValue which is buffered, so we should never go here during buffering as it will be delegated to the concrete CompositeEncoder + throw UnsupportedOperationException("Non-standard usage of ${CompositeEncoder::class.simpleName}: $methodName should be called from within encodeSerializableValue or encodeNullableSerializableValue") + } + } +} \ No newline at end of file diff --git a/core/commonTest/src/kotlinx/serialization/encoding/ReorderingCompositeEncoderTest.kt b/core/commonTest/src/kotlinx/serialization/encoding/ReorderingCompositeEncoderTest.kt new file mode 100644 index 0000000000..3de5e41f8a --- /dev/null +++ b/core/commonTest/src/kotlinx/serialization/encoding/ReorderingCompositeEncoderTest.kt @@ -0,0 +1,467 @@ +/* + * Copyright 2017-2024 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.serialization.encoding + +import kotlinx.serialization.* +import kotlinx.serialization.descriptors.* +import kotlinx.serialization.modules.* +import kotlin.jvm.* +import kotlin.test.* + +class ReorderingCompositeEncoderTest { + @Test + fun shouldReorderWellWithDirectUsage() { + val valueClassDescriptor = buildClassSerialDescriptor("inlined") { element("value") } + val descriptor = buildClassSerialDescriptor("reordered descriptor") { + element("0") + element("1", valueClassDescriptor) + element("2") + element("3") + } + + val encoder = SimpleListEncoder() + + val mappedIndexes = mapOf( + 0 to 1, + 1 to 0, + 2 to ReorderingCompositeEncoder.SKIP_ELEMENT_INDEX, + 3 to 2 + ) + ReorderingCompositeEncoder(3, encoder) { _, index -> mappedIndexes.getValue(index) }.apply { + encodeDoubleElement(descriptor, 0, 17.0) + encodeInlineElement(descriptor, 1).encodeString("Hello") + encodeStringElement(descriptor, 2, "ignored") + encodeCharElement(descriptor, 3, '!') + endStructure(descriptor) + } + + assertContentEquals( + listOf("Hello", 17.0, '!'), + encoder.output + ) + } + + @Test + fun shouldReorderWellWithSerializationPlugin() { + val value = AllTypesExample( + int = 1, + nullableInt = null, + intWrapped = IntValue(2), + nullableIntWrapped = IntNullableValue(null), + intWrappedNullable = IntValue(3), + + string = "Hello", + nullableString = null, + stringWrapped = StringValue("World"), + nullableStringWrapped = StringNullableValue(null), + stringWrappedNullable = StringValue("!!!"), + + boolean = true, + nullableBoolean = null, + booleanWrapped = BooleanValue(true), + nullableBooleanWrapped = BooleanNullableValue(null), + booleanWrappedNullable = BooleanValue(false), + + double = 1.0, + nullableDouble = null, + doubleWrapped = DoubleValue(2.0), + nullableDoubleWrapped = DoubleNullableValue(null), + doubleWrappedNullable = DoubleValue(3.0), + + nonReorderedList = listOf(1, 2, 3), + nonReorderedNullableList = null, + nonReorderedListWrapped = ListValue(listOf(4, 5, 6)), + nonReorderedNullableListWrapped = ListNullableValue(null), + nonReorderedListWrappedNullable = ListNullableValue(listOf(7, 8, 9)), + nonReorderedListWrappedNullableValues = ListNullableValues(listOf(null, 8, 9)), + + float = 1.0f, + nullableFloat = null, + floatWrapped = FloatValue(2.0f), + nullableFloatWrapped = FloatNullableValue(null), + floatWrappedNullable = FloatValue(3.0f), + + long = 1L, + nullableLong = null, + longWrapped = LongValue(2L), + nullableLongWrapped = LongNullableValue(null), + longWrappedNullable = LongValue(3L), + + short = 1, + nullableShort = null, + shortWrapped = ShortValue(2), + nullableShortWrapped = ShortNullableValue(null), + shortWrappedNullable = ShortValue(3), + + nonReorderedSubStructure = SubStructure( + a = IntValue(10), + b = "Sub", + c = null, + d = null + ), + nonReorderedNullableSubStructure = null, + nonReorderedSubStructureWrapped = SubStructureValue( + SubStructure( + a = IntValue(11), + b = "Wrapped", + c = 12L, + d = 13.toByte() + ) + ), + nonReorderedNullableSubStructureWrapped = SubStructureNullableValue(null), + nonReorderedSubStructureWrappedNullable = SubStructureValue( + SubStructure( + a = IntValue(14), + b = "Nullable-Wrapped", + c = 15L, + d = 16.toByte() + ) + ), + + byte = 1, + nullableByte = null, + byteWrapped = ByteValue(2), + nullableByteWrapped = ByteNullableValue(null), + byteWrappedNullable = ByteValue(3), + + char = 'A', + nullableChar = null, + charWrapped = CharValue('B'), + nullableCharWrapped = CharNullableValue(null), + charWrappedNullable = CharValue('C') + ) + + val output = StringBuilder() + val encoder = LightJsonEncoder( + output, + descriptorToReorder = AllTypesExample.serializer().descriptor + ) { descriptor, index -> descriptor.elementsCount - 1 - index } + + encoder.encodeSerializableValue(AllTypesExample.serializer(), value) + + // the final output should encode the fields in the reverse order, but should not reorder sub-structures + assertEquals( + actual = output.toString(), + expected = """ +{ + charWrappedNullable: "C", + nullableCharWrapped: null, + charWrapped: "B", + nullableChar: null, + char: "A", + byteWrappedNullable: 3, + nullableByteWrapped: null, + byteWrapped: 2, + nullableByte: null, + byte: 1, + nonReorderedSubStructureWrappedNullable: { + a: 14, + b: "Nullable-Wrapped", + c: 15, + d: 16 + }, + nonReorderedNullableSubStructureWrapped: null, + nonReorderedSubStructureWrapped: { + a: 11, + b: "Wrapped", + c: 12, + d: 13 + }, + nonReorderedNullableSubStructure: null, + nonReorderedSubStructure: { + a: 10, + b: "Sub", + c: null, + d: null + }, + shortWrappedNullable: 3, + nullableShortWrapped: null, + shortWrapped: 2, + nullableShort: null, + short: 1, + longWrappedNullable: 3, + nullableLongWrapped: null, + longWrapped: 2, + nullableLong: null, + long: 1, + floatWrappedNullable: 3.0, + nullableFloatWrapped: null, + floatWrapped: 2.0, + nullableFloat: null, + float: 1.0, + nonReorderedListWrappedNullableValues: [null,8,9], + nonReorderedListWrappedNullable: [7,8,9], + nonReorderedNullableListWrapped: null, + nonReorderedListWrapped: [4,5,6], + nonReorderedNullableList: null, + nonReorderedList: [1,2,3], + doubleWrappedNullable: 3.0, + nullableDoubleWrapped: null, + doubleWrapped: 2.0, + nullableDouble: null, + double: 1.0, + booleanWrappedNullable: false, + nullableBooleanWrapped: null, + booleanWrapped: true, + nullableBoolean: null, + boolean: true, + stringWrappedNullable: "!!!", + nullableStringWrapped: null, + stringWrapped: "World", + nullableString: null, + string: "Hello", + intWrappedNullable: 3, + nullableIntWrapped: null, + intWrapped: 2, + nullableInt: null, + int: 1 +} + """.replace(Regex("""\s+"""), ""), + ) + } +} + +private class SimpleListEncoder( + val output: MutableList = mutableListOf(), +) : AbstractEncoder() { + override val serializersModule: SerializersModule + get() = EmptySerializersModule() + + override fun encodeValue(value: Any) { + output += value + } +} + +@Serializable +private data class AllTypesExample( + val int: Int, + val nullableInt: Int?, + val intWrapped: IntValue, + val nullableIntWrapped: IntNullableValue, + val intWrappedNullable: IntValue?, + + val string: String, + val nullableString: String?, + val stringWrapped: StringValue, + val nullableStringWrapped: StringNullableValue, + val stringWrappedNullable: StringValue?, + + val boolean: Boolean, + val nullableBoolean: Boolean?, + val booleanWrapped: BooleanValue, + val nullableBooleanWrapped: BooleanNullableValue, + val booleanWrappedNullable: BooleanValue?, + + val double: Double, + val nullableDouble: Double?, + val doubleWrapped: DoubleValue, + val nullableDoubleWrapped: DoubleNullableValue, + val doubleWrappedNullable: DoubleValue?, + + val nonReorderedList: List, + val nonReorderedNullableList: List?, + val nonReorderedListWrapped: ListValue, + val nonReorderedNullableListWrapped: ListNullableValue, + val nonReorderedListWrappedNullable: ListNullableValue?, + val nonReorderedListWrappedNullableValues: ListNullableValues?, + + val float: Float, + val nullableFloat: Float?, + val floatWrapped: FloatValue, + val nullableFloatWrapped: FloatNullableValue, + val floatWrappedNullable: FloatValue?, + + val long: Long, + val nullableLong: Long?, + val longWrapped: LongValue, + val nullableLongWrapped: LongNullableValue, + val longWrappedNullable: LongValue?, + + val short: Short, + val nullableShort: Short?, + val shortWrapped: ShortValue, + val nullableShortWrapped: ShortNullableValue, + val shortWrappedNullable: ShortValue?, + + val nonReorderedSubStructure: SubStructure, + val nonReorderedNullableSubStructure: SubStructure?, + val nonReorderedSubStructureWrapped: SubStructureValue, + val nonReorderedNullableSubStructureWrapped: SubStructureNullableValue, + val nonReorderedSubStructureWrappedNullable: SubStructureValue?, + + val byte: Byte, + val nullableByte: Byte?, + val byteWrapped: ByteValue, + val nullableByteWrapped: ByteNullableValue, + val byteWrappedNullable: ByteValue?, + + val char: Char, + val nullableChar: Char?, + val charWrapped: CharValue, + val nullableCharWrapped: CharNullableValue, + val charWrappedNullable: CharValue?, +) + +@Serializable +private data class SubStructure( + val a: IntValue, + val b: String, + val c: Long?, + val d: Byte?, +) + +@Serializable +@JvmInline +private value class IntValue(val value: Int) + +@Serializable +@JvmInline +private value class StringValue(val value: String) + +@Serializable +@JvmInline +private value class BooleanValue(val value: Boolean) + +@Serializable +@JvmInline +private value class DoubleValue(val value: Double) + +@Serializable +@JvmInline +private value class FloatValue(val value: Float) + +@Serializable +@JvmInline +private value class LongValue(val value: Long) + +@Serializable +@JvmInline +private value class ShortValue(val value: Short) + +@Serializable +@JvmInline +private value class ByteValue(val value: Byte) + +@Serializable +@JvmInline +private value class CharValue(val value: Char) + +@Serializable +@JvmInline +private value class SubStructureValue(val value: SubStructure) + +@Serializable +@JvmInline +private value class ListValue(val value: List) + + +@Serializable +@JvmInline +private value class IntNullableValue(val value: Int?) + +@Serializable +@JvmInline +private value class StringNullableValue(val value: String?) + +@Serializable +@JvmInline +private value class BooleanNullableValue(val value: Boolean?) + +@Serializable +@JvmInline +private value class DoubleNullableValue(val value: Double?) + +@Serializable +@JvmInline +private value class FloatNullableValue(val value: Float?) + +@Serializable +@JvmInline +private value class LongNullableValue(val value: Long?) + +@Serializable +@JvmInline +private value class ShortNullableValue(val value: Short?) + +@Serializable +@JvmInline +private value class ByteNullableValue(val value: Byte?) + +@Serializable +@JvmInline +private value class CharNullableValue(val value: Char?) + +@Serializable +@JvmInline +private value class SubStructureNullableValue(val value: SubStructure?) + +@Serializable +@JvmInline +private value class ListNullableValue(val value: List?) + +@Serializable +@JvmInline +private value class ListNullableValues(val value: List) + + +private class LightJsonEncoder( + val sb: StringBuilder, + val descriptorToReorder: SerialDescriptor, + val mapElementIndex: (SerialDescriptor, Int) -> Int, +) : AbstractEncoder() { + override val serializersModule: SerializersModule = EmptySerializersModule() + var previousDescriptor: SerialDescriptor? = null + + override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { + if (descriptor.kind == StructureKind.LIST) { + sb.append('[') + } else { + sb.append('{') + } + previousDescriptor = null + if (descriptor == descriptorToReorder) { + return ReorderingCompositeEncoder(descriptor.elementsCount, this, mapElementIndex) + } + return this + } + + override fun endStructure(descriptor: SerialDescriptor) { + if (descriptor.kind == StructureKind.LIST) { + sb.append(']') + } else { + sb.append('}') + } + } + + override fun encodeElement(descriptor: SerialDescriptor, index: Int): Boolean { + if (previousDescriptor == null) { + previousDescriptor = descriptor + } else { + previousDescriptor = descriptor + sb.append(",") + } + if (descriptor.kind != StructureKind.LIST) { + sb.append(descriptor.getElementName(index)) + sb.append(':') + } + return true + } + + override fun encodeNull() { + sb.append("null") + } + + override fun encodeValue(value: Any) { + sb.append(value) + } + + override fun encodeString(value: String) { + sb.append('"') + sb.append(value) + sb.append('"') + } + + override fun encodeChar(value: Char) = encodeString(value.toString()) +}