From 4d29db6bd218c4e3c016bfa408f6ee5cff9b454a Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Sat, 7 Sep 2024 14:43:52 -0400 Subject: [PATCH] Add support for HFS tokens and pattern modifiers --- .../eatthepath/noise/HandshakePattern.java | 111 +++++++++++++-- .../com/eatthepath/noise/NoiseHandshake.java | 1 + .../noise/HandshakePatternTest.java | 130 ++++++++++++++++++ 3 files changed, 232 insertions(+), 10 deletions(-) diff --git a/src/main/java/com/eatthepath/noise/HandshakePattern.java b/src/main/java/com/eatthepath/noise/HandshakePattern.java index 7c72cbc..9b17e75 100644 --- a/src/main/java/com/eatthepath/noise/HandshakePattern.java +++ b/src/main/java/com/eatthepath/noise/HandshakePattern.java @@ -340,6 +340,21 @@ class HandshakePattern { } record MessagePattern(NoiseHandshake.Role sender, Token[] tokens) { + + MessagePattern withAddedToken(final Token token, final int insertionIndex) { + if (insertionIndex < 0 || insertionIndex >= this.tokens().length + 1) { + throw new IllegalArgumentException("Illegal insertion index"); + } + + final Token[] modifiedTokens = new Token[this.tokens().length + 1]; + System.arraycopy(this.tokens(), 0, modifiedTokens, 0, insertionIndex); + modifiedTokens[insertionIndex] = token; + System.arraycopy(this.tokens(), insertionIndex, modifiedTokens, + insertionIndex + 1, this.tokens().length - insertionIndex); + + return new MessagePattern(this.sender(), modifiedTokens); + } + @Override public String toString() { final String prefix = switch (sender()) { @@ -375,18 +390,24 @@ enum Token { ES, SE, SS, - PSK; + PSK, + E1, + EKEM1; static Token fromString(final String string) { - return switch (string) { - case "e", "E" -> E; - case "s", "S" -> S; - case "ee", "EE" -> EE; - case "es", "ES" -> ES; - case "se", "SE" -> SE; - case "ss", "SS" -> SS; - case "psk", "PSK" -> PSK; - default -> throw new IllegalArgumentException("Unrecognized token: " + string); + for (final Token token : Token.values()) { + if (token.name().equalsIgnoreCase(string)) { + return token; + } + } + + throw new IllegalArgumentException("Unrecognized token: " + string); + } + + boolean isKeyAgreementToken() { + return switch (this) { + case EE, ES, SE, SS -> true; + default -> false; }; } } @@ -482,6 +503,8 @@ HandshakePattern withModifier(final String modifier) { modifiedMessagePatterns = getPatternsWithFallbackModifier(); } else if (modifier.startsWith("psk")) { modifiedMessagePatterns = getPatternsWithPskModifier(modifier); + } else if ("hfs".equals(modifier)) { + modifiedMessagePatterns = getPatternsWithHfsModifier(); } else { throw new IllegalArgumentException("Unrecognized modifier: " + modifier); } @@ -538,6 +561,74 @@ private MessagePattern[][] getPatternsWithPskModifier(final String modifier) { return new MessagePattern[][] { modifiedPreMessagePatterns, modifiedHandshakeMessagePatterns }; } + private MessagePattern[][] getPatternsWithHfsModifier() { + // Temporarily combine the pre-messages and "normal" messages to make iteration/state management easier + final MessagePattern[] messagePatterns = + new MessagePattern[getPreMessagePatterns().length + getHandshakeMessagePatterns().length]; + + System.arraycopy(getPreMessagePatterns(), 0, messagePatterns, 0, getPreMessagePatterns().length); + System.arraycopy(getHandshakeMessagePatterns(), 0, messagePatterns, + getPreMessagePatterns().length, getHandshakeMessagePatterns().length); + + boolean insertedE1Token = false; + boolean insertedEkem1Token = false; + + for (int i = 0; i < messagePatterns.length; i++) { + if (!insertedE1Token && Arrays.stream(messagePatterns[i].tokens()).anyMatch(token -> token == Token.E)) { + // We haven't inserted an E1 token yet, and this message pattern needs one. Exactly where it should go depends + // on whether this message pattern also contains a key agreement token, but either way, this pattern will wind + // up one token longer than it was when it started. + int insertionIndex = -1; + + for (int t = 0; t < messagePatterns[i].tokens().length; t++) { + final Token token = messagePatterns[i].tokens()[t]; + + // TODO Prove that E must come before key agreement tokens + if (token == Token.E || token.isKeyAgreementToken()) { + insertionIndex = t + 1; + + if (token.isKeyAgreementToken()) { + break; + } + } + } + + messagePatterns[i] = messagePatterns[i].withAddedToken(Token.E1, insertionIndex); + insertedE1Token = true; + } + + if (!insertedEkem1Token && Arrays.stream(messagePatterns[i].tokens()).anyMatch(token -> token == Token.EE)) { + // We haven't inserted an EKEM1 token yet, and this pattern needs one. EKEM1 tokens always go after the first + // EE token. + int insertionIndex = -1; + + for (int t = 0; t < messagePatterns[i].tokens().length; t++) { + if (messagePatterns[i].tokens()[t] == Token.EE) { + insertionIndex = t + 1; + break; + } + } + + messagePatterns[i] = messagePatterns[i].withAddedToken(Token.EKEM1, insertionIndex); + insertedEkem1Token = true; + } + + if (insertedE1Token && insertedEkem1Token) { + // No need to inspect the rest of the message patterns if we've already inserted both of the HFS tokens + break; + } + } + + final MessagePattern[] modifiedPreMessagePatterns = new MessagePattern[getPreMessagePatterns().length]; + final MessagePattern[] modifiedHandshakeMessagePatterns = new MessagePattern[getHandshakeMessagePatterns().length]; + + System.arraycopy(messagePatterns, 0, modifiedPreMessagePatterns, 0, getPreMessagePatterns().length); + System.arraycopy(messagePatterns, getPreMessagePatterns().length, + modifiedHandshakeMessagePatterns, 0, getHandshakeMessagePatterns().length); + + return new MessagePattern[][] { modifiedPreMessagePatterns, modifiedHandshakeMessagePatterns }; + } + private String getModifiedName(final String modifier) { final String modifiedName; diff --git a/src/main/java/com/eatthepath/noise/NoiseHandshake.java b/src/main/java/com/eatthepath/noise/NoiseHandshake.java index fe56737..bd7afaf 100644 --- a/src/main/java/com/eatthepath/noise/NoiseHandshake.java +++ b/src/main/java/com/eatthepath/noise/NoiseHandshake.java @@ -349,6 +349,7 @@ public enum Role { } case EE, ES, SE, SS, PSK -> throw new IllegalArgumentException("Key-mixing tokens must not appear in pre-messages"); + case E1, EKEM1 -> throw new UnsupportedOperationException(); })) .forEach(publicKey -> mixHash(keyAgreement.serializePublicKey(publicKey))); } diff --git a/src/test/java/com/eatthepath/noise/HandshakePatternTest.java b/src/test/java/com/eatthepath/noise/HandshakePatternTest.java index 24f7284..d4ee2b0 100644 --- a/src/test/java/com/eatthepath/noise/HandshakePatternTest.java +++ b/src/test/java/com/eatthepath/noise/HandshakePatternTest.java @@ -169,6 +169,111 @@ void withPskModifier() throws NoSuchPatternException { } } + @ParameterizedTest + @MethodSource + void withHfsModifier(final HandshakePattern expectedHfsPattern) throws NoSuchPatternException { + final String fundamentalPatternName = HandshakePattern.getFundamentalPatternName(expectedHfsPattern.getName()); + + assertEquals(expectedHfsPattern, HandshakePattern.getInstance(fundamentalPatternName).withModifier("hfs")); + } + + private static List withHfsModifier() { + return List.of( + HandshakePattern.fromString(""" + NNhfs: + -> e, e1 + <- e, ee, ekem1 + """), + + HandshakePattern.fromString(""" + NKhfs: + <- s + ... + -> e, es, e1 + <- e, ee, ekem1 + """), + + HandshakePattern.fromString(""" + NXhfs: + -> e, e1 + <- e, ee, ekem1, s, es + """), + + HandshakePattern.fromString(""" + XNhfs: + -> e, e1 + <- e, ee, ekem1 + -> s, se + """), + + HandshakePattern.fromString(""" + XKhfs: + <- s + ... + -> e, es, e1 + <- e, ee, ekem1 + -> s, se + """), + + HandshakePattern.fromString(""" + XXhfs: + -> e, e1 + <- e, ee, ekem1, s, es + -> s, se + """), + + HandshakePattern.fromString(""" + KNhfs: + -> s + ... + -> e, e1 + <- e, ee, ekem1, se + """), + + // Note that this is different from what's listed at https://github.com/noiseprotocol/noise_hfs_spec/blob/025f0f60cb3b94ad75b68e3a4158b9aac234f8cb/noise_hfs.md?plain=1#L130-L135; + // the specification (at the time of writing) appears to have a typo. Please see + // https://github.com/noiseprotocol/noise_hfs_spec/pull/3. + HandshakePattern.fromString(""" + KKhfs: + -> s + <- s + ... + -> e, es, e1, ss + <- e, ee, ekem1, se + """), + + HandshakePattern.fromString(""" + KXhfs: + -> s + ... + -> e, e1 + <- e, ee, ekem1, se, s, es + """), + + // This also deviates from the latest version of the spec to fix a typo (the `ee` token is missing in the + // current draft of the spec). Please see https://github.com/noiseprotocol/noise_hfs_spec/pull/4. + HandshakePattern.fromString(""" + INhfs: + -> e, e1, s + <- e, ee, ekem1, se + """), + + HandshakePattern.fromString(""" + IKhfs: + <- s + ... + -> e, es, e1, s, ss + <- e, ee, ekem1, se + """), + + HandshakePattern.fromString(""" + IXhfs: + -> e, e1, s + <- e, ee, ekem1, se, s, es + """) + ); + } + @Test void withModifierUnrecognized() { assertThrows(IllegalArgumentException.class, () -> HandshakePattern.getInstance("XX").withModifier("fancy")); @@ -253,4 +358,29 @@ void requiresRemoteStaticPublicKey() throws NoSuchPatternException { assertTrue(HandshakePattern.getInstance("KN").requiresRemoteStaticPublicKey(Role.RESPONDER)); assertFalse(HandshakePattern.getInstance("KN").requiresRemoteStaticPublicKey(Role.INITIATOR)); } + + @Test + void messagePatternWithAddedToken() { + final MessagePattern originalPattern = new HandshakePattern.MessagePattern(Role.INITIATOR, + new Token[] { Token.E, Token.EE, Token.SE }); + + assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR, + new Token[] { Token.E1, Token.E, Token.EE, Token.SE }), + originalPattern.withAddedToken(Token.E1, 0)); + + assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR, + new Token[] { Token.E, Token.E1, Token.EE, Token.SE }), + originalPattern.withAddedToken(Token.E1, 1)); + + assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR, + new Token[] { Token.E, Token.EE, Token.E1, Token.SE }), + originalPattern.withAddedToken(Token.E1, 2)); + + assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR, + new Token[] { Token.E, Token.EE, Token.SE, Token.E1 }), + originalPattern.withAddedToken(Token.E1, 3)); + + assertThrows(IllegalArgumentException.class, () -> originalPattern.withAddedToken(Token.E1, -1)); + assertThrows(IllegalArgumentException.class, () -> originalPattern.withAddedToken(Token.E1, 4)); + } }