Skip to content

Commit

Permalink
cleanup rsa public keys
Browse files Browse the repository at this point in the history
  • Loading branch information
iaik-jheher committed Nov 14, 2024
1 parent f79ecf3 commit c532f9f
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ sealed class Asn1Integer(internal val uint: VarUInt, private val sign: Sign) {
override fun twosComplement(): ByteArray = uint.bytes.let {
if (it.first().countLeadingZeroBits() == 0) listOf(0.toUByte()) + it else it
}.toUByteArray().toByteArray()
fun bitLength(): UInt = uint.bitLength().toUInt()
val magnitude = uint.bytes.toUByteArray().toByteArray()
}

class Negative internal constructor(uint: VarUInt) : Asn1Integer(uint, Sign.NEGATIVE) {
Expand Down Expand Up @@ -91,6 +93,8 @@ sealed class Asn1Integer(internal val uint: VarUInt, private val sign: Sign) {
else throw IllegalArgumentException("NaN: $input")
}

fun fromUnsignedByteArray(input: ByteArray) = Positive(VarUInt(input))

fun fromTwosComplement(input: ByteArray): Asn1Integer =
if (input.first() < 0) {
Negative(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ fun CryptoPublicKey.toCoseKey(algorithm: CoseAlgorithm? = null, keyId: ByteArray
else catching {
CoseKey(
keyParams = CoseKeyParams.RsaParams(
n = n,
e = e.toTwosComplementByteArray()
n = n.magnitude,
e = e.magnitude
),
type = CoseKeyType.RSA,
keyId = didEncoded.encodeToByteArray(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ fun CryptoPublicKey.toJsonWebKey(keyId: String? = this.jwkId): JsonWebKey =
JsonWebKey(
type = JwkType.RSA,
keyId = keyId,
n = n,
e = e.toTwosComplementByteArray()
n = n.magnitude,
e = e.magnitude
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,52 +165,36 @@ sealed class CryptoPublicKey : Asn1Encodable<Asn1Sequence>, Identifiable {

}

/**
* RSA Public key
*/
@Serializable
/** RSA Public key */
@ConsistentCopyVisibility
data class RSA
@Throws(IllegalArgumentException::class)
private constructor(
/**
* RSA key size
*/
val bits: Size,

constructor(
/**
* modulus
*/
@Serializable(with = ByteArrayBase64Serializer::class) val n: ByteArray,
val n: Asn1Integer.Positive,

/**
* public exponent
*/
val e: Int,
val e: Asn1Integer.Positive,
) : CryptoPublicKey() {

init {
val computed = Size.of(n)
if (bits != computed) throw IllegalArgumentException("Provided number of bits (${bits.number}) does not match computed number of bits (${computed.number})")
}

@Throws(IllegalArgumentException::class)
private constructor(params: RsaParams) : this(
params.size,
params.n,
params.e
)
val bits = n.bitLength().let { Size.of(it) ?: throw IllegalArgumentException("Unsupported key size $it bits") }

/**
* @throws IllegalArgumentException in case of illegal input (odd key size, for example)
*/
@Throws(IllegalArgumentException::class)
constructor(n: ByteArray, e: Int) : this(sanitizeRsaInputs(n, e))
constructor(n: ByteArray, e: ByteArray) : this(Asn1Integer.fromUnsignedByteArray(n), Asn1Integer.fromUnsignedByteArray(e))

constructor(n: ByteArray, e: Int): this(Asn1Integer.fromUnsignedByteArray(n), Asn1Integer(e) as Asn1Integer.Positive)

override val oid = RSA.oid

/**
* enum of supported RSA key sized. For sanity checks!
* enum of supported RSA key sizes. For sanity checks!
*/
enum class Size(val number: UInt) {
RSA_512(512u),
Expand All @@ -222,13 +206,6 @@ sealed class CryptoPublicKey : Asn1Encodable<Asn1Sequence>, Identifiable {
companion object : Identifiable {
fun of(numBits: UInt) = entries.find { it.number == numBits }

@Throws(IllegalArgumentException::class)
fun of(n: ByteArray): Size {
val nTruncSize = n.dropWhile { it == 0.toByte() }.size
return entries.find { nTruncSize == (it.number.toInt() / 8) }
?: throw IllegalArgumentException("Unsupported key size $nTruncSize")
}

override val oid = KnownOIDs.rsaEncryption
}
}
Expand All @@ -249,11 +226,7 @@ sealed class CryptoPublicKey : Asn1Encodable<Asn1Sequence>, Identifiable {
*/
val pkcsEncoded by lazy {
Asn1.Sequence {
+Asn1Primitive(
Asn1Element.Tag.INT,
n.ensureSize(bits.number / 8u)
.let { if (it.first() == 0x00.toByte()) it else byteArrayOf(0x00, *it) })

+Asn1.Int(n)
+Asn1.Int(e)
}.derEncoded
}
Expand All @@ -269,8 +242,7 @@ sealed class CryptoPublicKey : Asn1Encodable<Asn1Sequence>, Identifiable {
}

override fun hashCode(): Int {
var result = bits.hashCode()
result = 31 * result + n.contentHashCode()
var result = n.hashCode()
result = 31 * result + e.hashCode()
return result
}
Expand All @@ -284,10 +256,10 @@ sealed class CryptoPublicKey : Asn1Encodable<Asn1Sequence>, Identifiable {
@Throws(Asn1Exception::class)
fun fromPKCS1encoded(input: ByteArray): RSA = runRethrowing {
val conv = Asn1Element.parse(input) as Asn1Sequence
val n = (conv.nextChild() as Asn1Primitive).decode(Asn1Element.Tag.INT) { it }
val e = (conv.nextChild() as Asn1Primitive).decodeToInt()
val n = (conv.nextChild() as Asn1Primitive).decodeToAsn1Integer() as Asn1Integer.Positive
val e = (conv.nextChild() as Asn1Primitive).decodeToAsn1Integer() as Asn1Integer.Positive
if (conv.hasMoreChildren()) throw Asn1StructuralException("Superfluous bytes")
return RSA(Size.of(n), n, e)
return RSA(n, e)
}

override val oid = KnownOIDs.rsaEncryption
Expand Down Expand Up @@ -441,23 +413,6 @@ fun CryptoPublicKey.equalsCryptographically(other: SpecializedCryptoPublicKey) =
other.equalsCryptographically(this)


//Helper typealias, for helper sanitization function. Enables passing all params along constructors for constructor chaining
private typealias RsaParams = Triple<ByteArray, Int, CryptoPublicKey.RSA.Size>

private val RsaParams.n get() = first
private val RsaParams.e get() = second
private val RsaParams.size get() = third

/**
* Sanitizes RSA parameters and maps it to the correct [CryptoPublicKey.RSA.Size] enum
* This function lives here and returns a typealiased Triple to allow for constructor chaining.
* If we were to change the primary constructor, we'd need to write a custom serializer
*/
@Throws(IllegalArgumentException::class)
private fun sanitizeRsaInputs(n: ByteArray, e: Int): RsaParams = n.dropWhile { it == 0.toByte() }.toByteArray()
.let { Triple(byteArrayOf(0, *it), e, CryptoPublicKey.RSA.Size.of(it)) }


private val PREFIX_DID_KEY = "did:key"

@Throws(Throwable::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ private val rsaFactory = KeyFactory.getInstance("RSA")

fun CryptoPublicKey.RSA.getJcaPublicKey(): KmmResult<RSAPublicKey> = catching {
rsaFactory.generatePublic(
RSAPublicKeySpec(BigInteger(1, n), BigInteger.valueOf(e.toLong()))
RSAPublicKeySpec(BigInteger(1, n.magnitude), BigInteger(1, e.magnitude))
) as RSAPublicKey
}

Expand All @@ -159,7 +159,7 @@ fun CryptoPublicKey.EC.Companion.fromJcaPublicKey(publicKey: ECPublicKey): KmmRe
}

fun CryptoPublicKey.RSA.Companion.fromJcaPublicKey(publicKey: RSAPublicKey): KmmResult<CryptoPublicKey> =
catching { CryptoPublicKey.RSA(publicKey.modulus.toByteArray(), publicKey.publicExponent.toInt()) }
catching { CryptoPublicKey.RSA(publicKey.modulus.toByteArray(), publicKey.publicExponent.toByteArray()) }

fun CryptoPublicKey.Companion.fromJcaPublicKey(publicKey: PublicKey): KmmResult<CryptoPublicKey> =
when (publicKey) {
Expand Down

0 comments on commit c532f9f

Please sign in to comment.