diff --git a/src/main/java/com/eatthepath/noise/NoiseHandshake.java b/src/main/java/com/eatthepath/noise/NoiseHandshake.java index febb6ea..709a77f 100644 --- a/src/main/java/com/eatthepath/noise/NoiseHandshake.java +++ b/src/main/java/com/eatthepath/noise/NoiseHandshake.java @@ -136,6 +136,7 @@ public class NoiseHandshake { private int currentMessagePattern = 0; private boolean hasSplit = false; + private boolean hasFallenBack = false; private final CipherState cipherState; private final NoiseHash noiseHash; @@ -478,6 +479,10 @@ public boolean isOneWayHandshake() { * @see #isDone() */ public boolean isExpectingRead() { + if (hasFallenBack) { + return false; + } + if (currentMessagePattern < handshakePattern.getHandshakeMessagePatterns().length) { return handshakePattern.getHandshakeMessagePatterns()[currentMessagePattern].sender() != role; } @@ -497,6 +502,10 @@ public boolean isExpectingRead() { * @see #isDone() */ public boolean isExpectingWrite() { + if (hasFallenBack) { + return false; + } + if (currentMessagePattern < handshakePattern.getHandshakeMessagePatterns().length) { return handshakePattern.getHandshakeMessagePatterns()[currentMessagePattern].sender() == role; } @@ -514,6 +523,10 @@ public boolean isExpectingWrite() { * @see #isExpectingWrite() */ public boolean isDone() { + if (hasFallenBack) { + return false; + } + return currentMessagePattern == handshakePattern.getHandshakeMessagePatterns().length; } @@ -1246,7 +1259,6 @@ private void handleMixKeyToken(final HandshakePattern.Token token) { * @see HandshakePattern#isFallbackPattern() */ public NoiseHandshake fallbackTo(final String handshakePatternName) throws NoSuchPatternException { - // TODO Self-destruct after falling back return fallbackTo(handshakePatternName, null); } @@ -1271,6 +1283,10 @@ public NoiseHandshake fallbackTo(final String handshakePatternName) throws NoSuc * @see HandshakePattern#isFallbackPattern() */ public NoiseHandshake fallbackTo(final String handshakePatternName, @Nullable final List preSharedKeys) throws NoSuchPatternException { + if (hasFallenBack) { + throw new IllegalStateException("Handshake has already fallen back to another pattern"); + } + final HandshakePattern fallbackPattern = HandshakePattern.getInstance(handshakePatternName); if (!fallbackPattern.isFallbackPattern()) { @@ -1313,6 +1329,8 @@ public NoiseHandshake fallbackTo(final String handshakePatternName, @Nullable fi fallbackRemoteEphemeralPublicKey = null; } + hasFallenBack = true; + return new NoiseHandshake(role, fallbackPattern, keyAgreement, diff --git a/src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java b/src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java index 25abe76..6ea8ed9 100644 --- a/src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java +++ b/src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java @@ -3,9 +3,12 @@ import com.eatthepath.noise.component.NoiseKeyAgreement; import org.junit.jupiter.api.Test; +import javax.crypto.AEADBadTagException; import javax.crypto.ShortBufferException; import java.nio.ByteBuffer; +import java.security.KeyPair; import java.security.NoSuchAlgorithmException; +import java.security.PublicKey; import static org.junit.jupiter.api.Assertions.*; @@ -130,4 +133,37 @@ void readMessageShortBuffer() throws NoSuchAlgorithmException { assertThrows(ShortBufferException.class, () -> handshake.readMessage(ByteBuffer.wrap(message), ByteBuffer.allocate(payloadLength - 1))); } + + @Test + void repeatedFallback() throws NoSuchAlgorithmException { + final NoiseKeyAgreement keyAgreement = NoiseKeyAgreement.getInstance("25519"); + + final KeyPair initiatorStaticKeyPair = keyAgreement.generateKeyPair(); + final PublicKey staleRemoteStaticPublicKey = keyAgreement.generateKeyPair().getPublic(); + final KeyPair currentResponderStaticKeyPair = keyAgreement.generateKeyPair(); + + final byte[] initiatorStaticKeyMessage; + { + final NoiseHandshake ikInitiatorHandshake = + NoiseHandshakeBuilder.forIKInitiator(initiatorStaticKeyPair, staleRemoteStaticPublicKey) + .setComponentsFromProtocolName("Noise_IK_25519_AESGCM_SHA256") + .build(); + + initiatorStaticKeyMessage = ikInitiatorHandshake.writeMessage((byte[]) null); + } + + final NoiseHandshake ikResponderHandshake = + NoiseHandshakeBuilder.forIKResponder(currentResponderStaticKeyPair) + .setComponentsFromProtocolName("Noise_IK_25519_AESGCM_SHA256") + .build(); + + assertThrows(AEADBadTagException.class, () -> ikResponderHandshake.readMessage(initiatorStaticKeyMessage)); + + assertDoesNotThrow(() -> ikResponderHandshake.fallbackTo("XXfallback")); + assertThrows(IllegalStateException.class, () -> ikResponderHandshake.fallbackTo("XXfallback")); + + assertFalse(ikResponderHandshake.isExpectingRead()); + assertFalse(ikResponderHandshake.isExpectingWrite()); + assertFalse(ikResponderHandshake.isDone()); + } }