Skip to content

Commit

Permalink
Add support for HFS tokens and pattern modifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
jchambers committed Sep 7, 2024
1 parent 4de5aef commit 4d29db6
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 10 deletions.
111 changes: 101 additions & 10 deletions src/main/java/com/eatthepath/noise/HandshakePattern.java
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,21 @@ class HandshakePattern {
}

record MessagePattern(NoiseHandshake.Role sender, Token[] tokens) {

MessagePattern withAddedToken(final Token token, final int insertionIndex) {
if (insertionIndex < 0 || insertionIndex >= this.tokens().length + 1) {
throw new IllegalArgumentException("Illegal insertion index");
}

final Token[] modifiedTokens = new Token[this.tokens().length + 1];
System.arraycopy(this.tokens(), 0, modifiedTokens, 0, insertionIndex);
modifiedTokens[insertionIndex] = token;
System.arraycopy(this.tokens(), insertionIndex, modifiedTokens,
insertionIndex + 1, this.tokens().length - insertionIndex);

return new MessagePattern(this.sender(), modifiedTokens);
}

@Override
public String toString() {
final String prefix = switch (sender()) {
Expand Down Expand Up @@ -375,18 +390,24 @@ enum Token {
ES,
SE,
SS,
PSK;
PSK,
E1,
EKEM1;

static Token fromString(final String string) {
return switch (string) {
case "e", "E" -> E;
case "s", "S" -> S;
case "ee", "EE" -> EE;
case "es", "ES" -> ES;
case "se", "SE" -> SE;
case "ss", "SS" -> SS;
case "psk", "PSK" -> PSK;
default -> throw new IllegalArgumentException("Unrecognized token: " + string);
for (final Token token : Token.values()) {
if (token.name().equalsIgnoreCase(string)) {
return token;
}
}

throw new IllegalArgumentException("Unrecognized token: " + string);
}

boolean isKeyAgreementToken() {
return switch (this) {
case EE, ES, SE, SS -> true;
default -> false;
};
}
}
Expand Down Expand Up @@ -482,6 +503,8 @@ HandshakePattern withModifier(final String modifier) {
modifiedMessagePatterns = getPatternsWithFallbackModifier();
} else if (modifier.startsWith("psk")) {
modifiedMessagePatterns = getPatternsWithPskModifier(modifier);
} else if ("hfs".equals(modifier)) {
modifiedMessagePatterns = getPatternsWithHfsModifier();
} else {
throw new IllegalArgumentException("Unrecognized modifier: " + modifier);
}
Expand Down Expand Up @@ -538,6 +561,74 @@ private MessagePattern[][] getPatternsWithPskModifier(final String modifier) {
return new MessagePattern[][] { modifiedPreMessagePatterns, modifiedHandshakeMessagePatterns };
}

private MessagePattern[][] getPatternsWithHfsModifier() {
// Temporarily combine the pre-messages and "normal" messages to make iteration/state management easier
final MessagePattern[] messagePatterns =
new MessagePattern[getPreMessagePatterns().length + getHandshakeMessagePatterns().length];

System.arraycopy(getPreMessagePatterns(), 0, messagePatterns, 0, getPreMessagePatterns().length);
System.arraycopy(getHandshakeMessagePatterns(), 0, messagePatterns,
getPreMessagePatterns().length, getHandshakeMessagePatterns().length);

boolean insertedE1Token = false;
boolean insertedEkem1Token = false;

for (int i = 0; i < messagePatterns.length; i++) {
if (!insertedE1Token && Arrays.stream(messagePatterns[i].tokens()).anyMatch(token -> token == Token.E)) {
// We haven't inserted an E1 token yet, and this message pattern needs one. Exactly where it should go depends
// on whether this message pattern also contains a key agreement token, but either way, this pattern will wind
// up one token longer than it was when it started.
int insertionIndex = -1;

for (int t = 0; t < messagePatterns[i].tokens().length; t++) {
final Token token = messagePatterns[i].tokens()[t];

// TODO Prove that E must come before key agreement tokens
if (token == Token.E || token.isKeyAgreementToken()) {
insertionIndex = t + 1;

if (token.isKeyAgreementToken()) {
break;
}
}
}

messagePatterns[i] = messagePatterns[i].withAddedToken(Token.E1, insertionIndex);
insertedE1Token = true;
}

if (!insertedEkem1Token && Arrays.stream(messagePatterns[i].tokens()).anyMatch(token -> token == Token.EE)) {
// We haven't inserted an EKEM1 token yet, and this pattern needs one. EKEM1 tokens always go after the first
// EE token.
int insertionIndex = -1;

for (int t = 0; t < messagePatterns[i].tokens().length; t++) {
if (messagePatterns[i].tokens()[t] == Token.EE) {
insertionIndex = t + 1;
break;
}
}

messagePatterns[i] = messagePatterns[i].withAddedToken(Token.EKEM1, insertionIndex);
insertedEkem1Token = true;
}

if (insertedE1Token && insertedEkem1Token) {
// No need to inspect the rest of the message patterns if we've already inserted both of the HFS tokens
break;
}
}

final MessagePattern[] modifiedPreMessagePatterns = new MessagePattern[getPreMessagePatterns().length];
final MessagePattern[] modifiedHandshakeMessagePatterns = new MessagePattern[getHandshakeMessagePatterns().length];

System.arraycopy(messagePatterns, 0, modifiedPreMessagePatterns, 0, getPreMessagePatterns().length);
System.arraycopy(messagePatterns, getPreMessagePatterns().length,
modifiedHandshakeMessagePatterns, 0, getHandshakeMessagePatterns().length);

return new MessagePattern[][] { modifiedPreMessagePatterns, modifiedHandshakeMessagePatterns };
}

private String getModifiedName(final String modifier) {
final String modifiedName;

Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/eatthepath/noise/NoiseHandshake.java
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ public enum Role {
}
case EE, ES, SE, SS, PSK ->
throw new IllegalArgumentException("Key-mixing tokens must not appear in pre-messages");
case E1, EKEM1 -> throw new UnsupportedOperationException();
}))
.forEach(publicKey -> mixHash(keyAgreement.serializePublicKey(publicKey)));
}
Expand Down
130 changes: 130 additions & 0 deletions src/test/java/com/eatthepath/noise/HandshakePatternTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,111 @@ void withPskModifier() throws NoSuchPatternException {
}
}

