diff --git a/vck/src/commonMain/kotlin/at/asitplus/wallet/lib/agent/SdJwtValidator.kt b/vck/src/commonMain/kotlin/at/asitplus/wallet/lib/agent/SdJwtValidator.kt index 6f26b9eb..69ad86bb 100644 --- a/vck/src/commonMain/kotlin/at/asitplus/wallet/lib/agent/SdJwtValidator.kt +++ b/vck/src/commonMain/kotlin/at/asitplus/wallet/lib/agent/SdJwtValidator.kt @@ -7,10 +7,18 @@ import at.asitplus.wallet.lib.jws.SdJwtSigned import io.matthewnelson.encoding.core.Decoder.Companion.decodeToByteArray import kotlinx.serialization.json.* -class SdJwtValidator { - - private val disclosures: Collection - private val valid = mutableMapOf() +/** + * Decodes a [SdJwtSigned], by substituting all blinded disclosure values (inside `_sd` elements of the payload) + * with the claims of the disclosures appended to the SD-JWT (by a `~`). + * + * See [Selective Disclosure for JWTs (SD-JWT)](https://www.ietf.org/archive/id/draft-ietf-oauth-selective-disclosure-jwt-13.html#name-simple-structured-sd-jwt) + */ +class SdJwtValidator(sdJwtSigned: SdJwtSigned) { + + private val disclosures: Collection = sdJwtSigned.rawDisclosures + private val _validDisclosures = mutableMapOf() + + /** Per 7.1 Verification of the SD-JWT in the spec */ private val filteredClaims = listOf("_sd_alg", "...") /** Map of serialized disclosure item (as [String]) to parsed item (as [SelectiveDisclosureItem]) */ @@ -19,23 +27,22 @@ class SdJwtValidator { /** JSON Object with claim values reconstructed from disclosures */ val reconstructedJsonObject: JsonObject? - constructor(sdJwtSigned: SdJwtSigned) { - disclosures = sdJwtSigned.rawDisclosures + init { reconstructedJsonObject = sdJwtSigned.getPayloadAsJsonObject().getOrNull()?.reconstructValues() - validDisclosures = valid.toMap() + validDisclosures = _validDisclosures.toMap() } private fun JsonObject.reconstructValues(): JsonObject = buildJsonObject { forEach { element -> - val sdArray = element.toSdArray() + val sdArray = element.asSdArray() val jsonObject = element.value as? JsonObject val jsonArray = element.value as? JsonArray if (sdArray != null) { - sdArray.forEach { sdEntry -> sdEntry.toValidatedItem()?.let { processSdItem(it) } } + sdArray.forEach { processSdItem(it) } } else if (jsonObject != null) { putIfNotEmpty(element.key, jsonObject.reconstructValues()) } else if (jsonArray != null) { - putIfNotEmpty(element.key, reconstructJsonArray(jsonArray)) + putIfNotEmpty(element.key, jsonArray.reconstructValues()) } else { if (element.key !in filteredClaims) { put(element.key, element.value) @@ -44,48 +51,50 @@ class SdJwtValidator { } } - private fun reconstructJsonArray(jsonArray: JsonArray) = buildJsonArray { - jsonArray.forEach { entry -> - if (entry is JsonObject) { - entry.asArrayDisclosure()?.let { - it.toValidatedItem()?.let { processSdItem(it) } - } ?: addIfNotEmpty(entry.reconstructValues()) + private fun JsonArray.reconstructValues() = buildJsonArray { + forEach { element -> + val sdArrayEntry = element.asArrayDisclosure() + val jsonObject = element as? JsonObject + if (sdArrayEntry != null) { + processSdItem(sdArrayEntry) + } else if (jsonObject != null) { + addIfNotEmpty(element.reconstructValues()) } else { - add(entry) + add(element) } } } - private fun JsonObject.asArrayDisclosure() = - if (this.size == 1 && this.containsKey("...") && this["..."] is JsonPrimitive) + private fun JsonElement.asArrayDisclosure() = + if (this is JsonObject && this.size == 1 && this["..."] is JsonPrimitive) this["..."] as JsonPrimitive else null - private fun JsonArrayBuilder.processSdItem(sdItem: Pair) { - with(sdItem.second) { - when (claimValue) { - is JsonObject -> add(claimValue.reconstructValues()) - else -> add(claimValue) + private fun JsonArrayBuilder.processSdItem(disclosure: JsonPrimitive) { + disclosure.toValidatedItem()?.let { sdItem -> + when (sdItem.claimValue) { + is JsonObject -> add(sdItem.claimValue.reconstructValues()) + else -> add(sdItem.claimValue) } - valid[sdItem.first] = this } } - private fun JsonObjectBuilder.processSdItem(sdItem: Pair) { - with(sdItem.second) { - when (val element = claimValue) { - is JsonObject -> claimName?.let { putIfNotEmpty(it, element.reconstructValues()) } - else -> claimName?.let { put(it, element) } + private fun JsonObjectBuilder.processSdItem(disclosure: JsonPrimitive) { + disclosure.toValidatedItem()?.let { sdItem -> + when (val element = sdItem.claimValue) { + is JsonObject -> sdItem.claimName?.let { putIfNotEmpty(it, element.reconstructValues()) } + else -> sdItem.claimName?.let { put(it, element) } } - valid[sdItem.first] = this } } - private fun JsonPrimitive.toValidatedItem(): Pair? = - disclosures.firstOrNull { it.hashDisclosure() == this.content } - ?.let { hash -> hash.toSdItem()?.let { hash to it } } + private fun JsonPrimitive.toValidatedItem(): SelectiveDisclosureItem? = + disclosures.firstOrNull { it.hashDisclosure() == this.content }?.let { disclosure -> + disclosure.toSdItem() + ?.also { _validDisclosures[disclosure] = it } + } - private fun Map.Entry.toSdArray(): List? = + private fun Map.Entry.asSdArray(): List? = if (key == "_sd") { kotlin.runCatching { value.jsonArray }.getOrNull() ?.mapNotNull { runCatching { it.jsonPrimitive }.getOrNull() }