Skip to content

Commit

Permalink
OnClose加上once,保证只调用一次
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong committed May 20, 2024
1 parent 5fec807 commit 4b1f131
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 16 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@ quickws是一个高性能的websocket库
[![Go Report Card](https://goreportcard.com/badge/github.com/antlabs/quickws)](https://goreportcard.com/report/github.com/antlabs/quickws)

## 特性
* 3倍的简单
* 完整实现rfc6455
* 完整实现rfc7692
* 高tps
* 低内存占用
* 池化管理所有buffer

## 内容
* [安装](#Installation)
* [例子](#example)
* [标准库服务端](#标准库服务端)
* [gin服务端](#gin服务端)
* [net/http升级到websocket服务端](#net-http升级到websocket服务端)
* [gin升级到websocket服务端](#gin升级到websocket服务端)
* [客户端](#客户端)
* [配置函数](#配置函数)
* [客户端配置参数](#客户端配置)
Expand All @@ -34,7 +36,7 @@ go get github.com/antlabs/quickws
```

## example
### 标准库服务端
### net http升级到websocket服务端
```go

package main
Expand Down Expand Up @@ -87,7 +89,7 @@ func main() {
}

```
### gin服务端
### gin升级到websocket服务端
```go
package main

Expand Down
45 changes: 35 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/antlabs/wsutil/fixedwriter"
"github.com/antlabs/wsutil/frame"
"github.com/antlabs/wsutil/limitreader"
"github.com/antlabs/wsutil/myonce"
"github.com/antlabs/wsutil/opcode"
)

Expand Down Expand Up @@ -70,7 +71,9 @@ type Conn struct {
deCtx *deflate.DeCompressContextTakeover // 解压缩上下文
enCtx *deflate.CompressContextTakeover // 压缩上下文
closed int32 // 0: open, 1: closed
client bool // client(true) or server(flase)
mu2 sync.Mutex
onCloseOnce myonce.MyOnce // 保证只调用一次OnClose函数
client bool // client(true) or server(flase)
}

func setNoDelay(c net.Conn, noDelay bool) error {
Expand Down Expand Up @@ -105,7 +108,11 @@ func (c *Conn) NetConn() net.Conn {

func (c *Conn) writeAndMaybeOnClose(err error) error {
var sc *StatusCode
defer c.Callback.OnClose(c, err)
defer func() {
c.onCloseOnce.Do(&c.mu2, func() {
c.Callback.OnClose(c, err)
})
}()

if errors.As(err, &sc) {
if err := c.WriteTimeout(opcode.Close, sc.toBytes(), 2*time.Second); err != nil {
Expand All @@ -116,7 +123,11 @@ func (c *Conn) writeAndMaybeOnClose(err error) error {
}

func (c *Conn) writeErrAndOnClose(code StatusCode, userErr error) error {
defer c.Callback.OnClose(c, userErr)
defer func() {
c.onCloseOnce.Do(&c.mu2, func() {
c.Callback.OnClose(c, userErr)
})
}()
if err := c.WriteTimeout(opcode.Close, code.toBytes(), 2*time.Second); err != nil {
return err
}
Expand Down Expand Up @@ -180,7 +191,10 @@ func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, bufioPa
if c.readTimeout > 0 {
err = c.c.SetReadDeadline(time.Now().Add(c.readTimeout))
if err != nil {
c.Callback.OnClose(c, err)

c.onCloseOnce.Do(&c.mu2, func() {
c.Callback.OnClose(c, err)
})
return
}
}
Expand All @@ -204,7 +218,9 @@ func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, bufioPa

if c.readTimeout > 0 {
if err = c.c.SetReadDeadline(time.Time{}); err != nil {
c.Callback.OnClose(c, err)
c.onCloseOnce.Do(&c.mu2, func() {
c.Callback.OnClose(c, err)
})
}
}
return
Expand Down Expand Up @@ -250,7 +266,9 @@ func (c *Conn) readMessage() (err error) {
// 这里的check按道理应该放到f.Fin前面, 会更符合rfc的标准, 前提是c.utf8Check修改成流式解析
// TODO c.utf8Check 修改成流式解析
if c.fragmentFrameHeader.Opcode == opcode.Text && !c.utf8Check(*c.fragmentFramePayload) {
c.Callback.OnClose(c, ErrTextNotUTF8)
c.onCloseOnce.Do(&c.mu2, func() {
c.Callback.OnClose(c, ErrTextNotUTF8)
})
return ErrTextNotUTF8
}

Expand Down Expand Up @@ -300,7 +318,9 @@ func (c *Conn) readMessage() (err error) {
if f.Opcode == opcode.Text {
if !c.utf8Check(*f.Payload) {
c.c.Close()
c.Callback.OnClose(c, ErrTextNotUTF8)
c.onCloseOnce.Do(&c.mu2, func() {
c.Callback.OnClose(c, ErrTextNotUTF8)
})
return ErrTextNotUTF8
}
}
Expand All @@ -326,7 +346,8 @@ func (c *Conn) readMessage() (err error) {

if f.Opcode == Close {
if len(*f.Payload) == 0 {
return c.writeErrAndOnClose(NormalClosure, ErrClosePayloadTooSmall)
c.writeErrAndOnClose(NormalClosure, &CloseErrMsg{Code: NormalClosure})
return nil
}

if len(*f.Payload) < 2 {
Expand All @@ -348,15 +369,19 @@ func (c *Conn) readMessage() (err error) {
}

err = bytesToCloseErrMsg(*f.Payload)
c.Callback.OnClose(c, err)
c.onCloseOnce.Do(&c.mu2, func() {
c.Callback.OnClose(c, err)
})
return err
}

if f.Opcode == Ping {
// 回一个pong包
if c.replyPing {
if err := c.WriteTimeout(Pong, *f.Payload, 2*time.Second); err != nil {
c.Callback.OnClose(c, err)
c.onCloseOnce.Do(&c.mu2, func() {
c.Callback.OnClose(c, err)
})
return err
}
c.Callback.OnMessage(c, f.Opcode, *f.Payload)
Expand Down
2 changes: 1 addition & 1 deletion status_codes.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (s StatusCode) String() string {
return "ServerTerminating"
}

return "unkown"
return "unknown"
}

func (s StatusCode) Error() string {
Expand Down

0 comments on commit 4b1f131

Please sign in to comment.