From 11bda985bf5f0a7169f973913a52ca7f39ab9d99 Mon Sep 17 00:00:00 2001 From: Moritz Date: Wed, 4 Dec 2024 20:11:43 +0100 Subject: [PATCH] fix: avoid writing messages after close and improve handshake (#476) Co-authored-by: Mathias Fredriksson --- close.go | 12 ++--- conn.go | 10 +++- conn_test.go | 149 ++++++++++++++++++++++++++++++++++++++++++++++++++- read.go | 94 ++++++++++++++++++++------------ write.go | 52 +++++++++++------- 5 files changed, 252 insertions(+), 65 deletions(-) diff --git a/close.go b/close.go index ff2e878a..f94951dc 100644 --- a/close.go +++ b/close.go @@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode { func (c *Conn) Close(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") - if !c.casClosing() { + if c.casClosing() { err = c.waitGoroutines() if err != nil { return err @@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) { func (c *Conn) CloseNow() (err error) { defer errd.Wrap(&err, "failed to immediately close WebSocket") - if !c.casClosing() { + if c.casClosing() { err = c.waitGoroutines() if err != nil { return err @@ -329,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) { } func (c *Conn) casClosing() bool { - c.closeMu.Lock() - defer c.closeMu.Unlock() - if !c.closing { - c.closing = true - return true - } - return false + return c.closing.Swap(true) } func (c *Conn) isClosed() bool { diff --git a/conn.go b/conn.go index d7434a9d..76b057dd 100644 --- a/conn.go +++ b/conn.go @@ -69,13 +69,19 @@ type Conn struct { writeHeaderBuf [8]byte writeHeader header + // Close handshake state. + closeStateMu sync.RWMutex + closeReceivedErr error + closeSentErr error + + // CloseRead state. closeReadMu sync.Mutex closeReadCtx context.Context closeReadDone chan struct{} + closing atomic.Bool + closeMu sync.Mutex // Protects following. closed chan struct{} - closeMu sync.Mutex - closing bool pingCounter atomic.Int64 activePingsMu sync.Mutex diff --git a/conn_test.go b/conn_test.go index b4d57f21..9ed8c7ea 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "os" @@ -460,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { } func BenchmarkConn(b *testing.B) { - var benchCases = []struct { + benchCases := []struct { name string mode websocket.CompressionMode }{ @@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) { }() } } + +func TestConnClosePropagation(t *testing.T) { + t.Parallel() + + want := []byte("hello") + keepWriting := func(c *websocket.Conn) <-chan error { + return xsync.Go(func() error { + for { + err := c.Write(context.Background(), websocket.MessageText, want) + if err != nil { + return err + } + } + }) + } + keepReading := func(c *websocket.Conn) <-chan error { + return xsync.Go(func() error { + for { + _, got, err := c.Read(context.Background()) + if err != nil { + return err + } + if !bytes.Equal(want, got) { + return fmt.Errorf("unexpected message: want %q, got %q", want, got) + } + } + }) + } + checkReadErr := func(t *testing.T, err error) { + // Check read error (output depends on when read is called in relation to connection closure). + var ce websocket.CloseError + if errors.As(err, &ce) { + assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code) + } else { + assert.ErrorIs(t, net.ErrClosed, err) + } + } + checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) { + for _, c := range conn { + // Check write error. + err := c.Write(context.Background(), websocket.MessageText, want) + assert.ErrorIs(t, net.ErrClosed, err) + + _, _, err = c.Read(context.Background()) + checkReadErr(t, err) + } + } + + t.Run("CloseOtherSideDuringWrite", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = this.CloseRead(tt.ctx) + thisWriteErr := keepWriting(this) + + _, got, err := other.Read(tt.ctx) + assert.Success(t, err) + assert.Equal(t, "msg", want, got) + + err = other.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisWriteErr: + assert.ErrorIs(t, net.ErrClosed, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseThisSideDuringWrite", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = this.CloseRead(tt.ctx) + thisWriteErr := keepWriting(this) + otherReadErr := keepReading(other) + + err := this.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisWriteErr: + assert.ErrorIs(t, net.ErrClosed, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + select { + case err := <-otherReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseOtherSideDuringRead", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = other.CloseRead(tt.ctx) + errs := keepReading(this) + + err := other.Write(tt.ctx, websocket.MessageText, want) + assert.Success(t, err) + + err = other.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-errs: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseThisSideDuringRead", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + thisReadErr := keepReading(this) + otherReadErr := keepReading(other) + + err := other.Write(tt.ctx, websocket.MessageText, want) + assert.Success(t, err) + + err = this.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + select { + case err := <-otherReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) +} diff --git a/read.go b/read.go index e2699da5..1267b5b9 100644 --- a/read.go +++ b/read.go @@ -217,57 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { } } -func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { +// prepareRead sets the readTimeout context and returns a done function +// to be called after the read is done. It also returns an error if the +// connection is closed. The reference to the error is used to assign +// an error depending on if the connection closed or the context timed +// out during use. Typically the referenced error is a named return +// variable of the function calling this method. +func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { select { case <-c.closed: - return header{}, net.ErrClosed + return nil, net.ErrClosed case c.readTimeout <- ctx: } - h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) - if err != nil { + done := func() { select { case <-c.closed: - return header{}, net.ErrClosed - case <-ctx.Done(): - return header{}, ctx.Err() - default: - return header{}, err + if *err != nil { + *err = net.ErrClosed + } + case c.readTimeout <- context.Background(): + } + if *err != nil && ctx.Err() != nil { + *err = ctx.Err() } } - select { - case <-c.closed: - return header{}, net.ErrClosed - case c.readTimeout <- context.Background(): + c.closeStateMu.Lock() + closeReceivedErr := c.closeReceivedErr + c.closeStateMu.Unlock() + if closeReceivedErr != nil { + defer done() + return nil, closeReceivedErr } - return h, nil + return done, nil } -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { - select { - case <-c.closed: - return 0, net.ErrClosed - case c.readTimeout <- ctx: +func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { + readDone, err := c.prepareRead(ctx, &err) + if err != nil { + return header{}, err } + defer readDone() - n, err := io.ReadFull(c.br, p) + h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { - select { - case <-c.closed: - return n, net.ErrClosed - case <-ctx.Done(): - return n, ctx.Err() - default: - return n, fmt.Errorf("failed to read frame payload: %w", err) - } + return header{}, err } - select { - case <-c.closed: - return n, net.ErrClosed - case c.readTimeout <- context.Background(): + return h, nil +} + +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { + readDone, err := c.prepareRead(ctx, &err) + if err != nil { + return 0, err + } + defer readDone() + + n, err := io.ReadFull(c.br, p) + if err != nil { + return n, fmt.Errorf("failed to read frame payload: %w", err) } return n, err @@ -325,9 +336,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } err = fmt.Errorf("received close frame: %w", ce) - c.writeClose(ce.Code, ce.Reason) - c.readMu.unlock() - c.close() + c.closeStateMu.Lock() + c.closeReceivedErr = err + closeSent := c.closeSentErr != nil + c.closeStateMu.Unlock() + + // Only unlock readMu if this connection is being closed becaue + // c.close will try to acquire the readMu lock. We unlock for + // writeClose as well because it may also call c.close. + if !closeSent { + c.readMu.unlock() + _ = c.writeClose(ce.Code, ce.Reason) + } + if !c.casClosing() { + c.readMu.unlock() + _ = c.close() + } return err } diff --git a/write.go b/write.go index e294a680..7324de74 100644 --- a/write.go +++ b/write.go @@ -5,6 +5,7 @@ package websocket import ( "bufio" + "compress/flate" "context" "crypto/rand" "encoding/binary" @@ -14,8 +15,6 @@ import ( "net" "time" - "compress/flate" - "github.com/coder/websocket/internal/errd" "github.com/coder/websocket/internal/util" ) @@ -249,22 +248,36 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } defer c.writeFrameMu.unlock() + defer func() { + if c.isClosed() && opcode == opClose { + err = nil + } + if err != nil { + if ctx.Err() != nil { + err = ctx.Err() + } else if c.isClosed() { + err = net.ErrClosed + } + err = fmt.Errorf("failed to write frame: %w", err) + } + }() + + c.closeStateMu.Lock() + closeSentErr := c.closeSentErr + c.closeStateMu.Unlock() + if closeSentErr != nil { + return 0, net.ErrClosed + } + select { case <-c.closed: return 0, net.ErrClosed case c.writeTimeout <- ctx: } - defer func() { - if err != nil { - select { - case <-c.closed: - err = net.ErrClosed - case <-ctx.Done(): - err = ctx.Err() - default: - } - err = fmt.Errorf("failed to write frame: %w", err) + select { + case <-c.closed: + case c.writeTimeout <- context.Background(): } }() @@ -303,13 +316,16 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } } - select { - case <-c.closed: - if opcode == opClose { - return n, nil + if opcode == opClose { + c.closeStateMu.Lock() + c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed) + closeReceived := c.closeReceivedErr != nil + c.closeStateMu.Unlock() + + if closeReceived && !c.casClosing() { + c.writeFrameMu.unlock() + _ = c.close() } - return n, net.ErrClosed - case c.writeTimeout <- context.Background(): } return n, nil