Skip to content

Commit

Permalink
dial proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong committed Dec 15, 2023
1 parent 256e849 commit 8d72fb7
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 69 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*swp
/.idea
/coverage.out
/cover.cov
autobahn-testsuite
Expand Down
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ quickws是一个高性能的websocket库
* [配置握手时的超时时间](#配置握手时的超时时间)
* [配置自动回复ping消息](#配置自动回复ping消息)
* [配置socks5代理](#配置socks5代理)
* [配置proxy代理](#配置proxy代理)
* [服务配置参数](#服务端配置)
* [配置服务自动回复ping消息](#配置服务自动回复ping消息)
## 注意⚠️
Expand Down Expand Up @@ -210,6 +211,21 @@ func main() {
}))
}
```
#### 配置proxy代理
```go
import(
"github.com/antlabs/quickws"
)

func main() {

proxy := func(*http.Request) (*url.URL, error) {
return url.Parse("http://127.0.0.1:1007")
}

quickws.Dial("ws://127.0.0.1:12345", quickws.WithClientProxyFunc(proxy))
}
```
### 服务端配置参数
#### 配置服务自动回复ping消息
```go
Expand Down
2 changes: 0 additions & 2 deletions autobahn/Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
all:
# mac, arm64
GOOS=darwin GOARCH=arm64 go build -o autobahn-server-darwin-arm64 ./autobahn-server.go
# mac, arm64
GOOS=darwin GOARCH=arm64 go build -tags=goexperiment.arenas -o autobahn-server-darwin-arm64-arena ./autobahn-server.go
# linux amd64
GOOS=linux GOARCH=amd64 go build -o autobahn-server-linux-amd64 ./autobahn-server.go

Expand Down
7 changes: 6 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/antlabs/wsutil/bytespool"
"github.com/antlabs/wsutil/enum"
"github.com/antlabs/wsutil/fixedreader"
"github.com/antlabs/wsutil/hostname"
)

var (
Expand Down Expand Up @@ -127,6 +128,10 @@ func (d *DialOption) handshake() (*http.Request, string, error) {
d.Header.Add("Sec-WebSocket-Extensions", strExtensions)
}

if len(d.subProtocols) > 0 {
d.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.subProtocols, ", ")}
}

req.Header = d.Header
return req, secWebSocket, nil
}
Expand Down Expand Up @@ -193,7 +198,7 @@ func (d *DialOption) Dial() (c *Conn, err error) {
var conn net.Conn
begin := time.Now()

hostName := getHostName(d.u)
hostName := hostname.GetHostName(d.u)
// conn, err := net.DialTimeout("tcp", d.u.Host /* TODO 加端号*/, d.dialTimeout)
dialFunc := net.Dial
if d.dialFunc != nil {
Expand Down
15 changes: 15 additions & 0 deletions common_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,18 @@ func WithClientProxyFunc(proxyFunc func(*http.Request) (*url.URL, error)) Client
o.proxyFunc = proxyFunc
}
}

// 20. 设置支持的子协议
// 20.1 设置客户端支持的子协议
func WithClientSubprotocols(subprotocols []string) ClientOption {
return func(o *DialOption) {
o.subProtocols = subprotocols
}
}

