diff --git a/client/client.go b/client/client.go index 0108471c..77568dd6 100644 --- a/client/client.go +++ b/client/client.go @@ -52,7 +52,7 @@ type Requester interface { // Transport provides direct access to the transport instance for // specialist operations. - Transport() http.RoundTripper + Transport() *http.Transport } // RequestOptions allows setting up a specific request. @@ -198,16 +198,11 @@ func (client *Client) getTaskWebsocket(taskID, websocketID string) (clientWebsoc return client.getWebsocket(url) } -func getWebsocket(transport http.RoundTripper, url string) (clientWebsocket, error) { - httpTransport, ok := transport.(*http.Transport) - if !ok { - return nil, fmt.Errorf("cannot create websocket: transport not compatible") - } - +func getWebsocket(transport *http.Transport, url string) (clientWebsocket, error) { dialer := websocket.Dialer{ - NetDial: httpTransport.Dial, - Proxy: httpTransport.Proxy, - TLSClientConfig: httpTransport.TLSClientConfig, + NetDial: transport.Dial, + Proxy: transport.Proxy, + TLSClientConfig: transport.TLSClientConfig, HandshakeTimeout: 5 * time.Second, } conn, _, err := dialer.Dial(url, nil) @@ -216,16 +211,7 @@ func getWebsocket(transport http.RoundTripper, url string) (clientWebsocket, err // CloseIdleConnections closes any API connections that are currently unused. func (client *Client) CloseIdleConnections() { - transport := client.Requester().Transport() - // The following is taken from net/http/client.go because - // we are directly going to try and close idle connections and - // we must make sure the transport supports this. - type closeIdler interface { - CloseIdleConnections() - } - if tr, ok := transport.(closeIdler); ok { - tr.CloseIdleConnections() - } + client.Requester().Transport().CloseIdleConnections() } // Maintenance returns an error reflecting the daemon maintenance status or nil. @@ -589,7 +575,7 @@ type DefaultRequester struct { baseURL url.URL doer doer userAgent string - transport http.RoundTripper + transport *http.Transport decoder DecoderFunc } @@ -625,6 +611,6 @@ func (br *DefaultRequester) SetDecoder(decoder DecoderFunc) { br.decoder = decoder } -func (br *DefaultRequester) Transport() http.RoundTripper { +func (br *DefaultRequester) Transport() *http.Transport { return br.transport }