Skip to content

Commit

Permalink
Simplify: in practice, AAD is always either the hash array or `null…
Browse files Browse the repository at this point in the history
…`, never a `ByteBuffer` or a sub-array
  • Loading branch information
jchambers committed Mar 3, 2024
1 parent 0edd2e0 commit 2b09344
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 82 deletions.
20 changes: 4 additions & 16 deletions src/main/java/com/eatthepath/noise/CipherState.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public boolean hasKey() {
return this.key != null;
}

public ByteBuffer decrypt(@Nullable final ByteBuffer associatedData, final ByteBuffer ciphertext)
public ByteBuffer decrypt(@Nullable final byte[] associatedData, final ByteBuffer ciphertext)
throws AEADBadTagException {

final ByteBuffer plaintext = ByteBuffer.allocate(getPlaintextLength(ciphertext.remaining()));
Expand All @@ -54,7 +54,7 @@ public ByteBuffer decrypt(@Nullable final ByteBuffer associatedData, final ByteB
return plaintext.flip();
}

public int decrypt(@Nullable final ByteBuffer associatedData, final ByteBuffer ciphertext, final ByteBuffer plaintext)
public int decrypt(@Nullable final byte[] associatedData, final ByteBuffer ciphertext, final ByteBuffer plaintext)
throws AEADBadTagException, ShortBufferException {

if (hasKey()) {
Expand All @@ -75,8 +75,6 @@ public byte[] decrypt(@Nullable final byte[] associatedData, final byte[] cipher

try {
decrypt(associatedData,
0,
associatedData != null ? associatedData.length : 0,
ciphertext,
0,
ciphertext.length,
Expand All @@ -91,8 +89,6 @@ public byte[] decrypt(@Nullable final byte[] associatedData, final byte[] cipher
}

public int decrypt(@Nullable final byte[] associatedData,
final int aadOffset,
final int aadLength,
final byte[] ciphertext,
final int ciphertextOffset,
final int ciphertextLength,
Expand All @@ -103,8 +99,6 @@ public int decrypt(@Nullable final byte[] associatedData,
final int plaintextLength = cipher.decrypt(key,
nonce,
associatedData,
aadOffset,
aadLength,
ciphertext,
ciphertextOffset,
ciphertextLength,
Expand All @@ -119,7 +113,7 @@ public int decrypt(@Nullable final byte[] associatedData,
}
}

public ByteBuffer encrypt(@Nullable final ByteBuffer associatedData, final ByteBuffer plaintext) {
public ByteBuffer encrypt(@Nullable final byte[] associatedData, final ByteBuffer plaintext) {
final ByteBuffer ciphertext = ByteBuffer.allocate(getCiphertextLength(plaintext.remaining()));

try {
Expand All @@ -132,7 +126,7 @@ public ByteBuffer encrypt(@Nullable final ByteBuffer associatedData, final ByteB
return ciphertext.flip();
}

public int encrypt(@Nullable final ByteBuffer associatedData, final ByteBuffer plaintext, final ByteBuffer ciphertext) throws ShortBufferException {
public int encrypt(@Nullable final byte[] associatedData, final ByteBuffer plaintext, final ByteBuffer ciphertext) throws ShortBufferException {
if (hasKey()) {
final int ciphertextLength = cipher.encrypt(key, nonce, associatedData, plaintext, ciphertext);
nonce += 1;
Expand All @@ -151,8 +145,6 @@ public byte[] encrypt(@Nullable final byte[] associatedData, final byte[] plaint

try {
encrypt(associatedData,
0,
associatedData != null ? associatedData.length : 0,
plaintext,
0,
plaintext.length,
Expand All @@ -167,8 +159,6 @@ public byte[] encrypt(@Nullable final byte[] associatedData, final byte[] plaint
}

public int encrypt(@Nullable final byte[] associatedData,
final int aadOffset,
final int aadLength,
final byte[] plaintext,
final int plaintextOffset,
final int plaintextLength,
Expand All @@ -179,8 +169,6 @@ public int encrypt(@Nullable final byte[] associatedData,
final int ciphertextLength = cipher.encrypt(key,
nonce,
associatedData,
aadOffset,
aadLength,
plaintext,
plaintextOffset,
plaintextLength,
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/com/eatthepath/noise/NoiseHandshake.java
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,15 @@ private int encryptAndHash(final byte[] plaintext,
final int ciphertextOffset) throws ShortBufferException {

final int ciphertextLength =
cipherState.encrypt(hash, 0, hash.length, plaintext, plaintextOffset, plaintextLength, ciphertext, ciphertextOffset);
cipherState.encrypt(hash, plaintext, plaintextOffset, plaintextLength, ciphertext, ciphertextOffset);

mixHash(ciphertext, ciphertextOffset, ciphertextLength);

return ciphertextLength;
}

private int encryptAndHash(final ByteBuffer plaintext, final ByteBuffer ciphertext) throws ShortBufferException {
final int ciphertextLength = cipherState.encrypt(ByteBuffer.wrap(hash), plaintext, ciphertext);
final int ciphertextLength = cipherState.encrypt(hash, plaintext, ciphertext);

mixHash(ciphertext.slice(ciphertext.position() - ciphertextLength, ciphertextLength));

Expand All @@ -440,7 +440,7 @@ private int decryptAndHash(final byte[] ciphertext,
final int plaintextOffset) throws ShortBufferException, AEADBadTagException {

final int plaintextLength =
cipherState.decrypt(hash, 0, hash.length, ciphertext, ciphertextOffset, ciphertextLength, plaintext, plaintextOffset);
cipherState.decrypt(hash, ciphertext, ciphertextOffset, ciphertextLength, plaintext, plaintextOffset);

mixHash(ciphertext, ciphertextOffset, ciphertextLength);

Expand All @@ -451,7 +451,7 @@ private int decryptAndHash(final ByteBuffer ciphertext,
final ByteBuffer plaintext) throws ShortBufferException, AEADBadTagException {

final int initialCiphertextPosition = ciphertext.position();
final int plaintextLength = cipherState.decrypt(ByteBuffer.wrap(hash), ciphertext, plaintext);
final int plaintextLength = cipherState.decrypt(hash, ciphertext, plaintext);

mixHash(ciphertext.slice(initialCiphertextPosition, ciphertext.position() - initialCiphertextPosition));

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/eatthepath/noise/NoiseTransportImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public int readMessage(final byte[] ciphertext,
throw new ShortBufferException("Plaintext array after offset is not large enough to hold plaintext");
}

return readerState.decrypt(null, 0, 0,
return readerState.decrypt(null,
ciphertext, ciphertextOffset, ciphertextLength,
plaintext, plaintextOffset);
}
Expand Down Expand Up @@ -113,7 +113,7 @@ public int writeMessage(final byte[] plaintext,
throw new ShortBufferException("Ciphertext array after offset is not large enough to hold ciphertext");
}

return writerState.encrypt(null, 0, 0,
return writerState.encrypt(null,
plaintext, plaintextOffset, plaintextLength,
ciphertext, ciphertextOffset);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,14 @@ private interface CipherFinalizer<T> {
@Override
public int encrypt(final Key key,
final long nonce,
@Nullable final ByteBuffer associatedData,
@Nullable final byte[] associatedData,
final ByteBuffer plaintext,
final ByteBuffer ciphertext) throws ShortBufferException {

initCipher(cipher, Cipher.ENCRYPT_MODE, key, nonce);

if (associatedData != null) {
final byte[] adBytes = new byte[associatedData.remaining()];
associatedData.get(adBytes);

cipher.updateAAD(adBytes);
cipher.updateAAD(associatedData);
}

return finishEncryption(() -> cipher.doFinal(plaintext, ciphertext));
Expand All @@ -46,8 +43,6 @@ public int encrypt(final Key key,
public int encrypt(final Key key,
final long nonce,
@Nullable final byte[] associatedData,
final int aadOffset,
final int aadLength,
final byte[] plaintext,
final int plaintextOffset,
final int plaintextLength,
Expand All @@ -57,7 +52,7 @@ public int encrypt(final Key key,
initCipher(cipher, Cipher.ENCRYPT_MODE, key, nonce);

if (associatedData != null) {
cipher.updateAAD(associatedData, aadOffset, aadLength);
cipher.updateAAD(associatedData);
}

return finishEncryption(() ->
Expand All @@ -67,17 +62,14 @@ public int encrypt(final Key key,
@Override
public int decrypt(final Key key,
final long nonce,
@Nullable final ByteBuffer associatedData,
@Nullable final byte[] associatedData,
final ByteBuffer ciphertext,
final ByteBuffer plaintext) throws AEADBadTagException, ShortBufferException {

initCipher(cipher, Cipher.DECRYPT_MODE, key, nonce);

if (associatedData != null) {
final byte[] adBytes = new byte[associatedData.remaining()];
associatedData.get(adBytes);

cipher.updateAAD(adBytes);
cipher.updateAAD(associatedData);
}

return finishDecryption(() -> cipher.doFinal(ciphertext, plaintext));
Expand All @@ -87,8 +79,6 @@ public int decrypt(final Key key,
public int decrypt(final Key key,
final long nonce,
@Nullable final byte[] associatedData,
final int aadOffset,
final int aadLength,
final byte[] ciphertext,
final int ciphertextOffset,
final int ciphertextLength,
Expand All @@ -98,7 +88,7 @@ public int decrypt(final Key key,
initCipher(cipher, Cipher.DECRYPT_MODE, key, nonce);

if (associatedData != null) {
cipher.updateAAD(associatedData, aadOffset, aadLength);
cipher.updateAAD(associatedData);
}

return finishDecryption(() ->
Expand Down
24 changes: 4 additions & 20 deletions src/main/java/com/eatthepath/noise/component/NoiseCipher.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ static NoiseCipher getInstance(final String noiseCipherName) throws NoSuchAlgori
*/
default ByteBuffer encrypt(final Key key,
final long nonce,
@Nullable final ByteBuffer associatedData,
@Nullable final byte[] associatedData,
final ByteBuffer plaintext) {

final ByteBuffer ciphertext = ByteBuffer.allocate(getCiphertextLength(plaintext.remaining()));
Expand Down Expand Up @@ -121,7 +121,7 @@ default ByteBuffer encrypt(final Key key,
*/
int encrypt(final Key key,
final long nonce,
@Nullable final ByteBuffer associatedData,
@Nullable final byte[] associatedData,
final ByteBuffer plaintext,
final ByteBuffer ciphertext)
throws ShortBufferException;
Expand Down Expand Up @@ -150,8 +150,6 @@ default byte[] encrypt(final Key key,
encrypt(key,
nonce,
associatedData,
0,
associatedData != null ? associatedData.length : 0,
plaintext,
0,
plaintext.length,
Expand All @@ -175,10 +173,6 @@ default byte[] encrypt(final Key key,
* @param nonce a nonce, which must be unique for the given key
* @param associatedData a byte array containing the associated data (if any) to be used when encrypting the given
* plaintext; may be {@code null}
* @param aadOffset the position within {@code associatedData} where the associated data starts; ignored if
* {@code associatedData} is {@code null}
* @param aadLength the length of the associated data within {@code associatedData}; ignored if {@code associatedData}
* is {@code null}
* @param plaintext a byte array containing the plaintext to encrypt
* @param plaintextOffset the offset within {@code plaintext} where the plaintext begins
* @param plaintextLength the length of the plaintext within {@code plaintext}
Expand All @@ -198,8 +192,6 @@ default byte[] encrypt(final Key key,
int encrypt(final Key key,
final long nonce,
@Nullable final byte[] associatedData,
final int aadOffset,
final int aadLength,
final byte[] plaintext,
final int plaintextOffset,
final int plaintextLength,
Expand Down Expand Up @@ -229,7 +221,7 @@ int encrypt(final Key key,
*/
default ByteBuffer decrypt(final Key key,
final long nonce,
@Nullable final ByteBuffer associatedData,
@Nullable final byte[] associatedData,
final ByteBuffer ciphertext) throws AEADBadTagException {

final ByteBuffer plaintext = ByteBuffer.allocate(getPlaintextLength(ciphertext.remaining()));
Expand Down Expand Up @@ -272,7 +264,7 @@ default ByteBuffer decrypt(final Key key,
*/
int decrypt(final Key key,
final long nonce,
@Nullable final ByteBuffer associatedData,
@Nullable final byte[] associatedData,
final ByteBuffer ciphertext,
final ByteBuffer plaintext)
throws AEADBadTagException, ShortBufferException;
Expand Down Expand Up @@ -304,8 +296,6 @@ default byte[] decrypt(final Key key,
decrypt(key,
nonce,
associatedData,
0,
associatedData != null ? associatedData.length : 0,
ciphertext,
0,
ciphertext.length,
Expand All @@ -330,10 +320,6 @@ default byte[] decrypt(final Key key,
* @param nonce a nonce, which must be unique for the given key
* @param associatedData a byte array containing the associated data (if any) to be used when verifying the AEAD tag
* for the given ciphertext; may be {@code null}
* @param aadOffset the position within {@code associatedData} where the associated data starts; ignored if
* {@code associatedData} is {@code null}
* @param aadLength the length of the associated data within {@code associatedData}; ignored if {@code associatedData}
* is {@code null}
* @param ciphertext a byte array containing the ciphertext and AEAD tag to be decrypted and verified
* @param ciphertextOffset the position within {@code ciphertext} at which to begin reading the ciphertext and AEAD
* tag
Expand All @@ -353,8 +339,6 @@ default byte[] decrypt(final Key key,
int decrypt(final Key key,
final long nonce,
@Nullable final byte[] associatedData,
final int aadOffset,
final int aadLength,
final byte[] ciphertext,
final int ciphertextOffset,
final int ciphertextLength,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ void encryptDecryptByteArrayInPlace() throws AEADBadTagException, ShortBufferExc
System.arraycopy(plaintextBytes, 0, buffer, 0, plaintextBytes.length);

assertEquals(buffer.length, getNoiseCipher().encrypt(key, nonce,
hash, 0, hash.length,
hash,
buffer, 0, plaintextBytes.length,
buffer, 0));

assertEquals(plaintextBytes.length, getNoiseCipher().decrypt(key, nonce,
hash, 0, hash.length,
hash,
buffer, 0, buffer.length,
buffer, 0));

Expand All @@ -66,39 +66,26 @@ void encryptDecryptByteArrayInPlace() throws AEADBadTagException, ShortBufferExc
void encryptDecryptNewByteBuffer() throws AEADBadTagException {
final Key key = generateKey();
final long nonce = ThreadLocalRandom.current().nextLong();

final ByteBuffer hashBuffer;
{
final byte[] hash = new byte[32];
ThreadLocalRandom.current().nextBytes(hash);

hashBuffer = ByteBuffer.wrap(hash);
}
final byte[] hash = new byte[32];
ThreadLocalRandom.current().nextBytes(hash);

final ByteBuffer plaintext = ByteBuffer.wrap("Hark! Plaintext!".getBytes(StandardCharsets.UTF_8));
final ByteBuffer ciphertext = getNoiseCipher().encrypt(key, nonce, hashBuffer, plaintext);
final ByteBuffer ciphertext = getNoiseCipher().encrypt(key, nonce, hash, plaintext);

plaintext.rewind();
hashBuffer.rewind();

assertEquals(ciphertext.remaining(), getNoiseCipher().getCiphertextLength(plaintext.remaining()));
assertEquals(plaintext.remaining(), getNoiseCipher().getPlaintextLength(ciphertext.remaining()));

assertEquals(plaintext, getNoiseCipher().decrypt(key, nonce, hashBuffer, ciphertext));
assertEquals(plaintext, getNoiseCipher().decrypt(key, nonce, hash, ciphertext));
}

@Test
void encryptDecryptByteBufferInPlace() throws AEADBadTagException, ShortBufferException {
final Key key = generateKey();
final long nonce = ThreadLocalRandom.current().nextLong();

final ByteBuffer hashBuffer;
{
final byte[] hash = new byte[32];
ThreadLocalRandom.current().nextBytes(hash);

hashBuffer = ByteBuffer.wrap(hash);
}
final byte[] hash = new byte[32];
ThreadLocalRandom.current().nextBytes(hash);

final byte[] plaintextBytes = "Hark! Plaintext!".getBytes(StandardCharsets.UTF_8);
final byte[] sharedByteArray = new byte[getNoiseCipher().getCiphertextLength(plaintextBytes.length)];
Expand All @@ -111,18 +98,17 @@ void encryptDecryptByteBufferInPlace() throws AEADBadTagException, ShortBufferEx
final ByteBuffer ciphertextBuffer = ByteBuffer.wrap(sharedByteArray);

assertEquals(sharedByteArray.length,
getNoiseCipher().encrypt(key, nonce, hashBuffer, plaintextBuffer, ciphertextBuffer));
getNoiseCipher().encrypt(key, nonce, hash, plaintextBuffer, ciphertextBuffer));

assertEquals(plaintextBytes.length, plaintextBuffer.limit());
assertEquals(plaintextBuffer.limit(), plaintextBuffer.position());
assertEquals(sharedByteArray.length, ciphertextBuffer.position());

hashBuffer.rewind();
plaintextBuffer.rewind();
ciphertextBuffer.rewind();

assertEquals(plaintextBytes.length,
getNoiseCipher().decrypt(key, nonce, hashBuffer, ciphertextBuffer, plaintextBuffer));
getNoiseCipher().decrypt(key, nonce, hash, ciphertextBuffer, plaintextBuffer));

assertEquals(plaintextBytes.length, plaintextBuffer.limit());
assertEquals(plaintextBuffer.limit(), plaintextBuffer.position());
Expand Down

0 comments on commit 2b09344

Please sign in to comment.