@ParameterizedTest
@MethodSource
void withHfsModifier(final HandshakePattern expectedHfsPattern) throws NoSuchPatternException {
final String fundamentalPatternName = HandshakePattern.getFundamentalPatternName(expectedHfsPattern.getName());

assertEquals(expectedHfsPattern, HandshakePattern.getInstance(fundamentalPatternName).withModifier("hfs"));
}

private static List<HandshakePattern> withHfsModifier() {
return List.of(
HandshakePattern.fromString("""
NNhfs:
-> e, e1
<- e, ee, ekem1
"""),

HandshakePattern.fromString("""
NKhfs:
<- s
...
-> e, es, e1
<- e, ee, ekem1
"""),

HandshakePattern.fromString("""
NXhfs:
-> e, e1
<- e, ee, ekem1, s, es
"""),

HandshakePattern.fromString("""
XNhfs:
-> e, e1
<- e, ee, ekem1
-> s, se
"""),

HandshakePattern.fromString("""
XKhfs:
<- s
...
-> e, es, e1
<- e, ee, ekem1
-> s, se
"""),

HandshakePattern.fromString("""
XXhfs:
-> e, e1
<- e, ee, ekem1, s, es
-> s, se
"""),

HandshakePattern.fromString("""
KNhfs:
-> s
...
-> e, e1
<- e, ee, ekem1, se
"""),

// Note that this is different from what's listed at https://github.com/noiseprotocol/noise_hfs_spec/blob/025f0f60cb3b94ad75b68e3a4158b9aac234f8cb/noise_hfs.md?plain=1#L130-L135;
// the specification (at the time of writing) appears to have a typo. Please see
// https://github.com/noiseprotocol/noise_hfs_spec/pull/3.
HandshakePattern.fromString("""
KKhfs:
-> s
<- s
...
-> e, es, e1, ss
<- e, ee, ekem1, se
"""),

HandshakePattern.fromString("""
KXhfs:
-> s
...
-> e, e1
<- e, ee, ekem1, se, s, es
"""),

// This also deviates from the latest version of the spec to fix a typo (the `ee` token is missing in the
// current draft of the spec). Please see https://github.com/noiseprotocol/noise_hfs_spec/pull/4.
HandshakePattern.fromString("""
INhfs:
-> e, e1, s
<- e, ee, ekem1, se
"""),

HandshakePattern.fromString("""
IKhfs:
<- s
...
-> e, es, e1, s, ss
<- e, ee, ekem1, se
"""),

HandshakePattern.fromString("""
IXhfs:
-> e, e1, s
<- e, ee, ekem1, se, s, es
""")
);
}

@Test
void withModifierUnrecognized() {
assertThrows(IllegalArgumentException.class, () -> HandshakePattern.getInstance("XX").withModifier("fancy"));
Expand Down Expand Up @@ -253,4 +358,29 @@ void requiresRemoteStaticPublicKey() throws NoSuchPatternException {
assertTrue(HandshakePattern.getInstance("KN").requiresRemoteStaticPublicKey(Role.RESPONDER));
assertFalse(HandshakePattern.getInstance("KN").requiresRemoteStaticPublicKey(Role.INITIATOR));
}

@Test
void messagePatternWithAddedToken() {
final MessagePattern originalPattern = new HandshakePattern.MessagePattern(Role.INITIATOR,
new Token[] { Token.E, Token.EE, Token.SE });

assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR,
new Token[] { Token.E1, Token.E, Token.EE, Token.SE }),
originalPattern.withAddedToken(Token.E1, 0));

assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR,
new Token[] { Token.E, Token.E1, Token.EE, Token.SE }),
originalPattern.withAddedToken(Token.E1, 1));

assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR,
new Token[] { Token.E, Token.EE, Token.E1, Token.SE }),
originalPattern.withAddedToken(Token.E1, 2));

assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR,
new Token[] { Token.E, Token.EE, Token.SE, Token.E1 }),
originalPattern.withAddedToken(Token.E1, 3));

assertThrows(IllegalArgumentException.class, () -> originalPattern.withAddedToken(Token.E1, -1));
assertThrows(IllegalArgumentException.class, () -> originalPattern.withAddedToken(Token.E1, 4));
}
}

0 comments on commit 4d29db6

Please sign in to comment.