Skip to content

Commit

Permalink
Merge pull request moby#46920 from dmcgowan/client-hijack-cleanup
Browse files Browse the repository at this point in the history
Replace use of httputil in client hijack
  • Loading branch information
thaJeztah authored Dec 18, 2023
2 parents 74cf9a0 + eb9ce77 commit 0751141
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 88 deletions.
8 changes: 0 additions & 8 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,6 @@ issues:
path: "api/types/(volume|container)/"
linters:
- revive
# FIXME temporarily suppress these. See #39926
- text: "SA1019: httputil.NewClientConn"
linters:
- staticcheck
# FIXME temporarily suppress these (related to the ones above)
- text: "SA1019: httputil.ErrPersistEOF"
linters:
- staticcheck
# FIXME temporarily suppress these (see https://github.com/gotestyourself/gotest.tools/issues/272)
- text: "SA1019: (assert|cmp|is)\\.ErrorType is deprecated"
linters:
Expand Down
95 changes: 26 additions & 69 deletions client/hijack.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@ import (
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"

"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/versions"
"github.com/pkg/errors"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)

// postHijacked sends a POST request and hijacks the connection.
Expand Down Expand Up @@ -54,33 +50,16 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", proto)

// We aren't using the configured RoundTripper here so manually inject the trace context
tp := cli.tp
if tp == nil {
if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
tp = span.TracerProvider()
} else {
tp = otel.GetTracerProvider()
}
}

ctx, span := tp.Tracer("").Start(ctx, req.Method+" "+req.URL.Path, trace.WithSpanKind(trace.SpanKindClient))
// FIXME(thaJeztah): httpconv.ClientRequest is now an internal package; replace this with alternative for semconv v1.21
// span.SetAttributes(httpconv.ClientRequest(req)...)
defer func() {
if retErr != nil {
span.RecordError(retErr)
span.SetStatus(codes.Error, retErr.Error())
}
span.End()
}()
otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header))

dialer := cli.Dialer()
conn, err := dialer(ctx)
if err != nil {
return nil, "", errors.Wrap(err, "cannot connect to the Docker daemon. Is 'docker daemon' running on this host?")
}
defer func() {
if retErr != nil {
conn.Close()
}
}()

// When we set up a TCP connection for hijack, there could be long periods
// of inactivity (a long running command with no output) that in certain
Expand All @@ -92,58 +71,29 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
}

clientconn := httputil.NewClientConn(conn, nil)
defer clientconn.Close()
hc := &hijackedConn{conn, bufio.NewReader(conn)}

