From ac8ecbbd0f8359cac3190a6420091c614ed5331e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bernd=20Pr=C3=BCnster?= Date: Mon, 7 Oct 2024 12:54:34 +0200 Subject: [PATCH] more efficient oid --- .../indispensable/asn1/ObjectIdentifier.kt | 159 ++++++++----- .../asn1/encoding/NumberEncoding.kt | 5 +- .../asitplus/signum/indispensable/OidTest.kt | 208 +++++++++++++++++- 3 files changed, 308 insertions(+), 64 deletions(-) diff --git a/indispensable/src/commonMain/kotlin/at/asitplus/signum/indispensable/asn1/ObjectIdentifier.kt b/indispensable/src/commonMain/kotlin/at/asitplus/signum/indispensable/asn1/ObjectIdentifier.kt index 0b4185b0..99620c7b 100644 --- a/indispensable/src/commonMain/kotlin/at/asitplus/signum/indispensable/asn1/ObjectIdentifier.kt +++ b/indispensable/src/commonMain/kotlin/at/asitplus/signum/indispensable/asn1/ObjectIdentifier.kt @@ -5,7 +5,6 @@ import at.asitplus.signum.indispensable.asn1.encoding.decodeAsn1VarBigInt import at.asitplus.signum.indispensable.asn1.encoding.toAsn1VarInt import at.asitplus.signum.indispensable.asn1.encoding.toBigInteger import com.ionspin.kotlin.bignum.integer.BigInteger -import com.ionspin.kotlin.bignum.integer.Sign import kotlinx.serialization.KSerializer import kotlinx.serialization.Serializable import kotlinx.serialization.Transient @@ -26,16 +25,38 @@ private val BIGINT_40 = BigInteger.fromUByte(40u) * @throws Asn1Exception if less than two nodes are supplied, the first node is >2 or the second node is >39 */ @Serializable(with = ObjectIdSerializer::class) -class ObjectIdentifier @Throws(Asn1Exception::class) constructor(@Transient vararg val nodes: BigInteger) : +class ObjectIdentifier @Throws(Asn1Exception::class) constructor( + val bytes: ByteArray, + @Transient private val _nodes: List? = null, dontVerify: Boolean = false +) : Asn1Encodable { init { - if (nodes.size < 2) throw Asn1StructuralException("at least two nodes required!") - if ((nodes[0] * BIGINT_40) > UByte.MAX_VALUE.toUInt()) throw Asn1Exception("first node too lage!") - //TODO more sanity checks - - if (nodes.first() > 2u) throw Asn1Exception("OID must start with either 1 or 2") - if (nodes[1] > 39u) throw Asn1Exception("Second segment must be <40") + if (_nodes == null || dontVerify) { + //Verify that everything can be parsed into nodes + if (bytes.isEmpty()) throw Asn1Exception("Empty OIDs are not supported") + var index = 1 + while (index < bytes.size) { + if (bytes[index] >= 0) { + index++ + } else { + val currentNode = mutableListOf() + while (bytes[index] < 0) { + currentNode += bytes[index] //+= parsed + index++ + } + currentNode += bytes[index] + index++ + val consumed = currentNode.iterator().consumeVarIntEncoded() + @OptIn(ExperimentalStdlibApi::class) + if (consumed != currentNode) throw Asn1Exception( + "Trailing bytes in OID Node ${ + currentNode.toByteArray().toHexString(HexFormat.UpperCase) + }" + ) + } + } + } } /** @@ -45,17 +66,23 @@ class ObjectIdentifier @Throws(Asn1Exception::class) constructor(@Transient vara */ @OptIn(ExperimentalUuidApi::class) constructor(uuid: Uuid) : this( - BigInteger.fromByte(2), - BigInteger.fromByte(25), - uuid.toBigInteger() + byteArrayOf((2 * 40 + 25).toUByte().toByte(), *uuid.toBigInteger().toAsn1VarInt()) ) /** * @param nodes OID Tree nodes passed in order (e.g. 1u, 2u, 96u, …) * @throws Asn1Exception if less than two nodes are supplied, the first node is >2 or the second node is >39 */ - constructor(vararg ints: UInt) : this(*(ints.map { BigInteger.fromUInt(it) }.toTypedArray())) + constructor(vararg nodes: UInt) : this( + nodes.toOidBytes(), + dontVerify = true + ) + /** + * @param nodes OID Tree nodes passed in order (e.g. 1, 2, 96, …) + * @throws Asn1Exception if less than two nodes are supplied, the first node is >2, the second node is >39 or any node is negative + */ + constructor(vararg nodes: BigInteger) : this(nodes.toOidBytes(), _nodes = nodes.asList()) /** * @param oid in human-readable format (e.g. "1.2.96") @@ -63,33 +90,55 @@ class ObjectIdentifier @Throws(Asn1Exception::class) constructor(@Transient vara constructor(oid: String) : this(*(oid.split(if (oid.contains('.')) '.' else ' ')).map { BigInteger.parseString(it) } .toTypedArray()) + /** + * Lazily evaluated list of OID nodes (e.g. `[1, 2, 35, 4654]`) + */ + val nodes: List by lazy { + if (_nodes != null) _nodes else { + val (first, second) = + if (bytes[0] >= 80) { + BigInteger.fromUByte(2u) to BigInteger.fromUInt(bytes[0].toUByte() - 80u) + } else { + BigInteger.fromUInt(bytes[0].toUByte() / 40u) to BigInteger.fromUInt(bytes[0].toUByte() % 40u) + } + var index = 1 + val collected = mutableListOf(first, second) + while (index < bytes.size) { + if (bytes[index] >= 0) { + collected += BigInteger.fromUInt(bytes[index].toUInt()) + index++ + } else { + val currentNode = mutableListOf() + while (bytes[index] < 0) { + currentNode += bytes[index] //+= parsed + index++ + } + currentNode += bytes[index] + index++ + collected += currentNode.decodeAsn1VarBigInt().first + } + } + collected + } + } + /** * @return human-readable format (e.g. "1.2.96") */ - override fun toString() = nodes.joinToString(separator = ".") { it.toString() } + override fun toString(): String { + return nodes.joinToString(".") + } override fun equals(other: Any?): Boolean { if (other == null) return false if (other !is ObjectIdentifier) return false - return nodes contentEquals other.nodes + return bytes contentEquals other.bytes } override fun hashCode(): Int { return bytes.contentHashCode() } - - /** - * Cursed encoding of OID nodes. A sacrifice of pristine numbers requested by past gods of the netherrealm - */ - val bytes: ByteArray by lazy { - nodes.slice(2.. acc + bytes } - } - /** * @return an OBJECT IDENTIFIER [Asn1Primitive] */ @@ -115,33 +164,37 @@ class ObjectIdentifier @Throws(Asn1Exception::class) constructor(@Transient vara * @throws Asn1Exception all sorts of errors on invalid input */ @Throws(Asn1Exception::class) - fun parse(rawValue: ByteArray): ObjectIdentifier = runRethrowing { - if (rawValue.isEmpty()) throw Asn1Exception("Empty OIDs are not supported") - val (first, second) = - if (rawValue[0] >= 80) { - BigInteger.fromUByte(2u) to BigInteger.fromUInt(rawValue[0].toUByte() - 80u) - } else { - BigInteger.fromUInt(rawValue[0].toUByte() / 40u) to BigInteger.fromUInt(rawValue[0].toUByte() % 40u) - } - - var index = 1 - val collected = mutableListOf(first, second) - while (index < rawValue.size) { - if (rawValue[index] >= 0) { - collected += BigInteger.fromUInt(rawValue[index].toUInt()) - index++ - } else { - val currentNode = mutableListOf() - while (rawValue[index] < 0) { - currentNode += rawValue[index] //+= parsed - index++ - } - currentNode += rawValue[index] - index++ - collected += currentNode.decodeAsn1VarBigInt().first - } + fun parse(rawValue: ByteArray): ObjectIdentifier = ObjectIdentifier(rawValue) + + private fun Iterator.consumeVarIntEncoded(): MutableList { + val accumulator = mutableListOf() + while (hasNext()) { + val curByte = next() + val current = BigInteger(curByte.toUByte().toInt()) + accumulator += curByte + if (current < 0x80.toUByte()) break } - return ObjectIdentifier(*collected.toTypedArray()) + return accumulator + } + + private fun UIntArray.toOidBytes(): ByteArray { + if (size < 2) throw Asn1StructuralException("at least two nodes required!") + if (first() > 2u) throw Asn1Exception("OID must start with either 1 or 2") + if (get(1) > 39u) throw Asn1Exception("Second segment must be <40") + return slice(2.. acc + bytes } + } + + private fun Array.toOidBytes(): ByteArray { + if (size < 2) throw Asn1StructuralException("at least two nodes required!") + if (first() > 2u) throw Asn1Exception("OID must start with either 1 or 2") + if (get(1) > 39u) throw Asn1Exception("Second segment must be <40") + + return slice(2.. acc + bytes } } } } @@ -152,7 +205,7 @@ object ObjectIdSerializer : KSerializer { override fun deserialize(decoder: Decoder): ObjectIdentifier = ObjectIdentifier(decoder.decodeString()) override fun serialize(encoder: Encoder, value: ObjectIdentifier) { - encoder.encodeString(value.nodes.joinToString(separator = ".") { it.toString() }) + encoder.encodeString(toString()) } } diff --git a/indispensable/src/commonMain/kotlin/at/asitplus/signum/indispensable/asn1/encoding/NumberEncoding.kt b/indispensable/src/commonMain/kotlin/at/asitplus/signum/indispensable/asn1/encoding/NumberEncoding.kt index 20d00830..c1a1f638 100644 --- a/indispensable/src/commonMain/kotlin/at/asitplus/signum/indispensable/asn1/encoding/NumberEncoding.kt +++ b/indispensable/src/commonMain/kotlin/at/asitplus/signum/indispensable/asn1/encoding/NumberEncoding.kt @@ -377,14 +377,13 @@ inline fun ByteArray.decodeAsn1VarBigInt(): Pair = iterat /** - * Decodes an ULong from bytes using varint encoding as used within ASN.1: groups of seven bits are encoded into a byte, + * Decodes an BigInteger from bytes using varint encoding as used within ASN.1: groups of seven bits are encoded into a byte, * while the highest bit indicates if more bytes are to come. Trailing bytes are ignored. * - * @return the decoded ULong and the underlying varint-encoded bytes as `ByteArray` + * @return the decoded BigInteger and the underlying varint-encoded bytes as `ByteArray` * @throws IllegalArgumentException if the number is larger than [ULong.MAX_VALUE] */ fun Iterator.decodeAsn1VarBigInt(): Pair { - var offset = 0 var result = BigInteger.ZERO val mask = BigInteger.fromUByte(0x7Fu) val accumulator = mutableListOf() diff --git a/indispensable/src/jvmTest/kotlin/at/asitplus/signum/indispensable/OidTest.kt b/indispensable/src/jvmTest/kotlin/at/asitplus/signum/indispensable/OidTest.kt index 20b0172a..790de798 100644 --- a/indispensable/src/jvmTest/kotlin/at/asitplus/signum/indispensable/OidTest.kt +++ b/indispensable/src/jvmTest/kotlin/at/asitplus/signum/indispensable/OidTest.kt @@ -1,22 +1,22 @@ package at.asitplus.signum.indispensable -import at.asitplus.signum.indispensable.asn1.ObjectIdentifier -import at.asitplus.signum.indispensable.asn1.encoding.fromBigintOrNull -import at.asitplus.signum.indispensable.asn1.encoding.toBigInteger +import at.asitplus.signum.indispensable.asn1.* +import at.asitplus.signum.indispensable.asn1.encoding.* import com.ionspin.kotlin.bignum.integer.BigInteger import com.ionspin.kotlin.bignum.integer.Sign import io.kotest.assertions.withClue import io.kotest.core.spec.style.FreeSpec import io.kotest.datatest.withData +import io.kotest.matchers.comparables.shouldBeLessThan import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe import io.kotest.property.Arb -import io.kotest.property.arbitrary.bigInt -import io.kotest.property.arbitrary.int -import io.kotest.property.arbitrary.intArray -import io.kotest.property.arbitrary.positiveInt +import io.kotest.property.arbitrary.* import io.kotest.property.checkAll +import kotlinx.datetime.Clock import org.bouncycastle.asn1.ASN1ObjectIdentifier +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid @@ -84,6 +84,56 @@ class OidTest : FreeSpec({ } } + "Benchmarking fast case" - { + val optimized = mutableListOf() + val repetitions= 10 + + "Optimized" - { + repeat(repetitions) { + val before = Clock.System.now() + checkAll(iterations = 15, Arb.uInt(max = 39u)) { second -> + checkAll(iterations = 5000, Arb.uIntArray(Arb.int(0..256), Arb.uInt(UInt.MAX_VALUE))) { + listOf(1u, 2u).forEach { first -> + val oid = ObjectIdentifier(first, second, *it.toUIntArray()) + ObjectIdentifier.decodeFromTlv(oid.encodeToTlv()) + } + } + } + val duration = Clock.System.now() - before + optimized += duration + println("Optimized: $duration") + } + } + + val avgOpt = (optimized.sorted().subList(0, optimized.size - 1) + .sumOf { it.inWholeMilliseconds } / optimized.size - 2).milliseconds + println("AvgOpt: $avgOpt") + val simple = mutableListOf() + "Simple" - { + repeat(repetitions) { + val before = Clock.System.now() + checkAll(iterations = 15, Arb.uInt(max = 39u)) { second -> + checkAll(iterations = 5000, Arb.uIntArray(Arb.int(0..256), Arb.uInt(UInt.MAX_VALUE))) { + listOf(1u, 2u).forEach { first -> + val oid = OldOIDObjectIdentifier(first, second, *it.toUIntArray()) + OldOIDObjectIdentifier.decodeFromTlv(oid.encodeToTlv()) + } + } + } + val duration = Clock.System.now() - before + simple += duration + println("Simple $duration") + } + } + + val avgSimple = (simple.sorted().subList(0, simple.size - 1) + .sumOf { it.inWholeMilliseconds } / simple.size - 2).milliseconds + println("AvgSimple: $avgSimple") + + avgOpt shouldBeLessThan avgSimple + + } + "Automated BigInt" - { checkAll(iterations = 15, Arb.positiveInt(39)) { second -> checkAll(iterations = 500, Arb.bigInt(1, 358)) { @@ -154,4 +204,146 @@ class OidTest : FreeSpec({ } } } -}) \ No newline at end of file +}) + + +// old implementation for benchmarking +private val BIGINT_40 = BigInteger.fromUByte(40u) + +class OldOIDObjectIdentifier @Throws(Asn1Exception::class) constructor(@Transient vararg val nodes: BigInteger) : + Asn1Encodable { + + init { + if (nodes.size < 2) throw Asn1StructuralException("at least two nodes required!") + if ((nodes[0] * BIGINT_40) > UByte.MAX_VALUE.toUInt()) throw Asn1Exception("first node too lage!") + //TODO more sanity checks + + if (nodes.first() > 2u) throw Asn1Exception("OID must start with either 1 or 2") + if (nodes[1] > 39u) throw Asn1Exception("Second segment must be <40") + } + + /** + * Creates an OID in the 2.25 subtree that requires no formal registration. + * E.g. the UUID `550e8400-e29b-41d4-a716-446655440000` results in the OID + * `2.25.113059749145936325402354257176981405696` + */ + @OptIn(ExperimentalUuidApi::class) + constructor(uuid: Uuid) : this( + BigInteger.fromByte(2), + BigInteger.fromByte(25), + uuid.toBigInteger() + ) + + /** + * @param nodes OID Tree nodes passed in order (e.g. 1u, 2u, 96u, …) + * @throws Asn1Exception if less than two nodes are supplied, the first node is >2 or the second node is >39 + */ + constructor(vararg ints: UInt) : this(*(ints.map { BigInteger.fromUInt(it) }.toTypedArray())) + + + /** + * @param oid in human-readable format (e.g. "1.2.96") + */ + constructor(oid: String) : this(*(oid.split(if (oid.contains('.')) '.' else ' ')).map { BigInteger.parseString(it) } + .toTypedArray()) + + /** + * @return human-readable format (e.g. "1.2.96") + */ + override fun toString() = nodes.joinToString(separator = ".") { it.toString() } + + override fun equals(other: Any?): Boolean { + if (other == null) return false + if (other !is OldOIDObjectIdentifier) return false + return bytes contentEquals other.bytes + } + + override fun hashCode(): Int { + return bytes.contentHashCode() + } + + + /** + * Cursed encoding of OID nodes. A sacrifice of pristine numbers requested by past gods of the netherrealm + */ + val bytes: ByteArray by lazy { + nodes.slice(2.. acc + bytes } + } + + /** + * @return an OBJECT IDENTIFIER [Asn1Primitive] + */ + override fun encodeToTlv() = Asn1Primitive(Asn1Element.Tag.OID, bytes) + + companion object : Asn1Decodable { + + /** + * Parses an OBJECT IDENTIFIER contained in [src] to an [ObjectIdentifier] + * @throws Asn1Exception all sorts of errors on invalid input + */ + @Throws(Asn1Exception::class) + override fun doDecode(src: Asn1Primitive): ObjectIdentifier { + if (src.length < 1) throw Asn1StructuralException("Empty OIDs are not supported") + + return parse(src.content) + + } + + /** + * Casts out the evil demons that haunt OID components encoded into [rawValue] + * @return ObjectIdentifier if decoding succeeded + * @throws Asn1Exception all sorts of errors on invalid input + */ + @Throws(Asn1Exception::class) + fun parse(rawValue: ByteArray): ObjectIdentifier = runRethrowing { + if (rawValue.isEmpty()) throw Asn1Exception("Empty OIDs are not supported") + val (first, second) = + if (rawValue[0] >= 80) { + BigInteger.fromUByte(2u) to BigInteger.fromUInt(rawValue[0].toUByte() - 80u) + } else { + BigInteger.fromUInt(rawValue[0].toUByte() / 40u) to BigInteger.fromUInt(rawValue[0].toUByte() % 40u) + } + + var index = 1 + val collected = mutableListOf(first, second) + while (index < rawValue.size) { + if (rawValue[index] >= 0) { + collected += BigInteger.fromUInt(rawValue[index].toUInt()) + index++ + } else { + val currentNode = mutableListOf() + while (rawValue[index] < 0) { + currentNode += rawValue[index] //+= parsed + index++ + } + currentNode += rawValue[index] + index++ + collected += currentNode.decodeAsn1VarBigInt().first + } + } + return ObjectIdentifier(*collected.toTypedArray()) + } + } +} + + +/** + * Adds [oid] to the implementing class + */ +interface Identifiable { + val oid: ObjectIdentifier +} + +/** + * decodes this [Asn1Primitive]'s content into an [ObjectIdentifier] + * + * @throws Asn1Exception on invalid input + */ +@Throws(Asn1Exception::class) +fun Asn1Primitive.readOid() = runRethrowing { + decode(Asn1Element.Tag.OID) { OldOIDObjectIdentifier.parse(it) } +} \ No newline at end of file