diff --git a/ssh/knownhosts/knownhosts.go b/ssh/knownhosts/knownhosts.go index 7376a8dff2..0fdc669584 100644 --- a/ssh/knownhosts/knownhosts.go +++ b/ssh/knownhosts/knownhosts.go @@ -358,34 +358,20 @@ func (db *hostKeyDB) checkAddr(a addr, remoteKey ssh.PublicKey) error { // is just a key for the IP address, but not for the // hostname? - // Algorithm => key. - knownKeys := map[string]KnownKey{} - for _, l := range db.lines { - if l.match(a) { - typ := l.knownKey.Key.Type() - if _, ok := knownKeys[typ]; !ok { - knownKeys[typ] = l.knownKey - } - } - } - keyErr := &KeyError{} - for _, v := range knownKeys { - keyErr.Want = append(keyErr.Want, v) - } - // Unknown remote host. - if len(knownKeys) == 0 { - return keyErr - } + for _, l := range db.lines { + if !l.match(a) { + continue + } - // If the remote host starts using a different, unknown key type, we - // also interpret that as a mismatch. - if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known.Key, remoteKey) { - return keyErr + keyErr.Want = append(keyErr.Want, l.knownKey) + if keyEq(l.knownKey.Key, remoteKey) { + return nil + } } - return nil + return keyErr } // The Read function parses file contents. diff --git a/ssh/knownhosts/knownhosts_test.go b/ssh/knownhosts/knownhosts_test.go index 464dd59249..156576ad07 100644 --- a/ssh/knownhosts/knownhosts_test.go +++ b/ssh/knownhosts/knownhosts_test.go @@ -201,17 +201,6 @@ func TestHostNamePrecedence(t *testing.T) { } } -func TestDBOrderingPrecedenceKeyType(t *testing.T) { - str := fmt.Sprintf("server.org,%s %s\nserver.org,%s %s", testAddr, edKeyStr, testAddr, alternateEdKeyStr) - db := testDB(t, str) - - if err := db.check("server.org:22", testAddr, alternateEdKey); err == nil { - t.Errorf("check succeeded") - } else if _, ok := err.(*KeyError); !ok { - t.Errorf("got %T, want *KeyError", err) - } -} - func TestNegate(t *testing.T) { str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr) db := testDB(t, str) @@ -354,3 +343,16 @@ func TestHashedHostkeyCheck(t *testing.T) { t.Errorf("got error %v, want %v", got, want) } } + +func TestIssue36126(t *testing.T) { + str := fmt.Sprintf("server.org,%s %s\nserver.org,%s %s", testAddr, edKeyStr, testAddr, alternateEdKeyStr) + db := testDB(t, str) + + if err := db.check("server.org:22", testAddr, edKey); err != nil { + t.Errorf("should have passed the check, got %v", err) + } + + if err := db.check("server.org:22", testAddr, alternateEdKey); err != nil { + t.Errorf("should have passed the check, got %v", err) + } +}