diff --git a/ssh/tcpip.go b/ssh/tcpip.go index ef5059a11d..c52c840812 100644 --- a/ssh/tcpip.go +++ b/ssh/tcpip.go @@ -350,7 +350,7 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err } ch := make(chan connErr) go func() { - conn, err := c.Dial(n, addr) + conn, err := c.dialContext(ctx, n, addr) select { case ch <- connErr{conn, err}: case <-ctx.Done(): @@ -369,7 +369,13 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err // Dial initiates a connection to the addr from the remote host. // The resulting connection has a zero LocalAddr() and RemoteAddr(). +// For TCP addresses the port section of the address can be a port number or a service name. +// Service names are resolved at the client side, domain names are resolved on the server. func (c *Client) Dial(n, addr string) (net.Conn, error) { + return c.dialContext(context.Background(), n, addr) +} + +func (c *Client) dialContext(ctx context.Context, n, addr string) (net.Conn, error) { var ch Channel switch n { case "tcp", "tcp4", "tcp6": @@ -378,11 +384,11 @@ func (c *Client) Dial(n, addr string) (net.Conn, error) { if err != nil { return nil, err } - port, err := strconv.ParseUint(portString, 10, 16) + port, err := net.DefaultResolver.LookupPort(ctx, n, portString) if err != nil { return nil, err } - ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port)) + ch, err = c.dial(net.IPv4zero.String(), 0, host, port) if err != nil { return nil, err } @@ -441,18 +447,18 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) // RFC 4254 7.2 type channelOpenDirectMsg struct { - raddr string - rport uint32 - laddr string - lport uint32 + Addr string + Port uint32 + OriginAddr string + OriginPort uint32 } func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { msg := channelOpenDirectMsg{ - raddr: raddr, - rport: uint32(rport), - laddr: laddr, - lport: uint32(lport), + Addr: raddr, + Port: uint32(rport), + OriginAddr: laddr, + OriginPort: uint32(lport), } ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) if err != nil { diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go index 4d85114727..b8b519a9e8 100644 --- a/ssh/tcpip_test.go +++ b/ssh/tcpip_test.go @@ -6,6 +6,7 @@ package ssh import ( "context" + "fmt" "net" "testing" "time" @@ -51,3 +52,81 @@ func TestClientDialContextWithDeadline(t *testing.T) { t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) } } + +func TestDialNamedPort(t *testing.T) { + srvConn, clientConn, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer srvConn.Close() + defer clientConn.Close() + + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + srvErr := make(chan error, 10) + go func() { + defer close(srvErr) + _, chans, req, err := NewServerConn(srvConn, serverConf) + if err != nil { + srvErr <- fmt.Errorf("NewServerConn: %w", err) + return + } + go DiscardRequests(req) + for newChan := range chans { + if newChan.ChannelType() != "direct-tcpip" { + srvErr <- fmt.Errorf("expected direct-tcpip channel, got=%s", newChan.ChannelType()) + if err := newChan.Reject(UnknownChannelType, "This test server only supports direct-tcpip"); err != nil { + srvErr <- err + } + continue + } + data := channelOpenDirectMsg{} + if err := Unmarshal(newChan.ExtraData(), &data); err != nil { + if err := newChan.Reject(ConnectionFailed, err.Error()); err != nil { + srvErr <- err + } + continue + } + // Below we dial for service `ssh` which should be translated to 22. + if data.Port != 22 { + if err := newChan.Reject(ConnectionFailed, fmt.Sprintf("expected port 22 got=%d", data.Port)); err != nil { + srvErr <- err + } + continue + } + ch, reqs, err := newChan.Accept() + if err != nil { + srvErr <- fmt.Errorf("Accept: %w", err) + continue + } + go DiscardRequests(reqs) + if err := ch.Close(); err != nil { + srvErr <- err + } + } + }() + + clientConf := &ClientConfig{ + User: "testuser", + HostKeyCallback: InsecureIgnoreHostKey(), + } + sshClientConn, newChans, reqs, err := NewClientConn(clientConn, "", clientConf) + if err != nil { + t.Fatal(err) + } + sshClient := NewClient(sshClientConn, newChans, reqs) + + // The port section in the host:port string being a named service `ssh` is the main point of the test. + _, err = sshClient.Dial("tcp", "localhost:ssh") + if err != nil { + t.Error(err) + } + + // Stop the ssh server. + clientConn.Close() + for err := range srvErr { + t.Errorf("ssh server: %s", err) + } +}