// 20.2 设置服务端支持的子协议
func WithServerSubprotocols(subprotocols []string) ServerOption {
return func(o *ConnOption) {
o.subProtocols = subprotocols
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ module github.com/antlabs/quickws
go 1.20

require (
github.com/antlabs/wsutil v0.1.2
github.com/antlabs/wsutil v0.1.6
golang.org/x/net v0.19.0
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
github.com/antlabs/wsutil v0.1.2 h1:8H6E0eMJ2Wp0qi9YGDeyG3DlIfIZncw2NSScC5bYSBQ=
github.com/antlabs/wsutil v0.1.2/go.mod h1:7ec5eUM7nmKW+Oi6F1I58iatOeL9k+yIsfOh1zh910g=
github.com/antlabs/wsutil v0.1.6 h1:K7wR+EvqQT1Nn7jAKs3dKsGtUykPD2OYlCicv4/tUf8=
github.com/antlabs/wsutil v0.1.6/go.mod h1:7ec5eUM7nmKW+Oi6F1I58iatOeL9k+yIsfOh1zh910g=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
4 changes: 3 additions & 1 deletion proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"net"
"net/http"
"net/url"

"github.com/antlabs/wsutil/hostname"
)

type (
Expand All @@ -41,7 +43,7 @@ func (h *httpProxy) Dial(network, addr string) (c net.Conn, err error) {
return h.dial(network, addr)
}

hostName := getHostName(h.proxyAddr)
hostName := hostname.GetHostName(h.proxyAddr)
c, err = h.dial(network, hostName)
if err != nil {
return nil, err
Expand Down
108 changes: 106 additions & 2 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,112 @@

package quickws

import "testing"
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)

type testServer struct {
path string
rawQuery string
requestURL string
subprotos []string
*testing.T
}

func newTestServer(t *testing.T) *testServer {
return &testServer{path: "/test", rawQuery: "a=1&b=2", requestURL: "/test?a=1&b=2", T: t, subprotos: []string{"proto1", "proto2"}}
}

func (t *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != t.path {
t.Errorf("path error: %s", r.URL.Path)
return
}

if r.URL.RawQuery != t.rawQuery {
t.Errorf("raw query error: %s", r.URL.RawQuery)
return
}

sub := subProtocol(r.Header.Get("Sec-Websocket-Protocol"), &Config{subProtocols: t.subprotos})
if sub != "proto1" {
t.Errorf("sub protocol error: (%s)", sub)
return
}

conn, err := Upgrade(w, r, WithServerOnMessageFunc(func(c *Conn, o Opcode, b []byte) {
c.WriteMessage(o, b)
}))
if err != nil {
t.Error(err)
return
}
conn.ReadLoop()
}

func (t *testServer) clientSend(c *Conn) {
c.WriteMessage(Text, []byte("hello world"))
}

func HTTPToWS(u string) string {
return strings.ReplaceAll(u, "http://", "ws://")
}

func WsToHTTP(u string) string {
return strings.ReplaceAll(u, "ws://", "http://")
}

func Test_Proxy(t *testing.T) {
t.Run("test proxy", func(t *testing.T) {})
t.Run("test proxy dial.1", func(t *testing.T) {
connect := false
s := newTestServer(t)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("method: %s, url: %s", r.Method, r.URL.String())
if r.Method == http.MethodConnect {
connect = true
w.WriteHeader(http.StatusOK)
return
}

if !connect {
t.Error("test proxy dial fail: not connect")
http.Error(w, "not connect", http.StatusMethodNotAllowed)
return
}
s.ServeHTTP(w, r)
}))

defer ts.Close()

proxy := func(*http.Request) (*url.URL, error) {
return url.Parse(HTTPToWS(ts.URL))
}

got := make(chan string, 1)
dstURL := HTTPToWS(ts.URL + s.requestURL)
con, err := Dial(dstURL,
WithClientProxyFunc(proxy),
WithClientSubprotocols(s.subprotos),
WithClientOnMessageFunc(func(c *Conn, o Opcode, b []byte) {
got <- string(b)
}))
if err != nil {
t.Error(err)
return
}
con.StartReadLoop()
s.clientSend(con)

defer con.Close()
gotValue := <-got
if gotValue != "hello world" {
t.Errorf("got: %s, want: %s", gotValue, "hello world")
return
}
})
}
7 changes: 0 additions & 7 deletions server_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,3 @@ func WithServerDecompressAndCompress() ServerOption {
o.decompression = true
}
}

// 2. 设置服务端支持的子协议
func WithServerSubprotocols(subprotocols []string) ServerOption {
return func(o *ConnOption) {
o.subProtocols = subprotocols
}
}
19 changes: 0 additions & 19 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import (
"fmt"
"math/rand"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"unsafe"
)
Expand Down Expand Up @@ -89,23 +87,6 @@ func maybeCompressionDecompression(header http.Header) bool {
return false
}

func getHostName(u *url.URL) (hostName string) {
hostName = u.Hostname()
if u.Port() == "" {
switch strings.ToLower(u.Scheme) {
case "https":
hostName += ":443"
case "http":
hostName += ":80"
default:
panic(fmt.Sprintf("unknown scheme:%s", u.Scheme))
}
return
}

return u.Host
}

func getHttpErrMsg(statusCode int) error {
errMsg := http.StatusText(statusCode)
if errMsg != "" {
Expand Down
34 changes: 0 additions & 34 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package quickws

import (
"net/url"
"testing"
)

Expand All @@ -42,36 +41,3 @@ func Test_getHttpErrMsg(t *testing.T) {
}
})
}

type test_getHostName struct {
data string
need string
}

func Test_getHostName(t *testing.T) {
t.Run("test 1", func(t *testing.T) {
for _, d := range []test_getHostName{
{
data: "http://www.baidu.com",
need: "www.baidu.com:80",
},
{
data: "http://www.baidu.com:333",
need: "www.baidu.com:333",
},
{
data: "https://www.baidu.com",
need: "www.baidu.com:443",
},
} {

u, err := url.Parse(d.data)
if err != nil {
t.Errorf("err should be nil, got %s", err)
}
if getHostName(u) != d.need {
t.Errorf("need %s, got %s", d.need, getHostName(u))
}
}
})
}

0 comments on commit 8d72fb7

Please sign in to comment.