diff --git a/client_option_test.go b/client_option_test.go index 7157c6b..c29ce40 100644 --- a/client_option_test.go +++ b/client_option_test.go @@ -23,6 +23,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "sync/atomic" "testing" "time" @@ -30,6 +31,23 @@ import ( "golang.org/x/net/proxy" ) +// // 实现安全的net.Conn +// type safeConn struct { +// net.Conn +// sync.Mutex +// } + +// func (s *safeConn) Write(b []byte) (n int, err error) { +// s.Lock() +// defer s.Unlock() +// return s.Conn.Write(b) +// } + +// func (s *safeConn) Read(b []byte) (n int, err error) { +// s.Lock() +// defer s.Unlock() +// return s.Conn.Read(b) +// } func Test_ClientOption(t *testing.T) { t.Run("ClientOption.WithClientHTTPHeader", func(t *testing.T) { done := make(chan string, 1) @@ -401,21 +419,76 @@ func Test_ClientOption(t *testing.T) { return } defer c2.Close() - done := make(chan struct{}) + + // done := make(chan struct{}) + // newConn = &safeConn{Conn: newConn} + // c2 = &safeConn{Conn: c2} + // go func() { + // _, err = io.Copy(newConn, c2) + // if err != nil { + // t.Error(err) + // return + // } + // close(done) + // }() + // _, err = io.Copy(c2, newConn) + // if err != nil { + // t.Error(err) + // return + // } + // <-done + + var ( + newConnMu sync.Mutex + c2Mu sync.Mutex + wg sync.WaitGroup + ) + + wg.Add(2) + go func() { - _, err = io.Copy(newConn, c2) - if err != nil { - t.Error(err) - return + defer wg.Done() + buf := make([]byte, 4096) + for { + n, err := c2.Read(buf) + if err != nil { + if err != io.EOF { + t.Error(err) + } + break + } + newConnMu.Lock() + _, err = newConn.Write(buf[:n]) + newConnMu.Unlock() + if err != nil { + t.Error(err) + break + } } - close(done) }() - _, err = io.Copy(c2, newConn) - if err != nil { - t.Error(err) - return - } - <-done + + go func() { + defer wg.Done() + buf := make([]byte, 4096) + for { + n, err := newConn.Read(buf) + if err != nil { + if err != io.EOF { + t.Error(err) + } + break + } + c2Mu.Lock() + _, err = c2.Write(buf[:n]) + c2Mu.Unlock() + if err != nil { + t.Error(err) + break + } + } + }() + + wg.Wait() }() got := make([]byte, 0, 128) diff --git a/common_options_test.go b/common_options_test.go index d90771e..6f2ec96 100644 --- a/common_options_test.go +++ b/common_options_test.go @@ -1698,6 +1698,7 @@ func Test_CommonOption(t *testing.T) { t.Run("13-15.client: WriteMessageDelay", func(t *testing.T) { run := int32(0) data := make(chan string, 1) + recvServer := int32(0) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := Upgrade(w, r, WithServerDecompressAndCompress(), @@ -1708,8 +1709,10 @@ func Test_CommonOption(t *testing.T) { t.Error(err) return } - atomic.AddInt32(&run, int32(1)) - data <- string(payload) + atomic.AddInt32(&recvServer, int32(1)) + if atomic.LoadInt32(&recvServer) == 3 { + c.Close() + } })) if err != nil { t.Error(err) @@ -1720,6 +1723,7 @@ func Test_CommonOption(t *testing.T) { defer ts.Close() url := strings.ReplaceAll(ts.URL, "http", "ws") + recv := int32(0) con, err := Dial(url, WithClientDecompressAndCompress(), WithClientDecompression(), WithClientMaxDelayWriteDuration(30*time.Millisecond), @@ -1727,21 +1731,9 @@ func Test_CommonOption(t *testing.T) { WithClientWindowsParseMode(), WithClientDelayWriteInitBufferSize(4096), WithClientOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { - err := c.WriteMessageDelay(op, []byte("hello")) - if err != nil { - t.Error(err) - return - } - - err = c.WriteMessageDelay(op, []byte("hello")) - if err != nil { - t.Error(err) - return - } - err = c.WriteMessageDelay(op, []byte("hello")) - if err != nil { - t.Error(err) - return + atomic.AddInt32(&recv, int32(1)) + if atomic.LoadInt32(&recv) == 3 { + data <- "hello" } })) if err != nil { @@ -1749,7 +1741,18 @@ func Test_CommonOption(t *testing.T) { } defer con.Close() - err = con.WriteMessage(Binary, []byte("hello")) + err = con.WriteMessageDelay(Text, []byte("hello")) + if err != nil { + t.Error(err) + return + } + + err = con.WriteMessageDelay(Text, []byte("hello")) + if err != nil { + t.Error(err) + return + } + err = con.WriteMessageDelay(Text, []byte("hello")) if err != nil { t.Error(err) return @@ -1760,10 +1763,11 @@ func Test_CommonOption(t *testing.T) { if d != "hello" { t.Errorf("write message or read message fail:got:%s, need:hello\n", d) } + run++ case <-time.After(1000 * time.Millisecond): } if atomic.LoadInt32(&run) != 1 { - t.Error("not run server:method fail") + t.Errorf("not run server:method fail:%d", atomic.LoadInt32(&run)) } }) @@ -1773,14 +1777,13 @@ func Test_CommonOption(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := Upgrade(w, r, WithServerDecompressAndCompress(), - // WithServerBufioParseMode(), WithServerCallbackFunc(nil, func(c *Conn, op Opcode, payload []byte) { - if op != Binary { + if op != Text { t.Error("opcode error") } - err := c.WriteMessage(op, payload) - if err != nil { - t.Error(err) + err1 := c.WriteMessage(op, payload) + if err1 != nil { + t.Errorf("c.WriteMessage:%v", err1) return } }, func(c *Conn, err error) { @@ -1804,6 +1807,7 @@ func Test_CommonOption(t *testing.T) { defer ts.Close() url := strings.ReplaceAll(ts.URL, "http", "ws") + recv := int32(0) con, err := Dial(url, WithClientDecompressAndCompress(), WithClientMaxDelayWriteDuration(30*time.Millisecond), @@ -1811,23 +1815,14 @@ func Test_CommonOption(t *testing.T) { WithClientWindowsParseMode(), WithClientDelayWriteInitBufferSize(4096), WithClientOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { - if op != Binary { + if op != Text { t.Error("opcode error") } - err := c.WriteMessageDelay(op, []byte("hello")) - if err != nil { - t.Error(err) - } - err = c.WriteMessageDelay(op, []byte("hello")) - if err != nil { - t.Error(err) + atomic.AddInt32(&recv, int32(1)) + if atomic.LoadInt32(&recv) == 3 { + data <- "hello" } - err = c.WriteMessageDelay(op, []byte("hello")) - if err != nil { - t.Error(err) - } - data <- "hello" - atomic.AddInt32(&run, int32(1)) + // atomic.AddInt32(&run, int32(1)) })) if err != nil { t.Error(err) @@ -1837,7 +1832,15 @@ func Test_CommonOption(t *testing.T) { if !con.Compression { t.Error("not compression:method fail") } - err = con.WriteMessage(Binary, []byte("hello")) + err = con.WriteMessageDelay(Text, []byte("hello")) + if err != nil { + t.Error(err) + } + err = con.WriteMessageDelay(Text, []byte("hello")) + if err != nil { + t.Error(err) + } + err = con.WriteMessageDelay(Text, []byte("hello")) if err != nil { t.Error(err) } @@ -1848,6 +1851,7 @@ func Test_CommonOption(t *testing.T) { if d != "hello" { t.Errorf("write message or read message fail:got:%s, need:hello\n", d) } + run++ case <-time.After(1000 * time.Millisecond): t.Errorf("write message timeout\n") } diff --git a/utils.go b/utils.go index 474de81..849b7f6 100644 --- a/utils.go +++ b/utils.go @@ -21,12 +21,14 @@ import ( "fmt" "math/rand" "net/http" + "sync" "time" "unsafe" ) var rng = rand.New(rand.NewSource(time.Now().UnixNano())) +var mu sync.Mutex var uuid = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func StringToBytes(s string) []byte { @@ -50,7 +52,9 @@ func StringToBytes(s string) []byte { func secWebSocketAccept() string { // rfc规定是16字节 var key [16]byte + mu.Lock() rng.Read(key[:]) + mu.Unlock() return base64.StdEncoding.EncodeToString(key[:]) }