diff --git a/.gitignore b/.gitignore index 3fe4214..f242a8d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *swp +/.idea /coverage.out /cover.cov autobahn-testsuite diff --git a/README.md b/README.md index 472f2e2..aff8b7d 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ quickws是一个高性能的websocket库 * [配置握手时的超时时间](#配置握手时的超时时间) * [配置自动回复ping消息](#配置自动回复ping消息) * [配置socks5代理](#配置socks5代理) + * [配置proxy代理](#配置proxy代理) * [服务配置参数](#服务端配置) * [配置服务自动回复ping消息](#配置服务自动回复ping消息) ## 注意⚠️ @@ -210,6 +211,21 @@ func main() { })) } ``` +#### 配置proxy代理 +```go +import( + "github.com/antlabs/quickws" +) + +func main() { + + proxy := func(*http.Request) (*url.URL, error) { + return url.Parse("http://127.0.0.1:1007") + } + + quickws.Dial("ws://127.0.0.1:12345", quickws.WithClientProxyFunc(proxy)) +} +``` ### 服务端配置参数 #### 配置服务自动回复ping消息 ```go diff --git a/autobahn/Makefile b/autobahn/Makefile index 355258e..f50d11c 100644 --- a/autobahn/Makefile +++ b/autobahn/Makefile @@ -1,8 +1,6 @@ all: # mac, arm64 GOOS=darwin GOARCH=arm64 go build -o autobahn-server-darwin-arm64 ./autobahn-server.go - # mac, arm64 - GOOS=darwin GOARCH=arm64 go build -tags=goexperiment.arenas -o autobahn-server-darwin-arm64-arena ./autobahn-server.go # linux amd64 GOOS=linux GOARCH=amd64 go build -o autobahn-server-linux-amd64 ./autobahn-server.go diff --git a/client.go b/client.go index 9412da0..8f34dde 100644 --- a/client.go +++ b/client.go @@ -28,6 +28,7 @@ import ( "github.com/antlabs/wsutil/bytespool" "github.com/antlabs/wsutil/enum" "github.com/antlabs/wsutil/fixedreader" + "github.com/antlabs/wsutil/hostname" ) var ( @@ -127,6 +128,10 @@ func (d *DialOption) handshake() (*http.Request, string, error) { d.Header.Add("Sec-WebSocket-Extensions", strExtensions) } + if len(d.subProtocols) > 0 { + d.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.subProtocols, ", ")} + } + req.Header = d.Header return req, secWebSocket, nil } @@ -193,7 +198,7 @@ func (d *DialOption) Dial() (c *Conn, err error) { var conn net.Conn begin := time.Now() - hostName := getHostName(d.u) + hostName := hostname.GetHostName(d.u) // conn, err := net.DialTimeout("tcp", d.u.Host /* TODO 加端号*/, d.dialTimeout) dialFunc := net.Dial if d.dialFunc != nil { diff --git a/common_options.go b/common_options.go index f2fb947..ce0e648 100644 --- a/common_options.go +++ b/common_options.go @@ -316,3 +316,18 @@ func WithClientProxyFunc(proxyFunc func(*http.Request) (*url.URL, error)) Client o.proxyFunc = proxyFunc } } + +// 20. 设置支持的子协议 +// 20.1 设置客户端支持的子协议 +func WithClientSubprotocols(subprotocols []string) ClientOption { + return func(o *DialOption) { + o.subProtocols = subprotocols + } +} + +// 20.2 设置服务端支持的子协议 +func WithServerSubprotocols(subprotocols []string) ServerOption { + return func(o *ConnOption) { + o.subProtocols = subprotocols + } +} diff --git a/go.mod b/go.mod index ad9d6d2..7183af5 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/antlabs/quickws go 1.20 require ( - github.com/antlabs/wsutil v0.1.2 + github.com/antlabs/wsutil v0.1.6 golang.org/x/net v0.19.0 ) diff --git a/go.sum b/go.sum index b5dc98b..b6aaea6 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ -github.com/antlabs/wsutil v0.1.2 h1:8H6E0eMJ2Wp0qi9YGDeyG3DlIfIZncw2NSScC5bYSBQ= -github.com/antlabs/wsutil v0.1.2/go.mod h1:7ec5eUM7nmKW+Oi6F1I58iatOeL9k+yIsfOh1zh910g= +github.com/antlabs/wsutil v0.1.6 h1:K7wR+EvqQT1Nn7jAKs3dKsGtUykPD2OYlCicv4/tUf8= +github.com/antlabs/wsutil v0.1.6/go.mod h1:7ec5eUM7nmKW+Oi6F1I58iatOeL9k+yIsfOh1zh910g= golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= diff --git a/proxy.go b/proxy.go index 0b9a06c..9aaa8a8 100644 --- a/proxy.go +++ b/proxy.go @@ -20,6 +20,8 @@ import ( "net" "net/http" "net/url" + + "github.com/antlabs/wsutil/hostname" ) type ( @@ -41,7 +43,7 @@ func (h *httpProxy) Dial(network, addr string) (c net.Conn, err error) { return h.dial(network, addr) } - hostName := getHostName(h.proxyAddr) + hostName := hostname.GetHostName(h.proxyAddr) c, err = h.dial(network, hostName) if err != nil { return nil, err diff --git a/proxy_test.go b/proxy_test.go index 863f614..4d3d8fa 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -14,8 +14,112 @@ package quickws -import "testing" +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +type testServer struct { + path string + rawQuery string + requestURL string + subprotos []string + *testing.T +} + +func newTestServer(t *testing.T) *testServer { + return &testServer{path: "/test", rawQuery: "a=1&b=2", requestURL: "/test?a=1&b=2", T: t, subprotos: []string{"proto1", "proto2"}} +} + +func (t *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != t.path { + t.Errorf("path error: %s", r.URL.Path) + return + } + + if r.URL.RawQuery != t.rawQuery { + t.Errorf("raw query error: %s", r.URL.RawQuery) + return + } + + sub := subProtocol(r.Header.Get("Sec-Websocket-Protocol"), &Config{subProtocols: t.subprotos}) + if sub != "proto1" { + t.Errorf("sub protocol error: (%s)", sub) + return + } + + conn, err := Upgrade(w, r, WithServerOnMessageFunc(func(c *Conn, o Opcode, b []byte) { + c.WriteMessage(o, b) + })) + if err != nil { + t.Error(err) + return + } + conn.ReadLoop() +} + +func (t *testServer) clientSend(c *Conn) { + c.WriteMessage(Text, []byte("hello world")) +} + +func HTTPToWS(u string) string { + return strings.ReplaceAll(u, "http://", "ws://") +} + +func WsToHTTP(u string) string { + return strings.ReplaceAll(u, "ws://", "http://") +} func Test_Proxy(t *testing.T) { - t.Run("test proxy", func(t *testing.T) {}) + t.Run("test proxy dial.1", func(t *testing.T) { + connect := false + s := newTestServer(t) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("method: %s, url: %s", r.Method, r.URL.String()) + if r.Method == http.MethodConnect { + connect = true + w.WriteHeader(http.StatusOK) + return + } + + if !connect { + t.Error("test proxy dial fail: not connect") + http.Error(w, "not connect", http.StatusMethodNotAllowed) + return + } + s.ServeHTTP(w, r) + })) + + defer ts.Close() + + proxy := func(*http.Request) (*url.URL, error) { + return url.Parse(HTTPToWS(ts.URL)) + } + + got := make(chan string, 1) + dstURL := HTTPToWS(ts.URL + s.requestURL) + con, err := Dial(dstURL, + WithClientProxyFunc(proxy), + WithClientSubprotocols(s.subprotos), + WithClientOnMessageFunc(func(c *Conn, o Opcode, b []byte) { + got <- string(b) + })) + if err != nil { + t.Error(err) + return + } + con.StartReadLoop() + s.clientSend(con) + + defer con.Close() + gotValue := <-got + if gotValue != "hello world" { + t.Errorf("got: %s, want: %s", gotValue, "hello world") + return + } + }) } diff --git a/server_options.go b/server_options.go index 53e4fe8..6755a78 100644 --- a/server_options.go +++ b/server_options.go @@ -27,10 +27,3 @@ func WithServerDecompressAndCompress() ServerOption { o.decompression = true } } - -// 2. 设置服务端支持的子协议 -func WithServerSubprotocols(subprotocols []string) ServerOption { - return func(o *ConnOption) { - o.subProtocols = subprotocols - } -} diff --git a/utils.go b/utils.go index 3ce1c14..a0389d5 100644 --- a/utils.go +++ b/utils.go @@ -21,9 +21,7 @@ import ( "fmt" "math/rand" "net/http" - "net/url" "reflect" - "strings" "time" "unsafe" ) @@ -89,23 +87,6 @@ func maybeCompressionDecompression(header http.Header) bool { return false } -func getHostName(u *url.URL) (hostName string) { - hostName = u.Hostname() - if u.Port() == "" { - switch strings.ToLower(u.Scheme) { - case "https": - hostName += ":443" - case "http": - hostName += ":80" - default: - panic(fmt.Sprintf("unknown scheme:%s", u.Scheme)) - } - return - } - - return u.Host -} - func getHttpErrMsg(statusCode int) error { errMsg := http.StatusText(statusCode) if errMsg != "" { diff --git a/utils_test.go b/utils_test.go index d1c22a7..b295f30 100644 --- a/utils_test.go +++ b/utils_test.go @@ -15,7 +15,6 @@ package quickws import ( - "net/url" "testing" ) @@ -42,36 +41,3 @@ func Test_getHttpErrMsg(t *testing.T) { } }) } - -type test_getHostName struct { - data string - need string -} - -func Test_getHostName(t *testing.T) { - t.Run("test 1", func(t *testing.T) { - for _, d := range []test_getHostName{ - { - data: "http://www.baidu.com", - need: "www.baidu.com:80", - }, - { - data: "http://www.baidu.com:333", - need: "www.baidu.com:333", - }, - { - data: "https://www.baidu.com", - need: "www.baidu.com:443", - }, - } { - - u, err := url.Parse(d.data) - if err != nil { - t.Errorf("err should be nil, got %s", err) - } - if getHostName(u) != d.need { - t.Errorf("need %s, got %s", d.need, getHostName(u)) - } - } - }) -}