Skip to content

Commit

Permalink
修改1.timeout,2. 修改测试代码 (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong authored May 24, 2024
1 parent e2ce486 commit 601b46e
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 46 deletions.
27 changes: 7 additions & 20 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,34 +196,32 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) {
}

var conn net.Conn
begin := time.Now()

hostName := hostname.GetHostName(d.u)
// conn, err := net.DialTimeout("tcp", d.u.Host /* TODO 加端号*/, d.dialTimeout)
dialFunc := net.Dial
dialFunc := net.DialTimeout
if d.dialFunc != nil {
dialInterface, err := d.dialFunc()
if err != nil {
return nil, err
}
dialFunc = dialInterface.Dial
dialFunc = func(network, address string, timeout time.Duration) (net.Conn, error) {
return dialInterface.Dial(network, address)
}
}

if d.proxyFunc != nil {
proxyURL, err := d.proxyFunc(req)
if err != nil {
return nil, err
}
dialFunc = newhttpProxy(proxyURL, dialFunc).Dial
dialFunc = newhttpProxy(proxyURL, dialFunc).DialTimeout
}

conn, err = dialFunc("tcp", hostName)
conn, err = dialFunc("tcp", hostName, d.dialTimeout)
if err != nil {
return nil, err
}

dialDuration := time.Since(begin)

conn = d.tlsConn(conn)
defer func() {
if err != nil && conn != nil {
Expand All @@ -232,18 +230,7 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) {
}
}()

if to := d.dialTimeout - dialDuration; to > 0 {
if err = conn.SetDeadline(time.Now().Add(to)); err != nil {
return
}
}

defer func() {
if err == nil {
err = conn.SetDeadline(time.Time{})
}
}()

