Skip to content

Commit

Permalink
KTXIO-based ASN.1 decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
JesusMcCloud committed Oct 17, 2024
1 parent a6bc66f commit fd85d07
Show file tree
Hide file tree
Showing 29 changed files with 376 additions and 303 deletions.
4 changes: 0 additions & 4 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ tasks.getByName("dokkaHtmlMultiModule") {
allprojects {
apply(plugin = "org.jetbrains.dokka")
group = rootProject.group

repositories {
mavenLocal()
}
}

tasks.register<Copy>("copyChangelog") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class Asn1BitString private constructor(

@Throws(Asn1Exception::class)
override fun doDecode(src: Asn1Primitive): Asn1BitString {
if (src.length == 0) return Asn1BitString(0, byteArrayOf())
if (src.length == 0L) return Asn1BitString(0, byteArrayOf())
if (src.content.first() > 7) throw Asn1Exception("Number of padding bits < 7")
return Asn1BitString(src.content[0], src.content.sliceArray(1..<src.content.size))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import at.asitplus.catching
import at.asitplus.signum.indispensable.asn1.Asn1Element.Tag.Template.Companion.withClass
import at.asitplus.signum.indispensable.asn1.encoding.*
import at.asitplus.signum.indispensable.io.ByteArrayBase64Serializer
import at.asitplus.signum.indispensable.io.copyToSource
import io.matthewnelson.encoding.base16.Base16
import io.matthewnelson.encoding.core.Decoder.Companion.decodeToByteArray
import io.matthewnelson.encoding.core.Encoder.Companion.encodeToString
import kotlinx.io.Buffer
import kotlinx.io.Sink
import kotlinx.io.*
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.PrimitiveKind
Expand All @@ -23,7 +22,7 @@ import kotlin.native.ObjCName
*/
@Serializable(with = Asn1EncodableSerializer::class)
sealed class Asn1Element(
internal val tlv: TLV,
internal val tlv: TLV.Immutable,
protected open val children: List<Asn1Element>?
) {

Expand All @@ -32,11 +31,10 @@ sealed class Asn1Element(
if (other == null) return false
if (other !is Asn1Element) return false
if (tag != other.tag) return false
if (!content.contentEquals(other.content)) return false
if (this is Asn1Structure && other !is Asn1Structure) return false
if (this is Asn1Primitive && other !is Asn1Primitive) return false
return if (this is Asn1Primitive) {
(this.content contentEquals other.content)
(content.contentEquals(other.content))
} else {
this as Asn1Structure
other as Asn1Structure
Expand All @@ -60,15 +58,15 @@ sealed class Asn1Element(
* For a primitive, this is just the size of the held bytes.
* For a structure, it is the sum of the number of bytes needed to encode all held child nodes.
*/
val encodedLength by lazy { length.encodeLength() }
val encodedLength by lazy { Buffer().apply { encodeLength(length) }.snapshot().toByteArray() }

/**
* Length (as a plain `Int` to work with it in code) of the contained data.
* Length (as a plain `Long` to work with it in code) of the contained data.
* For a primitive, this is just the size of the held bytes.
* For a structure, it is the sum of the number of bytes needed to encode all held child nodes.
*/
val length: Int by lazy {
children?.fold(0) { acc, extendedTlv -> acc + extendedTlv.overallLength } ?: tlv.contentLength
val length: Long by lazy {
children?.fold(0L) { acc, extendedTlv -> acc + extendedTlv.overallLength } ?: tlv.contentLength
}

/**
Expand All @@ -81,18 +79,19 @@ sealed class Asn1Element(
val tag by lazy { tlv.tag }

val derEncoded: ByteArray by lazy {
children?.fold(byteArrayOf()) { acc, extendedTlv -> acc + extendedTlv.derEncoded }
?.let { byteArrayOf(*tlv.tag.encodedTag, *it.size.encodeLength(), *it) }
?: byteArrayOf(*tlv.tag.encodedTag, *encodedLength, *tlv.content)
(children?.fold(Buffer()) { acc, extendedTlv -> acc.apply { write(extendedTlv.derEncoded) } }
?.let {
Buffer().apply { write(tlv.tag.encodedTag); encodeLength(it.size); it.transferTo(this) }
}?.readByteArray()
?: Buffer().apply { write(tlv.tag.encodedTag); write(encodedLength);write(tlv.content) }.readByteArray())
}

override fun toString(): String = "(tag=${tlv.tag}" +
", length=${length}" +
", overallLength=${overallLength}" +
(children?.let { ", children=$children" } ?: ", content=${
content.encodeToString(Base16 {
lineBreakInterval = 0;encodeToLowercase = false
})
@OptIn(ExperimentalStdlibApi::class)
content.toHexString(HexFormat.UpperCase)
}") +
")"

Expand All @@ -103,16 +102,17 @@ sealed class Asn1Element(
", length=${length}" +
", overallLength=${overallLength}" +
((children?.joinToString(
prefix = ")\n" + (" " * indent) + "{\n",
prefix = ", ${children!!.size} elem${if (this is Asn1OctetString<*>) ", ${prettyPrintRawContent()}" else ""})\n" + (" " * indent) + "{\n",
separator = "\n",
postfix = "\n" + (" " * indent) + "}"
) { it.prettyPrint(indent + 2) }) ?: ", content=${
content.encodeToString(Base16 {
lineBreakInterval = 0;encodeToLowercase = false
})
})")
) { it.prettyPrint(indent + 2) }) ?: ", ${prettyPrintRawContent()})")


protected fun prettyPrintRawContent(): String = "content=${
@OptIn(ExperimentalStdlibApi::class)
content.toHexString(HexFormat.UpperCase)
}"

protected operator fun String.times(op: Int): String {
var s = this
kotlin.repeat(op) { s += this }
Expand All @@ -122,11 +122,9 @@ sealed class Asn1Element(
/**
* Convenience method to directly produce an HEX string of this element's ANS.1 representation
*/
fun toDerHexString(lineLen: Byte? = null) = derEncoded.encodeToString(Base16 {
lineLen?.let {
lineBreakInterval = lineLen
}
})
fun toDerHexString(lineLen: Byte? = null) = @OptIn(ExperimentalStdlibApi::class)
derEncoded.toHexString(HexFormat.UpperCase)
.let { if (lineLen == null) it else it.chunked(lineLen.toInt()).joinToString("\n") }

override fun hashCode(): Int {
var result = tlv.hashCode()
Expand Down Expand Up @@ -242,8 +240,12 @@ sealed class Asn1Element(
@Serializable(with = ByteArrayBase64Serializer::class) val encodedTag: ByteArray
) : Comparable<Tag> {
private constructor(values: Triple<ULong, Int, ByteArray>) : this(values.first, values.second, values.third)

//TODO this CTOR is internally called only using already validated inputs.
// We need another CTOR to prevent double-parsing and byte copying
constructor(derEncoded: ByteArray) : this(
derEncoded.iterator().decodeTag().let { Triple(it.first, it.second.size, derEncoded) }
derEncoded.copyToSource().decodeTag()
.let { Triple(it.first, it.second.size, derEncoded) }
)

/**
Expand All @@ -261,20 +263,21 @@ sealed class Asn1Element(

companion object {
private fun encode(tagClass: TagClass, constructed: Boolean, tagValue: ULong): ByteArray {
val derEncoded: ByteArray =
if (tagValue <= 30u) {
byteArrayOf(tagValue.toUByte().toByte())
} else {
byteArrayOf(0b11111, *tagValue.toAsn1VarInt())
val derEncoded: Buffer =
Buffer().apply {
if (tagValue <= 30u) {
writeUByte(tagValue.toUByte().withTagProperties(constructed, tagClass))
} else {
writeUByte((0b11111u).toUByte().withTagProperties(constructed, tagClass))
writeAsn1VarInt(tagValue)
}
}

derEncoded[0] = derEncoded[0].toUByte()
.let { if (constructed) (it or BERTags.CONSTRUCTED) else it }
.let { it or tagClass.berTag }
.toByte()
return derEncoded
return derEncoded.readByteArray()
}

private fun UByte.withTagProperties(constructed: Boolean, tagClass: TagClass): UByte =
(if (constructed) (this or BERTags.CONSTRUCTED) else this) or tagClass.berTag

val SET = Tag(tagValue = BERTags.SET.toULong(), constructed = true)
val SEQUENCE = Tag(tagValue = BERTags.SEQUENCE.toULong(), constructed = true)

Expand Down Expand Up @@ -314,7 +317,7 @@ sealed class Asn1Element(

override fun toString(): String =
"${tagClass.let { if (it == TagClass.UNIVERSAL) "" else it.name + " " }}${tagValue}${if (isConstructed) " CONSTRUCTED" else ""}" +
(" (=${encodedTag.encodeToString(Base16)})")
(" (=${@OptIn(ExperimentalStdlibApi::class) encodedTag.toHexString(HexFormat.UpperCase)})")

/**
* As per ITU-T X.680 8824-1 8.6
Expand Down Expand Up @@ -421,7 +424,8 @@ object Asn1EncodableSerializer : KSerializer<Asn1Element> {
}

override fun serialize(encoder: Encoder, value: Asn1Element) {
encoder.encodeString(value.derEncoded.encodeToString(Base16))
@OptIn(ExperimentalStdlibApi::class)
encoder.encodeString(value.derEncoded.toHexString(HexFormat.UpperCase))
}

}
Expand All @@ -447,8 +451,7 @@ sealed class Asn1Structure(
*/
val isSorted: Boolean = false
) :
Asn1Element(TLV(tag, byteArrayOf()), if (!isSorted) children else children.sortedBy { it.tag }) {

Asn1Element(TLV.Immutable(tag, byteArrayOf()), if (!isSorted) children else children.sortedBy { it.tag }) {

public override val children: List<Asn1Element>
get() = super.children!!
Expand Down Expand Up @@ -580,8 +583,7 @@ class Asn1CustomStructure private constructor(


override val content: ByteArray by lazy {
if (!tag.isConstructed)
children.fold(byteArrayOf()) { acc, asn1Element -> acc + asn1Element.derEncoded }
if (!tag.isConstructed) Buffer().apply { children.forEach { write(it.derEncoded) } }.readByteArray()
else super.content
}

Expand Down Expand Up @@ -618,7 +620,7 @@ class Asn1EncapsulatingOctetString(children: List<Asn1Element>) :
Asn1Structure(Tag.OCTET_STRING, children),
Asn1OctetString<Asn1EncapsulatingOctetString> {
override val content: ByteArray by lazy {
children.fold(byteArrayOf()) { acc, asn1Element -> acc + asn1Element.derEncoded }
Buffer().apply { children.forEach { write(it.derEncoded) } }.readByteArray()
}

override fun unwrap() = this
Expand Down Expand Up @@ -690,7 +692,7 @@ class Asn1SetOf @Throws(Asn1Exception::class) internal constructor(children: Lis
/**
* ASN.1 primitive. Hold o children, but [content] under [tag]
*/
open class Asn1Primitive(tag: Tag, content: ByteArray) : Asn1Element(TLV(tag, content), null) {
open class Asn1Primitive(tag: Tag, content: ByteArray) : Asn1Element(TLV.Immutable(tag, content), null) {
init {
if (tag.isConstructed) throw IllegalArgumentException("A primitive cannot have a CONSTRUCTED tag")
}
Expand Down Expand Up @@ -739,7 +741,7 @@ interface Asn1OctetString<T : Asn1Element> {


@Throws(IllegalArgumentException::class)
internal fun Int.encodeLength(): ByteArray {
internal fun Long.encodeLength(): ByteArray {
require(this >= 0)
return when {
(this < 0x80) -> byteArrayOf(this.toByte()) /* short form */
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package at.asitplus.signum.indispensable.asn1

import at.asitplus.signum.indispensable.asn1.encoding.asAsn1String
import kotlinx.io.bytestring.ByteString
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package at.asitplus.signum.indispensable.asn1

import io.matthewnelson.encoding.base16.Base16
import io.matthewnelson.encoding.core.Encoder.Companion.encodeToString
import kotlinx.io.Buffer
import kotlinx.io.bytestring.ByteString
import kotlinx.io.bytestring.toHexString
import kotlinx.io.readByteArray
import kotlinx.io.snapshot

internal data class TLV(val tag: Asn1Element.Tag, val content: ByteArray) {
internal sealed class TLV<T>(val tag: Asn1Element.Tag, val content: T) {

val encodedContentLength by lazy { contentLength.encodeLength() }
val contentLength: Int by lazy { content.size }
val overallLength: Int by lazy { contentLength + tag.encodedTagLength + encodedContentLength.size }
abstract val contentLength: Long
val overallLength: Long by lazy { contentLength + tag.encodedTagLength + encodedContentLength.size }

val tagClass: TagClass get() = tag.tagClass

Expand All @@ -16,11 +19,13 @@ internal data class TLV(val tag: Asn1Element.Tag, val content: ByteArray) {
val encodedTag get() = tag.encodedTag

override fun equals(other: Any?): Boolean {
if (other is TLV.Shallow || this is TLV.Shallow) throw IllegalStateException("Shallow TLVs may neve be compared")
if (this === other) return true
if (other == null) return false
if (this::class != other::class) return false

other as TLV
other as Immutable
this as Immutable

if (tag != other.tag) return false
if (!content.contentEquals(other.content)) return false
Expand All @@ -30,15 +35,55 @@ internal data class TLV(val tag: Asn1Element.Tag, val content: ByteArray) {

override fun hashCode(): Int {
var result = tag.hashCode()
result = 31 * result + content.contentHashCode()
result = 31 * result + if(this is Shallow) content.hashCode() else (content as ByteArray).contentHashCode()
return result
}

protected abstract val contentHexString: String

override fun toString(): String {
return "TLV(tag=$tag" +
", length=$contentLength" +
", overallLength=$overallLength" +
", content=${content.encodeToString(Base16)})"
", content=$contentHexString)"
}

/**
* Shallow TLV, containing a reference to the buffer it is based on. Once [content] is consumed, the underlying bytes are gone.
*/
class Shallow(tag: Asn1Element.Tag, content: Buffer) : TLV<Buffer>(tag, content) {

override val contentLength: Long by lazy { content.size }


override fun equals(other: Any?): Boolean {
throw IllegalStateException("Shallow TLVs may neve be compared")
}


override val contentHexString: String by lazy {
@OptIn(ExperimentalStdlibApi::class)
content.snapshot().toHexString(HexFormat.UpperCase)
}

/**
* Deep-copies this shallow TLV into an [Immutable] one. Does not consume anything from [content]
*/
fun deepCopy() = Immutable(tag, content.copy().readByteArray())

}

/**
* Immutable TLV, containing a deep copy of the parsed bytes
*/
class Immutable(tag: Asn1Element.Tag, content: ByteArray) : TLV<ByteArray>(tag, content) {

override val contentLength: Long by lazy { content.size.toLong() }

override val contentHexString: String by lazy {
@OptIn(ExperimentalStdlibApi::class)
content.toHexString(HexFormat.UpperCase)
}
}

}
Loading

0 comments on commit fd85d07

Please sign in to comment.