From 2b09344b1e6db296b83d331ff6fc8db002747967 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Sun, 3 Mar 2024 14:19:06 -0500 Subject: [PATCH] Simplify: in practice, AAD is always either the `hash` array or `null`, never a `ByteBuffer` or a sub-array --- .../com/eatthepath/noise/CipherState.java | 20 +++-------- .../com/eatthepath/noise/NoiseHandshake.java | 8 ++--- .../eatthepath/noise/NoiseTransportImpl.java | 4 +-- .../noise/component/AbstractNoiseCipher.java | 22 ++++-------- .../noise/component/NoiseCipher.java | 24 +++---------- .../component/AbstractNoiseCipherTest.java | 34 ++++++------------- 6 files changed, 30 insertions(+), 82 deletions(-) diff --git a/src/main/java/com/eatthepath/noise/CipherState.java b/src/main/java/com/eatthepath/noise/CipherState.java index 45399b5..b352a1d 100644 --- a/src/main/java/com/eatthepath/noise/CipherState.java +++ b/src/main/java/com/eatthepath/noise/CipherState.java @@ -39,7 +39,7 @@ public boolean hasKey() { return this.key != null; } - public ByteBuffer decrypt(@Nullable final ByteBuffer associatedData, final ByteBuffer ciphertext) + public ByteBuffer decrypt(@Nullable final byte[] associatedData, final ByteBuffer ciphertext) throws AEADBadTagException { final ByteBuffer plaintext = ByteBuffer.allocate(getPlaintextLength(ciphertext.remaining())); @@ -54,7 +54,7 @@ public ByteBuffer decrypt(@Nullable final ByteBuffer associatedData, final ByteB return plaintext.flip(); } - public int decrypt(@Nullable final ByteBuffer associatedData, final ByteBuffer ciphertext, final ByteBuffer plaintext) + public int decrypt(@Nullable final byte[] associatedData, final ByteBuffer ciphertext, final ByteBuffer plaintext) throws AEADBadTagException, ShortBufferException { if (hasKey()) { @@ -75,8 +75,6 @@ public byte[] decrypt(@Nullable final byte[] associatedData, final byte[] cipher try { decrypt(associatedData, - 0, - associatedData != null ? associatedData.length : 0, ciphertext, 0, ciphertext.length, @@ -91,8 +89,6 @@ public byte[] decrypt(@Nullable final byte[] associatedData, final byte[] cipher } public int decrypt(@Nullable final byte[] associatedData, - final int aadOffset, - final int aadLength, final byte[] ciphertext, final int ciphertextOffset, final int ciphertextLength, @@ -103,8 +99,6 @@ public int decrypt(@Nullable final byte[] associatedData, final int plaintextLength = cipher.decrypt(key, nonce, associatedData, - aadOffset, - aadLength, ciphertext, ciphertextOffset, ciphertextLength, @@ -119,7 +113,7 @@ public int decrypt(@Nullable final byte[] associatedData, } } - public ByteBuffer encrypt(@Nullable final ByteBuffer associatedData, final ByteBuffer plaintext) { + public ByteBuffer encrypt(@Nullable final byte[] associatedData, final ByteBuffer plaintext) { final ByteBuffer ciphertext = ByteBuffer.allocate(getCiphertextLength(plaintext.remaining())); try { @@ -132,7 +126,7 @@ public ByteBuffer encrypt(@Nullable final ByteBuffer associatedData, final ByteB return ciphertext.flip(); } - public int encrypt(@Nullable final ByteBuffer associatedData, final ByteBuffer plaintext, final ByteBuffer ciphertext) throws ShortBufferException { + public int encrypt(@Nullable final byte[] associatedData, final ByteBuffer plaintext, final ByteBuffer ciphertext) throws ShortBufferException { if (hasKey()) { final int ciphertextLength = cipher.encrypt(key, nonce, associatedData, plaintext, ciphertext); nonce += 1; @@ -151,8 +145,6 @@ public byte[] encrypt(@Nullable final byte[] associatedData, final byte[] plaint try { encrypt(associatedData, - 0, - associatedData != null ? associatedData.length : 0, plaintext, 0, plaintext.length, @@ -167,8 +159,6 @@ public byte[] encrypt(@Nullable final byte[] associatedData, final byte[] plaint } public int encrypt(@Nullable final byte[] associatedData, - final int aadOffset, - final int aadLength, final byte[] plaintext, final int plaintextOffset, final int plaintextLength, @@ -179,8 +169,6 @@ public int encrypt(@Nullable final byte[] associatedData, final int ciphertextLength = cipher.encrypt(key, nonce, associatedData, - aadOffset, - aadLength, plaintext, plaintextOffset, plaintextLength, diff --git a/src/main/java/com/eatthepath/noise/NoiseHandshake.java b/src/main/java/com/eatthepath/noise/NoiseHandshake.java index 709a77f..7226048 100644 --- a/src/main/java/com/eatthepath/noise/NoiseHandshake.java +++ b/src/main/java/com/eatthepath/noise/NoiseHandshake.java @@ -418,7 +418,7 @@ private int encryptAndHash(final byte[] plaintext, final int ciphertextOffset) throws ShortBufferException { final int ciphertextLength = - cipherState.encrypt(hash, 0, hash.length, plaintext, plaintextOffset, plaintextLength, ciphertext, ciphertextOffset); + cipherState.encrypt(hash, plaintext, plaintextOffset, plaintextLength, ciphertext, ciphertextOffset); mixHash(ciphertext, ciphertextOffset, ciphertextLength); @@ -426,7 +426,7 @@ private int encryptAndHash(final byte[] plaintext, } private int encryptAndHash(final ByteBuffer plaintext, final ByteBuffer ciphertext) throws ShortBufferException { - final int ciphertextLength = cipherState.encrypt(ByteBuffer.wrap(hash), plaintext, ciphertext); + final int ciphertextLength = cipherState.encrypt(hash, plaintext, ciphertext); mixHash(ciphertext.slice(ciphertext.position() - ciphertextLength, ciphertextLength)); @@ -440,7 +440,7 @@ private int decryptAndHash(final byte[] ciphertext, final int plaintextOffset) throws ShortBufferException, AEADBadTagException { final int plaintextLength = - cipherState.decrypt(hash, 0, hash.length, ciphertext, ciphertextOffset, ciphertextLength, plaintext, plaintextOffset); + cipherState.decrypt(hash, ciphertext, ciphertextOffset, ciphertextLength, plaintext, plaintextOffset); mixHash(ciphertext, ciphertextOffset, ciphertextLength); @@ -451,7 +451,7 @@ private int decryptAndHash(final ByteBuffer ciphertext, final ByteBuffer plaintext) throws ShortBufferException, AEADBadTagException { final int initialCiphertextPosition = ciphertext.position(); - final int plaintextLength = cipherState.decrypt(ByteBuffer.wrap(hash), ciphertext, plaintext); + final int plaintextLength = cipherState.decrypt(hash, ciphertext, plaintext); mixHash(ciphertext.slice(initialCiphertextPosition, ciphertext.position() - initialCiphertextPosition)); diff --git a/src/main/java/com/eatthepath/noise/NoiseTransportImpl.java b/src/main/java/com/eatthepath/noise/NoiseTransportImpl.java index 65de0af..7821390 100644 --- a/src/main/java/com/eatthepath/noise/NoiseTransportImpl.java +++ b/src/main/java/com/eatthepath/noise/NoiseTransportImpl.java @@ -64,7 +64,7 @@ public int readMessage(final byte[] ciphertext, throw new ShortBufferException("Plaintext array after offset is not large enough to hold plaintext"); } - return readerState.decrypt(null, 0, 0, + return readerState.decrypt(null, ciphertext, ciphertextOffset, ciphertextLength, plaintext, plaintextOffset); } @@ -113,7 +113,7 @@ public int writeMessage(final byte[] plaintext, throw new ShortBufferException("Ciphertext array after offset is not large enough to hold ciphertext"); } - return writerState.encrypt(null, 0, 0, + return writerState.encrypt(null, plaintext, plaintextOffset, plaintextLength, ciphertext, ciphertextOffset); } diff --git a/src/main/java/com/eatthepath/noise/component/AbstractNoiseCipher.java b/src/main/java/com/eatthepath/noise/component/AbstractNoiseCipher.java index 59166f9..e6ac94e 100644 --- a/src/main/java/com/eatthepath/noise/component/AbstractNoiseCipher.java +++ b/src/main/java/com/eatthepath/noise/component/AbstractNoiseCipher.java @@ -26,17 +26,14 @@ private interface CipherFinalizer { @Override public int encrypt(final Key key, final long nonce, - @Nullable final ByteBuffer associatedData, + @Nullable final byte[] associatedData, final ByteBuffer plaintext, final ByteBuffer ciphertext) throws ShortBufferException { initCipher(cipher, Cipher.ENCRYPT_MODE, key, nonce); if (associatedData != null) { - final byte[] adBytes = new byte[associatedData.remaining()]; - associatedData.get(adBytes); - - cipher.updateAAD(adBytes); + cipher.updateAAD(associatedData); } return finishEncryption(() -> cipher.doFinal(plaintext, ciphertext)); @@ -46,8 +43,6 @@ public int encrypt(final Key key, public int encrypt(final Key key, final long nonce, @Nullable final byte[] associatedData, - final int aadOffset, - final int aadLength, final byte[] plaintext, final int plaintextOffset, final int plaintextLength, @@ -57,7 +52,7 @@ public int encrypt(final Key key, initCipher(cipher, Cipher.ENCRYPT_MODE, key, nonce); if (associatedData != null) { - cipher.updateAAD(associatedData, aadOffset, aadLength); + cipher.updateAAD(associatedData); } return finishEncryption(() -> @@ -67,17 +62,14 @@ public int encrypt(final Key key, @Override public int decrypt(final Key key, final long nonce, - @Nullable final ByteBuffer associatedData, + @Nullable final byte[] associatedData, final ByteBuffer ciphertext, final ByteBuffer plaintext) throws AEADBadTagException, ShortBufferException { initCipher(cipher, Cipher.DECRYPT_MODE, key, nonce); if (associatedData != null) { - final byte[] adBytes = new byte[associatedData.remaining()]; - associatedData.get(adBytes); - - cipher.updateAAD(adBytes); + cipher.updateAAD(associatedData); } return finishDecryption(() -> cipher.doFinal(ciphertext, plaintext)); @@ -87,8 +79,6 @@ public int decrypt(final Key key, public int decrypt(final Key key, final long nonce, @Nullable final byte[] associatedData, - final int aadOffset, - final int aadLength, final byte[] ciphertext, final int ciphertextOffset, final int ciphertextLength, @@ -98,7 +88,7 @@ public int decrypt(final Key key, initCipher(cipher, Cipher.DECRYPT_MODE, key, nonce); if (associatedData != null) { - cipher.updateAAD(associatedData, aadOffset, aadLength); + cipher.updateAAD(associatedData); } return finishDecryption(() -> diff --git a/src/main/java/com/eatthepath/noise/component/NoiseCipher.java b/src/main/java/com/eatthepath/noise/component/NoiseCipher.java index 7d93100..905ee4a 100644 --- a/src/main/java/com/eatthepath/noise/component/NoiseCipher.java +++ b/src/main/java/com/eatthepath/noise/component/NoiseCipher.java @@ -77,7 +77,7 @@ static NoiseCipher getInstance(final String noiseCipherName) throws NoSuchAlgori */ default ByteBuffer encrypt(final Key key, final long nonce, - @Nullable final ByteBuffer associatedData, + @Nullable final byte[] associatedData, final ByteBuffer plaintext) { final ByteBuffer ciphertext = ByteBuffer.allocate(getCiphertextLength(plaintext.remaining())); @@ -121,7 +121,7 @@ default ByteBuffer encrypt(final Key key, */ int encrypt(final Key key, final long nonce, - @Nullable final ByteBuffer associatedData, + @Nullable final byte[] associatedData, final ByteBuffer plaintext, final ByteBuffer ciphertext) throws ShortBufferException; @@ -150,8 +150,6 @@ default byte[] encrypt(final Key key, encrypt(key, nonce, associatedData, - 0, - associatedData != null ? associatedData.length : 0, plaintext, 0, plaintext.length, @@ -175,10 +173,6 @@ default byte[] encrypt(final Key key, * @param nonce a nonce, which must be unique for the given key * @param associatedData a byte array containing the associated data (if any) to be used when encrypting the given * plaintext; may be {@code null} - * @param aadOffset the position within {@code associatedData} where the associated data starts; ignored if - * {@code associatedData} is {@code null} - * @param aadLength the length of the associated data within {@code associatedData}; ignored if {@code associatedData} - * is {@code null} * @param plaintext a byte array containing the plaintext to encrypt * @param plaintextOffset the offset within {@code plaintext} where the plaintext begins * @param plaintextLength the length of the plaintext within {@code plaintext} @@ -198,8 +192,6 @@ default byte[] encrypt(final Key key, int encrypt(final Key key, final long nonce, @Nullable final byte[] associatedData, - final int aadOffset, - final int aadLength, final byte[] plaintext, final int plaintextOffset, final int plaintextLength, @@ -229,7 +221,7 @@ int encrypt(final Key key, */ default ByteBuffer decrypt(final Key key, final long nonce, - @Nullable final ByteBuffer associatedData, + @Nullable final byte[] associatedData, final ByteBuffer ciphertext) throws AEADBadTagException { final ByteBuffer plaintext = ByteBuffer.allocate(getPlaintextLength(ciphertext.remaining())); @@ -272,7 +264,7 @@ default ByteBuffer decrypt(final Key key, */ int decrypt(final Key key, final long nonce, - @Nullable final ByteBuffer associatedData, + @Nullable final byte[] associatedData, final ByteBuffer ciphertext, final ByteBuffer plaintext) throws AEADBadTagException, ShortBufferException; @@ -304,8 +296,6 @@ default byte[] decrypt(final Key key, decrypt(key, nonce, associatedData, - 0, - associatedData != null ? associatedData.length : 0, ciphertext, 0, ciphertext.length, @@ -330,10 +320,6 @@ default byte[] decrypt(final Key key, * @param nonce a nonce, which must be unique for the given key * @param associatedData a byte array containing the associated data (if any) to be used when verifying the AEAD tag * for the given ciphertext; may be {@code null} - * @param aadOffset the position within {@code associatedData} where the associated data starts; ignored if - * {@code associatedData} is {@code null} - * @param aadLength the length of the associated data within {@code associatedData}; ignored if {@code associatedData} - * is {@code null} * @param ciphertext a byte array containing the ciphertext and AEAD tag to be decrypted and verified * @param ciphertextOffset the position within {@code ciphertext} at which to begin reading the ciphertext and AEAD * tag @@ -353,8 +339,6 @@ default byte[] decrypt(final Key key, int decrypt(final Key key, final long nonce, @Nullable final byte[] associatedData, - final int aadOffset, - final int aadLength, final byte[] ciphertext, final int ciphertextOffset, final int ciphertextLength, diff --git a/src/test/java/com/eatthepath/noise/component/AbstractNoiseCipherTest.java b/src/test/java/com/eatthepath/noise/component/AbstractNoiseCipherTest.java index d12f26e..f011638 100644 --- a/src/test/java/com/eatthepath/noise/component/AbstractNoiseCipherTest.java +++ b/src/test/java/com/eatthepath/noise/component/AbstractNoiseCipherTest.java @@ -47,12 +47,12 @@ void encryptDecryptByteArrayInPlace() throws AEADBadTagException, ShortBufferExc System.arraycopy(plaintextBytes, 0, buffer, 0, plaintextBytes.length); assertEquals(buffer.length, getNoiseCipher().encrypt(key, nonce, - hash, 0, hash.length, + hash, buffer, 0, plaintextBytes.length, buffer, 0)); assertEquals(plaintextBytes.length, getNoiseCipher().decrypt(key, nonce, - hash, 0, hash.length, + hash, buffer, 0, buffer.length, buffer, 0)); @@ -66,39 +66,26 @@ void encryptDecryptByteArrayInPlace() throws AEADBadTagException, ShortBufferExc void encryptDecryptNewByteBuffer() throws AEADBadTagException { final Key key = generateKey(); final long nonce = ThreadLocalRandom.current().nextLong(); - - final ByteBuffer hashBuffer; - { - final byte[] hash = new byte[32]; - ThreadLocalRandom.current().nextBytes(hash); - - hashBuffer = ByteBuffer.wrap(hash); - } + final byte[] hash = new byte[32]; + ThreadLocalRandom.current().nextBytes(hash); final ByteBuffer plaintext = ByteBuffer.wrap("Hark! Plaintext!".getBytes(StandardCharsets.UTF_8)); - final ByteBuffer ciphertext = getNoiseCipher().encrypt(key, nonce, hashBuffer, plaintext); + final ByteBuffer ciphertext = getNoiseCipher().encrypt(key, nonce, hash, plaintext); plaintext.rewind(); - hashBuffer.rewind(); assertEquals(ciphertext.remaining(), getNoiseCipher().getCiphertextLength(plaintext.remaining())); assertEquals(plaintext.remaining(), getNoiseCipher().getPlaintextLength(ciphertext.remaining())); - assertEquals(plaintext, getNoiseCipher().decrypt(key, nonce, hashBuffer, ciphertext)); + assertEquals(plaintext, getNoiseCipher().decrypt(key, nonce, hash, ciphertext)); } @Test void encryptDecryptByteBufferInPlace() throws AEADBadTagException, ShortBufferException { final Key key = generateKey(); final long nonce = ThreadLocalRandom.current().nextLong(); - - final ByteBuffer hashBuffer; - { - final byte[] hash = new byte[32]; - ThreadLocalRandom.current().nextBytes(hash); - - hashBuffer = ByteBuffer.wrap(hash); - } + final byte[] hash = new byte[32]; + ThreadLocalRandom.current().nextBytes(hash); final byte[] plaintextBytes = "Hark! Plaintext!".getBytes(StandardCharsets.UTF_8); final byte[] sharedByteArray = new byte[getNoiseCipher().getCiphertextLength(plaintextBytes.length)]; @@ -111,18 +98,17 @@ void encryptDecryptByteBufferInPlace() throws AEADBadTagException, ShortBufferEx final ByteBuffer ciphertextBuffer = ByteBuffer.wrap(sharedByteArray); assertEquals(sharedByteArray.length, - getNoiseCipher().encrypt(key, nonce, hashBuffer, plaintextBuffer, ciphertextBuffer)); + getNoiseCipher().encrypt(key, nonce, hash, plaintextBuffer, ciphertextBuffer)); assertEquals(plaintextBytes.length, plaintextBuffer.limit()); assertEquals(plaintextBuffer.limit(), plaintextBuffer.position()); assertEquals(sharedByteArray.length, ciphertextBuffer.position()); - hashBuffer.rewind(); plaintextBuffer.rewind(); ciphertextBuffer.rewind(); assertEquals(plaintextBytes.length, - getNoiseCipher().decrypt(key, nonce, hashBuffer, ciphertextBuffer, plaintextBuffer)); + getNoiseCipher().decrypt(key, nonce, hash, ciphertextBuffer, plaintextBuffer)); assertEquals(plaintextBytes.length, plaintextBuffer.limit()); assertEquals(plaintextBuffer.limit(), plaintextBuffer.position());