diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go index 35b62e3311..55216c15f3 100644 --- a/ssh/client_auth_test.go +++ b/ssh/client_auth_test.go @@ -613,7 +613,7 @@ func TestClientAuthMaxAuthTries(t *testing.T) { } serverConfig.AddHostKey(testSigners["rsa"]) - expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ + expectedErr := fmt.Errorf("ssh: handshake failed: %v", &DisconnectError{ Reason: 2, Message: "too many authentication failures", }) @@ -676,7 +676,7 @@ func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) { t.Fatalf("unable to dial remote side: %s", err) } - expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ + expectedErr := fmt.Errorf("ssh: handshake failed: %v", &DisconnectError{ Reason: 2, Message: "too many authentication failures", }) diff --git a/ssh/connection.go b/ssh/connection.go index 35661a52be..2d28b7fa39 100644 --- a/ssh/connection.go +++ b/ssh/connection.go @@ -20,6 +20,39 @@ func (e *OpenChannelError) Error() string { return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) } +// DisconnectReason is an enumeration used when closing connections to describe +// why a disconnect was sent. See RFC 4253, section 11.1. +type DisconnectReason uint32 + +const ( + HostNotAllowedToConnect DisconnectReason = 1 + ProtocolError = 2 + KeyExchangeFailed = 3 + // 4 is reserved for future use. + MacError = 5 + CompressionError = 6 + ServiceNotAvailable = 7 + ProtocolVersionNotSupported = 8 + HostKeyNotVerifiable = 9 + ConnectionLost = 10 + ByApplication = 11 + TooManyConnections = 12 + AuthCancelledByUser = 13 + NoMoreAuthMethodsAvailable = 14 + IllegalUserName = 15 +) + +// DisconnectError is returned by Conn.Wait if the other end of the connection +// explicitly closes the connection by sending a disconnect message. +type DisconnectError struct { + Reason DisconnectReason + Message string +} + +func (d *DisconnectError) Error() string { + return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) +} + // ConnMetadata holds metadata for the connection. type ConnMetadata interface { // User returns the user ID for this connection. @@ -66,7 +99,9 @@ type Conn interface { Close() error // Wait blocks until the connection has shut down, and returns the - // error causing the shutdown. + // error causing the shutdown. If the connection has been closed by an + // explicit disconnect message from the other end, then Wait will return a + // DisconnectError. Wait() error // TODO(hanwen): consider exposing: diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go index 3d0ab5044c..88f77aee70 100644 --- a/ssh/handshake_test.go +++ b/ssh/handshake_test.go @@ -505,11 +505,15 @@ func TestDisconnect(t *testing.T) { defer trS.Close() trC.writePacket([]byte{msgRequestSuccess, 0, 0}) - errMsg := &disconnectMsg{ + errPacket := &disconnectMsg{ Reason: 42, Message: "such is life", } - trC.writePacket(Marshal(errMsg)) + errResponse := &DisconnectError{ + Reason: DisconnectReason(errPacket.Reason), + Message: errPacket.Message, + } + trC.writePacket(Marshal(errPacket)) trC.writePacket([]byte{msgRequestSuccess, 0, 0}) packet, err := trS.readPacket() @@ -523,8 +527,8 @@ func TestDisconnect(t *testing.T) { _, err = trS.readPacket() if err == nil { t.Errorf("readPacket 2 succeeded") - } else if !reflect.DeepEqual(err, errMsg) { - t.Errorf("got error %#v, want %#v", err, errMsg) + } else if !reflect.DeepEqual(err, errResponse) { + t.Errorf("got error %#v, want %#v", err, errResponse) } _, err = trS.readPacket() diff --git a/ssh/messages.go b/ssh/messages.go index 922032d952..f13a7b5834 100644 --- a/ssh/messages.go +++ b/ssh/messages.go @@ -43,10 +43,6 @@ type disconnectMsg struct { Language string } -func (d *disconnectMsg) Error() string { - return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) -} - // See RFC 4253, section 7.1. const msgKexInit = 20 diff --git a/ssh/server.go b/ssh/server.go index 9e3870292f..667de86b1c 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -416,7 +416,8 @@ userAuthLoop: return nil, err } - return nil, discMsg + err := &DisconnectError{Reason: DisconnectReason(discMsg.Reason), Message: discMsg.Message} + return nil, err } var userAuthReq userAuthRequestMsg diff --git a/ssh/transport.go b/ssh/transport.go index da015801ea..e30856ba49 100644 --- a/ssh/transport.go +++ b/ssh/transport.go @@ -154,7 +154,8 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { if err := Unmarshal(packet, &msg); err != nil { return nil, err } - return nil, &msg + err := &DisconnectError{Reason: DisconnectReason(msg.Reason), Message: msg.Message} + return nil, err } }