Skip to content

Commit

Permalink
more efficient oid
Browse files Browse the repository at this point in the history
  • Loading branch information
JesusMcCloud committed Oct 7, 2024
1 parent 5e2c469 commit ac8ecbb
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import at.asitplus.signum.indispensable.asn1.encoding.decodeAsn1VarBigInt
import at.asitplus.signum.indispensable.asn1.encoding.toAsn1VarInt
import at.asitplus.signum.indispensable.asn1.encoding.toBigInteger
import com.ionspin.kotlin.bignum.integer.BigInteger
import com.ionspin.kotlin.bignum.integer.Sign
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient
Expand All @@ -26,16 +25,38 @@ private val BIGINT_40 = BigInteger.fromUByte(40u)
* @throws Asn1Exception if less than two nodes are supplied, the first node is >2 or the second node is >39
*/
@Serializable(with = ObjectIdSerializer::class)
class ObjectIdentifier @Throws(Asn1Exception::class) constructor(@Transient vararg val nodes: BigInteger) :
class ObjectIdentifier @Throws(Asn1Exception::class) constructor(
val bytes: ByteArray,
@Transient private val _nodes: List<BigInteger>? = null, dontVerify: Boolean = false
) :
Asn1Encodable<Asn1Primitive> {

init {
if (nodes.size < 2) throw Asn1StructuralException("at least two nodes required!")
if ((nodes[0] * BIGINT_40) > UByte.MAX_VALUE.toUInt()) throw Asn1Exception("first node too lage!")
//TODO more sanity checks

if (nodes.first() > 2u) throw Asn1Exception("OID must start with either 1 or 2")
if (nodes[1] > 39u) throw Asn1Exception("Second segment must be <40")
if (_nodes == null || dontVerify) {
//Verify that everything can be parsed into nodes
if (bytes.isEmpty()) throw Asn1Exception("Empty OIDs are not supported")
var index = 1
while (index < bytes.size) {
if (bytes[index] >= 0) {
index++
} else {
val currentNode = mutableListOf<Byte>()
while (bytes[index] < 0) {
currentNode += bytes[index] //+= parsed
index++
}
currentNode += bytes[index]
index++
val consumed = currentNode.iterator().consumeVarIntEncoded()
@OptIn(ExperimentalStdlibApi::class)
if (consumed != currentNode) throw Asn1Exception(
"Trailing bytes in OID Node ${
currentNode.toByteArray().toHexString(HexFormat.UpperCase)
}"
)
}
}
}
}

/**
Expand All @@ -45,51 +66,79 @@ class ObjectIdentifier @Throws(Asn1Exception::class) constructor(@Transient vara
*/
@OptIn(ExperimentalUuidApi::class)
constructor(uuid: Uuid) : this(
BigInteger.fromByte(2),
BigInteger.fromByte(25),
uuid.toBigInteger()
byteArrayOf((2 * 40 + 25).toUByte().toByte(), *uuid.toBigInteger().toAsn1VarInt())
)

/**
* @param nodes OID Tree nodes passed in order (e.g. 1u, 2u, 96u, …)
* @throws Asn1Exception if less than two nodes are supplied, the first node is >2 or the second node is >39
*/
constructor(vararg ints: UInt) : this(*(ints.map { BigInteger.fromUInt(it) }.toTypedArray()))
constructor(vararg nodes: UInt) : this(
nodes.toOidBytes(),
dontVerify = true
)

/**
* @param nodes OID Tree nodes passed in order (e.g. 1, 2, 96, …)
* @throws Asn1Exception if less than two nodes are supplied, the first node is >2, the second node is >39 or any node is negative
*/
constructor(vararg nodes: BigInteger) : this(nodes.toOidBytes(), _nodes = nodes.asList())

/**
* @param oid in human-readable format (e.g. "1.2.96")
*/
constructor(oid: String) : this(*(oid.split(if (oid.contains('.')) '.' else ' ')).map { BigInteger.parseString(it) }
.toTypedArray())

/**
* Lazily evaluated list of OID nodes (e.g. `[1, 2, 35, 4654]`)
*/
val nodes: List<BigInteger> by lazy {
if (_nodes != null) _nodes else {
val (first, second) =
if (bytes[0] >= 80) {
BigInteger.fromUByte(2u) to BigInteger.fromUInt(bytes[0].toUByte() - 80u)
} else {
BigInteger.fromUInt(bytes[0].toUByte() / 40u) to BigInteger.fromUInt(bytes[0].toUByte() % 40u)
}
var index = 1
val collected = mutableListOf(first, second)
while (index < bytes.size) {
if (bytes[index] >= 0) {
collected += BigInteger.fromUInt(bytes[index].toUInt())
index++
} else {
val currentNode = mutableListOf<Byte>()
while (bytes[index] < 0) {
currentNode += bytes[index] //+= parsed
index++
}
currentNode += bytes[index]
index++
collected += currentNode.decodeAsn1VarBigInt().first
}
}
collected
}
}

/**
* @return human-readable format (e.g. "1.2.96")
*/
override fun toString() = nodes.joinToString(separator = ".") { it.toString() }
override fun toString(): String {
return nodes.joinToString(".")
}