// Server hijacks the connection, error 'connection closed' expected
resp, err := clientconn.Do(req)
if resp != nil {
// This is a simplified variant of "httpconv.ClientStatus(resp.StatusCode))";
//
// The main purpose of httpconv.ClientStatus() is to detect whether the
// status was successful (1xx, 2xx, 3xx) or non-successful (4xx/5xx).
//
// It also provides complex logic to *validate* status-codes against
// a hard-coded list meant to exclude "bogus" status codes in "success"
// ranges (1xx, 2xx) and convert them into an error status. That code
// seemed over-reaching (and not accounting for potential future valid
// status codes). We assume we only get valid status codes, and only
// look at status-code ranges.
//
// For reference, see:
// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/v1.17.0/httpconv/http.go#L85-L89
// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/internal/v2/http.go#L322-L330
// https://github.com/open-telemetry/opentelemetry-go/blob/v1.21.0/semconv/internal/v2/http.go#L356-L404
code := codes.Unset
if resp.StatusCode >= http.StatusBadRequest {
code = codes.Error
}
span.SetStatus(code, "")
resp, err := otelhttp.NewTransport(hc).RoundTrip(req)
if err != nil {
return nil, "", err
}

//nolint:staticcheck // ignore SA1019 for connecting to old (pre go1.8) daemons
if err != httputil.ErrPersistEOF {
if err != nil {
return nil, "", err
}
if resp.StatusCode != http.StatusSwitchingProtocols {
_ = resp.Body.Close()
return nil, "", fmt.Errorf("unable to upgrade to %s, received %d", proto, resp.StatusCode)
}
if resp.StatusCode != http.StatusSwitchingProtocols {
_ = resp.Body.Close()
return nil, "", fmt.Errorf("unable to upgrade to %s, received %d", proto, resp.StatusCode)
}

c, br := clientconn.Hijack()
if br.Buffered() > 0 {
if hc.r.Buffered() > 0 {
// If there is buffered content, wrap the connection. We return an
// object that implements CloseWrite if the underlying connection
// implements it.
if _, ok := c.(types.CloseWriter); ok {
c = &hijackedConnCloseWriter{&hijackedConn{c, br}}
if _, ok := hc.Conn.(types.CloseWriter); ok {
conn = &hijackedConnCloseWriter{hc}
} else {
c = &hijackedConn{c, br}
conn = hc
}
} else {
br.Reset(nil)
hc.r.Reset(nil)
}

var mediaType string
Expand All @@ -152,7 +102,7 @@ func (cli *Client) setupHijackConn(req *http.Request, proto string) (_ net.Conn,
mediaType = resp.Header.Get("Content-Type")
}

return c, mediaType, nil
return conn, mediaType, nil
}

// hijackedConn wraps a net.Conn and is returned by setupHijackConn in the case
Expand All @@ -164,6 +114,13 @@ type hijackedConn struct {
r *bufio.Reader
}

func (c *hijackedConn) RoundTrip(req *http.Request) (*http.Response, error) {
if err := req.Write(c.Conn); err != nil {
return nil, err
}
return http.ReadResponse(c.r, req)
}

func (c *hijackedConn) Read(b []byte) (int, error) {
return c.r.Read(b)
}
Expand Down
37 changes: 26 additions & 11 deletions integration/plugin/authz/authz_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path/filepath"
Expand All @@ -25,6 +24,7 @@ import (
"github.com/docker/docker/pkg/archive"
"github.com/docker/docker/pkg/authorization"
"github.com/docker/docker/testutil/environment"
"github.com/docker/go-connections/sockets"
"gotest.tools/v3/assert"
"gotest.tools/v3/skip"
)
Expand Down Expand Up @@ -81,6 +81,17 @@ func isAllowed(reqURI string) bool {
return false
}

func socketHTTPClient(u *url.URL) (*http.Client, error) {
transport := &http.Transport{}
err := sockets.ConfigureTransport(transport, u.Scheme, u.Path)
if err != nil {
return nil, err
}
return &http.Client{
Transport: transport,
}, nil
}

func TestAuthZPluginAllowRequest(t *testing.T) {
ctx := setupTestV1(t)

Expand Down Expand Up @@ -176,15 +187,17 @@ func TestAuthZPluginAPIDenyResponse(t *testing.T) {
daemonURL, err := url.Parse(d.Sock())
assert.NilError(t, err)

conn, err := net.DialTimeout(daemonURL.Scheme, daemonURL.Path, time.Second*10)
socketClient, err := socketHTTPClient(daemonURL)
assert.NilError(t, err)
c := httputil.NewClientConn(conn, nil)
req, err := http.NewRequest(http.MethodGet, "/version", nil)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/version", nil)
assert.NilError(t, err)
req = req.WithContext(ctx)
resp, err := c.Do(req)
req.URL.Scheme = "http"
req.URL.Host = client.DummyHost

resp, err := socketClient.Do(req)
assert.NilError(t, err)

assert.DeepEqual(t, http.StatusForbidden, resp.StatusCode)
}

Expand Down Expand Up @@ -471,13 +484,15 @@ func TestAuthZPluginHeader(t *testing.T) {
daemonURL, err := url.Parse(d.Sock())
assert.NilError(t, err)

conn, err := net.DialTimeout(daemonURL.Scheme, daemonURL.Path, time.Second*10)
socketClient, err := socketHTTPClient(daemonURL)
assert.NilError(t, err)
client := httputil.NewClientConn(conn, nil)
req, err := http.NewRequest(http.MethodGet, "/version", nil)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/version", nil)
assert.NilError(t, err)
req = req.WithContext(ctx)
resp, err := client.Do(req)
req.URL.Scheme = "http"
req.URL.Host = client.DummyHost

resp, err := socketClient.Do(req)
assert.NilError(t, err)
assert.Equal(t, "application/json", resp.Header["Content-Type"][0])
}
Expand Down

0 comments on commit 0751141

Please sign in to comment.