err = conn.SetDeadline(time.Time{})
if err = req.Write(conn); err != nil {
return
}
Expand Down
50 changes: 48 additions & 2 deletions common_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2227,13 +2227,14 @@ func Test_CommonOption(t *testing.T) {
t.Run("22.3.WithClientReadMaxMessage", func(t *testing.T) {
var tsort testServerOptionReadTimeout

upgrade := NewUpgrade(WithServerCallback(&tsort), WithServerReadTimeout(time.Millisecond*60))
upgrade := NewUpgrade()
tsort.err = make(chan error, 1)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrade.Upgrade(w, r)
if err != nil {
t.Error(err)
}
time.Sleep(time.Second / 100)
err = c.WriteMessage(Binary, bytes.Repeat([]byte("1"), 1025))
if err != nil {
t.Error(err)
Expand All @@ -2245,12 +2246,57 @@ func Test_CommonOption(t *testing.T) {
defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
con, err := Dial(url, WithClientBufioParseMode(), WithClientReadMaxMessage(1<<10), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) {
con, err := Dial(url, WithClientCallback(&tsort), WithClientBufioParseMode(), WithClientReadMaxMessage(1<<10))
if err != nil {
t.Error(err)
return
}
defer con.Close()
go func() {
_ = con.ReadLoop()
}()

select {
case d := <-tsort.err:
if d == nil {
t.Errorf("got:nil, need:error\n")
}
case <-time.After(100 * time.Hour):
t.Errorf(" Test_ServerOption:WithServerReadMaxMessage timeout\n")
}
if atomic.LoadInt32(&tsort.run) != 1 {
t.Error("not run server:method fail")
}
})
t.Run("22.4.WithClientReadMaxMessage-ParseWindows", func(t *testing.T) {
var tsort testServerOptionReadTimeout

upgrade := NewUpgrade()
tsort.err = make(chan error, 1)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrade.Upgrade(w, r)
if err != nil {
t.Error(err)
}
time.Sleep(time.Second / 100)
err = c.WriteMessage(Binary, bytes.Repeat([]byte("1"), 1025))
if err != nil {
t.Error(err)
return
}
c.StartReadLoop()
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
con, err := Dial(url, WithClientCallback(&tsort), WithClientReadMaxMessage(1<<10))
if err != nil {
t.Error(err)
return
}
defer con.Close()
con.StartReadLoop()

select {
case d := <-tsort.err:
Expand Down
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@ import (

var ErrDialFuncAndProxyFunc = errors.New("dialFunc and proxyFunc can't be set at the same time")

// 握手
type Dialer interface {
Dial(network, addr string) (c net.Conn, err error)
}

// 带超时时间的握手
type DialerTimeout interface {
DialTimeout(network, addr string, timeout time.Duration) (c net.Conn, err error)
}

// Config的配置,有两个种用法
// 一种是声明一个全局的配置,后面不停使用。
// 另外一种是局部声明一个配置,然后使用WithXXX函数设置配置
Expand Down
5 changes: 3 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,11 @@ func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, bufioPa
}
} else {
r := io.Reader(c.br)
var lr io.Reader
if c.readMaxMessage > 0 {
r = limitreader.NewLimitReader(c.br, c.readMaxMessage)
lr = limitreader.NewLimitReader(c.br, c.readMaxMessage)
}
f, err = frame.ReadFrameFromReaderV2(r, headArray, bufioPayload)
f, err = frame.ReadFrameFromReaderV3(r, lr, headArray, bufioPayload)
}
if err != nil {
c.writeAndMaybeOnClose(err)
Expand Down
2 changes: 1 addition & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ func TestFragmentFrame(t *testing.T) {
select {
case <-data:
atomic.AddInt32(&run, 1)
case <-time.After(500 * time.Hour):
case <-time.After(500 * time.Millisecond):
}

if atomic.LoadInt32(&run) != 1 {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/antlabs/quickws
go 1.21

require (
github.com/antlabs/wsutil v0.1.10
github.com/antlabs/wsutil v0.1.11
golang.org/x/net v0.23.0
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/antlabs/wsutil v0.1.10 h1:86p67dG8/iiQ+yZrHVl73OPHGnXfXopFSU0w84fLOdE=
github.com/antlabs/wsutil v0.1.10/go.mod h1:Pk7xYOw3o5iEB6ukiOu+2uJMLYeMVVjJLazFD3okI2A=
github.com/antlabs/wsutil v0.1.11 h1:bIVZ3Hxdq5ByZKu5OXL/cMtanEw6YlxdtUDiySI77Q0=
github.com/antlabs/wsutil v0.1.11/go.mod h1:Pk7xYOw3o5iEB6ukiOu+2uJMLYeMVVjJLazFD3okI2A=
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
Expand Down
18 changes: 10 additions & 8 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,33 @@ import (
"net"
"net/http"
"net/url"
"time"

"github.com/antlabs/wsutil/hostname"
)

type (
dialFunc func(network, addr string) (c net.Conn, err error)
dialFunc func(network, addr string, timeout time.Duration) (c net.Conn, err error)
httpProxy struct {
proxyAddr *url.URL
dial func(network, addr string) (c net.Conn, err error)
proxyAddr *url.URL
dialTimeout func(network, addr string, timeout time.Duration) (c net.Conn, err error)
timeout time.Duration
}
)

var _ Dialer = (*httpProxy)(nil)
var _ DialerTimeout = (*httpProxy)(nil)

func newhttpProxy(u *url.URL, dial dialFunc) *httpProxy {
return &httpProxy{proxyAddr: u, dial: dial}
return &httpProxy{proxyAddr: u, dialTimeout: dial}
}

func (h *httpProxy) Dial(network, addr string) (c net.Conn, err error) {
func (h *httpProxy) DialTimeout(network, addr string, timeout time.Duration) (c net.Conn, err error) {
if h.proxyAddr == nil {
return h.dial(network, addr)
return h.dialTimeout(network, addr, h.timeout)
}

hostName := hostname.GetHostName(h.proxyAddr)
c, err = h.dial(network, hostName)
c, err = h.dialTimeout(network, hostName, h.timeout)
if err != nil {
return nil, err
}
Expand Down
22 changes: 12 additions & 10 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/url"
"strings"
"testing"
"time"
)

type testServer struct {
Expand Down Expand Up @@ -133,7 +134,7 @@ func Test_Proxy(t *testing.T) {
func Test_httpProxy_Dial(t *testing.T) {
type fields struct {
proxyAddr *url.URL
dial func(network, addr string) (c net.Conn, err error)
dial func(network, addr string, timeout time.Duration) (c net.Conn, err error)
}
type args struct {
network string
Expand All @@ -146,12 +147,12 @@ func Test_httpProxy_Dial(t *testing.T) {
wantC net.Conn
wantErr bool
}{
// TODO: Add test cases.
// 0
{
name: "No proxy address",
fields: fields{
proxyAddr: nil,
dial: func(network, addr string) (c net.Conn, err error) {
dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) {
// Simulate successful dialing
return &net.TCPConn{}, errors.New("fail")
},
Expand All @@ -163,11 +164,12 @@ func Test_httpProxy_Dial(t *testing.T) {
wantC: &net.TCPConn{},
wantErr: true,
},
// 1
{
name: "Proxy address",
fields: fields{
proxyAddr: &url.URL{Host: "1.2.3:8080", User: url.UserPassword("user", "password")},
dial: func(network, addr string) (c net.Conn, err error) {
dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) {
// Simulate successful dialing
return &net.TCPConn{}, errors.New("fail")
},
Expand All @@ -179,11 +181,12 @@ func Test_httpProxy_Dial(t *testing.T) {
wantC: &net.TCPConn{},
wantErr: true,
},
// 2
{
name: "Proxy address",
fields: fields{
proxyAddr: &url.URL{Host: "1.2.3:8080", User: url.UserPassword("user", "password")},
dial: func(network, addr string) (c net.Conn, err error) {
dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) {
// Simulate successful dialing
return &net.TCPConn{}, nil
},
Expand All @@ -193,17 +196,16 @@ func Test_httpProxy_Dial(t *testing.T) {
addr: "a.b.c:80",
},
wantC: &net.TCPConn{},
wantErr: true,
wantErr: false,
},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &httpProxy{
proxyAddr: tt.fields.proxyAddr,
dial: tt.fields.dial,
proxyAddr: tt.fields.proxyAddr,
dialTimeout: tt.fields.dial,
}
_, err := h.Dial(tt.args.network, tt.args.addr)
// gotC, err := h.Dial(tt.args.network, tt.args.addr)
_, err := h.dialTimeout(tt.args.network, tt.args.addr, 0)
if (err != nil) != tt.wantErr {
t.Errorf("index:%d, httpProxy.Dial() error = %v, wantErr %v", i, err, tt.wantErr)
return
Expand Down

0 comments on commit 601b46e

Please sign in to comment.