diff --git a/Sources/eudiWalletOidcIos/Service/CredentialValidation/CredentialValidatorService.swift b/Sources/eudiWalletOidcIos/Service/CredentialValidation/CredentialValidatorService.swift index c5a40bc..b04ea5c 100644 --- a/Sources/eudiWalletOidcIos/Service/CredentialValidation/CredentialValidatorService.swift +++ b/Sources/eudiWalletOidcIos/Service/CredentialValidation/CredentialValidatorService.swift @@ -9,6 +9,7 @@ import Base58Swift public enum ValidationError: Error { case JWTExpired case signatureExpired + case invalidKID } public class CredentialValidatorService: CredentialValidaorProtocol { public static var shared = CredentialValidatorService() @@ -16,7 +17,7 @@ public class CredentialValidatorService: CredentialValidaorProtocol { public func validateCredential(jwt: String?, jwksURI: String?, format: String) async throws { let isJWTExpired = ExpiryValidator.validateExpiryDate(jwt: jwt, format: format) ?? false - let isSignatureExpied = await SignatureValidator.validateSign(jwt: jwt, jwksURI: jwksURI, format: format) ?? false + let isSignatureExpied = try await SignatureValidator.validateSign(jwt: jwt, jwksURI: jwksURI, format: format) ?? false if isJWTExpired { throw ValidationError.JWTExpired } diff --git a/Sources/eudiWalletOidcIos/Service/CredentialValidation/SignatureValidator.swift b/Sources/eudiWalletOidcIos/Service/CredentialValidation/SignatureValidator.swift index a350d23..c602880 100644 --- a/Sources/eudiWalletOidcIos/Service/CredentialValidation/SignatureValidator.swift +++ b/Sources/eudiWalletOidcIos/Service/CredentialValidation/SignatureValidator.swift @@ -12,7 +12,7 @@ import CryptoKit class SignatureValidator { - static func validateSign(jwt: String?, jwksURI: String?, format: String) async -> Bool? { + static func validateSign(jwt: String?, jwksURI: String?, format: String) async throws-> Bool? { var jwk: [String: Any] = [:] if format == "mso_mdoc" { return true @@ -20,7 +20,15 @@ class SignatureValidator { guard let split = jwt?.split(separator: "."), split.count > 1 else { return true} guard let jsonString = "\(split[0])".decodeBase64(), let jsonObject = UIApplicationUtils.shared.convertStringToDictionary(text: jsonString) else { return false } - if var kid = jsonObject["kid"] as? String { + if (jsonObject["kid"] as? String)?.isEmpty ?? true, let x5cList = jsonObject["x5c"] as? [String]{ + if let x5cList = extractX5C(data: jsonObject) { + if X509SanRequestVerifier.shared.validateSignatureWithCertificate(jwt: jwt ?? "", x5cChain: x5cList) { + return true + } else { + throw ValidationError.invalidKID + } + } + } else if var kid = jsonObject["kid"] as? String { if kid.hasPrefix("did:jwk:") { if let parsedJWK = ProcessJWKFromKID.parseDIDJWK(kid) { jwk = parsedJWK @@ -152,4 +160,12 @@ class SignatureValidator { return false } } + + static func extractX5C(data: [String: Any]?) -> [String]?{ + var x5cData: [String] = [] + if let data = data, let x5cArray = data["x5c"] as? [String]{ + x5cData = x5cArray + } + return x5cData + } } diff --git a/Sources/eudiWalletOidcIos/Service/X509SanRequestVerifier.swift b/Sources/eudiWalletOidcIos/Service/X509SanRequestVerifier.swift index 838ac4d..7968f37 100644 --- a/Sources/eudiWalletOidcIos/Service/X509SanRequestVerifier.swift +++ b/Sources/eudiWalletOidcIos/Service/X509SanRequestVerifier.swift @@ -126,13 +126,28 @@ public class X509SanRequestVerifier { } let signedData = segments[0] + "." + segments[1] - let b64 = base64UrlToBase64(String(segments[2])) - guard let signature = Data(base64Encoded: b64) else { + var sigatureData: String = "" + if segments[2].contains("~") { + let splitData = segments[2].split(separator: "~") + sigatureData = String(splitData[0]) + } else { + sigatureData = String(segments[2]) + } + let b64 = base64UrlToBase64(sigatureData) + guard let rawSignature = Data(base64Encoded: b64) else { print("Failed to decode JWT signature") return false } + guard let jsonString = "\(segments[0])".decodeBase64(), + let jsonObject = UIApplicationUtils.shared.convertStringToDictionary(text: jsonString), let alg = jsonObject["alg"] as? String, let algorithm = mapJWTAlgToSecKeyAlgorithm(alg: alg) else { return false } + let signature: Data + if alg.starts(with: "ES"), let convertedSignature = convertRawSignatureToASN1DER(rawSignature) { + signature = convertedSignature + } else { + signature = rawSignature + } - return verifySignature(publicKey: publicKey, data: String(signedData), signature: signature, algorithm: .rsaSignatureMessagePKCS1v15SHA256) + return verifySignature(publicKey: publicKey, data: String(signedData), signature: signature, algorithm: algorithm) } func base64UrlToBase64(_ base64Url: String) -> String { @@ -206,4 +221,51 @@ public class X509SanRequestVerifier { return isValid } + func mapJWTAlgToSecKeyAlgorithm(alg: String) -> SecKeyAlgorithm? { + switch alg { + case "RS256": + return .rsaSignatureMessagePKCS1v15SHA256 + case "RS384": + return .rsaSignatureMessagePKCS1v15SHA384 + case "RS512": + return .rsaSignatureMessagePKCS1v15SHA512 + case "ES256": + return .ecdsaSignatureMessageX962SHA256 + case "ES384": + return .ecdsaSignatureMessageX962SHA384 + case "ES512": + return .ecdsaSignatureMessageX962SHA512 + default: + return nil + } + } + + func convertRawSignatureToASN1DER(_ rawSignature: Data) -> Data? { + let halfLength = rawSignature.count / 2 + let r = rawSignature.prefix(halfLength) + let s = rawSignature.suffix(halfLength) + func asn1Length(_ length: Int) -> Data { + if length < 128 { + return Data([UInt8(length)]) + } + var lengthBytes = withUnsafeBytes(of: length.bigEndian, Array.init) + while lengthBytes.first == 0 { + lengthBytes.removeFirst() + } + return Data([0x80 | UInt8(lengthBytes.count)] + lengthBytes) + } + func asn1Integer(_ data: Data) -> Data { + var bytes = Array(data) + // Add a leading zero if the MSB is set (to avoid it being interpreted as negative) + if let first = bytes.first, first & 0x80 != 0 { + bytes.insert(0, at: 0) + } + return Data([0x02, UInt8(bytes.count)] + bytes) + } + let asn1R = asn1Integer(r) + let asn1S = asn1Integer(s) + let asn1Sequence = Data([0x30]) + asn1Length(asn1R.count + asn1S.count) + asn1R + asn1S + return asn1Sequence + } + }