From ad3f0c2cbd6c11bc8adfc16a38e28f6f258546f0 Mon Sep 17 00:00:00 2001 From: Andrii Date: Wed, 13 Nov 2024 12:52:52 +0200 Subject: [PATCH] abolish an extra goroutine --- close.go | 6 ------ conn.go | 63 ++++++++++++++++++++++++++++++++------------------------ read.go | 12 +++++++---- write.go | 6 ++++-- 4 files changed, 48 insertions(+), 39 deletions(-) diff --git a/close.go b/close.go index ff2e878a..820354e4 100644 --- a/close.go +++ b/close.go @@ -232,12 +232,6 @@ func (c *Conn) waitGoroutines() error { t := time.NewTimer(time.Second * 15) defer t.Stop() - select { - case <-c.timeoutLoopDone: - case <-t.C: - return errors.New("failed to wait for timeoutLoop goroutine to exit") - } - c.closeReadMu.Lock() closeRead := c.closeReadCtx != nil c.closeReadMu.Unlock() diff --git a/conn.go b/conn.go index d7434a9d..17d2d6b5 100644 --- a/conn.go +++ b/conn.go @@ -52,9 +52,8 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer - readTimeout chan context.Context - writeTimeout chan context.Context - timeoutLoopDone chan struct{} + readTimeoutCloser atomic.Value + writeTimeoutCloser atomic.Value // Read state. readMu *mu @@ -104,10 +103,6 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), - timeoutLoopDone: make(chan struct{}), - closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } @@ -133,8 +128,6 @@ func newConn(cfg connConfig) *Conn { c.close() }) - go c.timeoutLoop() - return c } @@ -164,26 +157,42 @@ func (c *Conn) close() error { return err } -func (c *Conn) timeoutLoop() { - defer close(c.timeoutLoopDone) +func (c *Conn) setupWriteTimeout(ctx context.Context) { + hammerTime := context.AfterFunc(ctx, func() { + c.close() + }) - readCtx := context.Background() - writeCtx := context.Background() + if closer := c.writeTimeoutCloser.Swap(hammerTime); closer != nil { + if fn, ok := closer.(func() bool); ok { + fn() + } + } +} - for { - select { - case <-c.closed: - return - - case writeCtx = <-c.writeTimeout: - case readCtx = <-c.readTimeout: - - case <-readCtx.Done(): - c.close() - return - case <-writeCtx.Done(): - c.close() - return +func (c *Conn) clearWriteTimeout() { + if closer := c.writeTimeoutCloser.Load(); closer != nil { + if fn, ok := closer.(func() bool); ok { + fn() + } + } +} + +func (c *Conn) setupReadTimeout(ctx context.Context) { + hammerTime := context.AfterFunc(ctx, func() { + defer c.close() + }) + + if closer := c.readTimeoutCloser.Swap(hammerTime); closer != nil { + if fn, ok := closer.(func() bool); ok { + fn() + } + } +} + +func (c *Conn) clearReadTimeout() { + if closer := c.readTimeoutCloser.Load(); closer != nil { + if fn, ok := closer.(func() bool); ok { + fn() } } } diff --git a/read.go b/read.go index e2699da5..89292356 100644 --- a/read.go +++ b/read.go @@ -221,7 +221,8 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: return header{}, net.ErrClosed - case c.readTimeout <- ctx: + default: + c.setupReadTimeout(ctx) } h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) @@ -239,7 +240,8 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: return header{}, net.ErrClosed - case c.readTimeout <- context.Background(): + default: + c.clearReadTimeout() } return h, nil @@ -249,7 +251,8 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: return 0, net.ErrClosed - case c.readTimeout <- ctx: + default: + c.setupReadTimeout(ctx) } n, err := io.ReadFull(c.br, p) @@ -267,7 +270,8 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: return n, net.ErrClosed - case c.readTimeout <- context.Background(): + default: + c.clearReadTimeout() } return n, err diff --git a/write.go b/write.go index e294a680..ac0a1ac7 100644 --- a/write.go +++ b/write.go @@ -252,7 +252,8 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco select { case <-c.closed: return 0, net.ErrClosed - case c.writeTimeout <- ctx: + default: + c.setupWriteTimeout(ctx) } defer func() { @@ -309,7 +310,8 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco return n, nil } return n, net.ErrClosed - case c.writeTimeout <- context.Background(): + default: + c.clearWriteTimeout() } return n, nil