override fun equals(other: Any?): Boolean {
if (other == null) return false
if (other !is ObjectIdentifier) return false
return nodes contentEquals other.nodes
return bytes contentEquals other.bytes
}

override fun hashCode(): Int {
return bytes.contentHashCode()
}


/**
* Cursed encoding of OID nodes. A sacrifice of pristine numbers requested by past gods of the netherrealm
*/
val bytes: ByteArray by lazy {
nodes.slice(2..<nodes.size).map { it.toAsn1VarInt() }.fold(
byteArrayOf(
(nodes[0] * BIGINT_40 + nodes[1]).ubyteValue(exactRequired = true).toByte()
)
) { acc, bytes -> acc + bytes }
}

/**
* @return an OBJECT IDENTIFIER [Asn1Primitive]
*/
Expand All @@ -115,33 +164,37 @@ class ObjectIdentifier @Throws(Asn1Exception::class) constructor(@Transient vara
* @throws Asn1Exception all sorts of errors on invalid input
*/
@Throws(Asn1Exception::class)
fun parse(rawValue: ByteArray): ObjectIdentifier = runRethrowing {
if (rawValue.isEmpty()) throw Asn1Exception("Empty OIDs are not supported")
val (first, second) =
if (rawValue[0] >= 80) {
BigInteger.fromUByte(2u) to BigInteger.fromUInt(rawValue[0].toUByte() - 80u)
} else {
BigInteger.fromUInt(rawValue[0].toUByte() / 40u) to BigInteger.fromUInt(rawValue[0].toUByte() % 40u)
}

var index = 1
val collected = mutableListOf(first, second)
while (index < rawValue.size) {
if (rawValue[index] >= 0) {
collected += BigInteger.fromUInt(rawValue[index].toUInt())
index++
} else {
val currentNode = mutableListOf<Byte>()
while (rawValue[index] < 0) {
currentNode += rawValue[index] //+= parsed
index++
}
currentNode += rawValue[index]
index++
collected += currentNode.decodeAsn1VarBigInt().first
}
fun parse(rawValue: ByteArray): ObjectIdentifier = ObjectIdentifier(rawValue)

private fun Iterator<Byte>.consumeVarIntEncoded(): MutableList<Byte> {
val accumulator = mutableListOf<Byte>()
while (hasNext()) {
val curByte = next()
val current = BigInteger(curByte.toUByte().toInt())
accumulator += curByte
if (current < 0x80.toUByte()) break
}
return ObjectIdentifier(*collected.toTypedArray())
return accumulator
}

private fun UIntArray.toOidBytes(): ByteArray {
if (size < 2) throw Asn1StructuralException("at least two nodes required!")
if (first() > 2u) throw Asn1Exception("OID must start with either 1 or 2")
if (get(1) > 39u) throw Asn1Exception("Second segment must be <40")
return slice(2..<size).map { it.toAsn1VarInt() }.fold(
byteArrayOf((first() * 40u + get(1)).toUByte().toByte())
) { acc, bytes -> acc + bytes }
}

private fun Array<out BigInteger>.toOidBytes(): ByteArray {
if (size < 2) throw Asn1StructuralException("at least two nodes required!")
if (first() > 2u) throw Asn1Exception("OID must start with either 1 or 2")
if (get(1) > 39u) throw Asn1Exception("Second segment must be <40")

return slice(2..<size).map { if (it.isNegative) throw Asn1Exception("Negative Number encountered: $it") else it.toAsn1VarInt() }
.fold(
byteArrayOf((first().intValue() * 40 + get(1).intValue()).toUByte().toByte())
) { acc, bytes -> acc + bytes }
}
}
}
Expand All @@ -152,7 +205,7 @@ object ObjectIdSerializer : KSerializer<ObjectIdentifier> {
override fun deserialize(decoder: Decoder): ObjectIdentifier = ObjectIdentifier(decoder.decodeString())

override fun serialize(encoder: Encoder, value: ObjectIdentifier) {
encoder.encodeString(value.nodes.joinToString(separator = ".") { it.toString() })
encoder.encodeString(toString())
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,13 @@ inline fun ByteArray.decodeAsn1VarBigInt(): Pair<BigInteger, ByteArray> = iterat


/**
* Decodes an ULong from bytes using varint encoding as used within ASN.1: groups of seven bits are encoded into a byte,
* Decodes an BigInteger from bytes using varint encoding as used within ASN.1: groups of seven bits are encoded into a byte,
* while the highest bit indicates if more bytes are to come. Trailing bytes are ignored.
*
* @return the decoded ULong and the underlying varint-encoded bytes as `ByteArray`
* @return the decoded BigInteger and the underlying varint-encoded bytes as `ByteArray`
* @throws IllegalArgumentException if the number is larger than [ULong.MAX_VALUE]
*/
fun Iterator<Byte>.decodeAsn1VarBigInt(): Pair<BigInteger, ByteArray> {
var offset = 0
var result = BigInteger.ZERO
val mask = BigInteger.fromUByte(0x7Fu)
val accumulator = mutableListOf<Byte>()
Expand Down
Loading

0 comments on commit ac8ecbb

Please sign in to comment.