Skip to content

Commit

Permalink
fix: avoid writing messages after close and improve handshake (#476)
Browse files Browse the repository at this point in the history
Co-authored-by: Mathias Fredriksson <[email protected]>
  • Loading branch information
FrauElster and mafredri authored Dec 4, 2024
1 parent 1253b77 commit 11bda98
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 65 deletions.
12 changes: 3 additions & 9 deletions close.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode {
func (c *Conn) Close(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")

if !c.casClosing() {
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
Expand Down Expand Up @@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) {
func (c *Conn) CloseNow() (err error) {
defer errd.Wrap(&err, "failed to immediately close WebSocket")

if !c.casClosing() {
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
Expand Down Expand Up @@ -329,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) {
}

func (c *Conn) casClosing() bool {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if !c.closing {
c.closing = true
return true
}
return false
return c.closing.Swap(true)
}

func (c *Conn) isClosed() bool {
Expand Down
10 changes: 8 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,19 @@ type Conn struct {
writeHeaderBuf [8]byte
writeHeader header

// Close handshake state.
closeStateMu sync.RWMutex
closeReceivedErr error
closeSentErr error

// CloseRead state.
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}

closing atomic.Bool
closeMu sync.Mutex // Protects following.
closed chan struct{}
closeMu sync.Mutex
closing bool

pingCounter atomic.Int64
activePingsMu sync.Mutex
Expand Down
149 changes: 148 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -460,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
}

func BenchmarkConn(b *testing.B) {
var benchCases = []struct {
benchCases := []struct {
name string
mode websocket.CompressionMode
}{
Expand Down Expand Up @@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) {
}()
}
}

func TestConnClosePropagation(t *testing.T) {
t.Parallel()

want := []byte("hello")
keepWriting := func(c *websocket.Conn) <-chan error {
return xsync.Go(func() error {
for {
err := c.Write(context.Background(), websocket.MessageText, want)
if err != nil {
return err
}
}
})
}
keepReading := func(c *websocket.Conn) <-chan error {
return xsync.Go(func() error {
for {
_, got, err := c.Read(context.Background())
if err != nil {
return err
}
if !bytes.Equal(want, got) {
return fmt.Errorf("unexpected message: want %q, got %q", want, got)
}
}
})
}
checkReadErr := func(t *testing.T, err error) {
// Check read error (output depends on when read is called in relation to connection closure).
var ce websocket.CloseError
if errors.As(err, &ce) {
assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code)
} else {
assert.ErrorIs(t, net.ErrClosed, err)
}
}
checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) {
for _, c := range conn {
// Check write error.
err := c.Write(context.Background(), websocket.MessageText, want)
assert.ErrorIs(t, net.ErrClosed, err)

_, _, err = c.Read(context.Background())
checkReadErr(t, err)
}
}

t.Run("CloseOtherSideDuringWrite", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)

_ = this.CloseRead(tt.ctx)
thisWriteErr := keepWriting(this)

_, got, err := other.Read(tt.ctx)
assert.Success(t, err)
assert.Equal(t, "msg", want, got)

err = other.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)

select {
case err := <-thisWriteErr:
assert.ErrorIs(t, net.ErrClosed, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

checkConnErrs(t, this, other)
})
t.Run("CloseThisSideDuringWrite", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)

_ = this.CloseRead(tt.ctx)
thisWriteErr := keepWriting(this)
otherReadErr := keepReading(other)

err := this.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)

select {
case err := <-thisWriteErr:
assert.ErrorIs(t, net.ErrClosed, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

select {
case err := <-otherReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

checkConnErrs(t, this, other)
})
t.Run("CloseOtherSideDuringRead", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)

_ = other.CloseRead(tt.ctx)
errs := keepReading(this)

err := other.Write(tt.ctx, websocket.MessageText, want)
assert.Success(t, err)

err = other.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)

select {
case err := <-errs:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

checkConnErrs(t, this, other)
})
t.Run("CloseThisSideDuringRead", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)

thisReadErr := keepReading(this)
otherReadErr := keepReading(other)

err := other.Write(tt.ctx, websocket.MessageText, want)
assert.Success(t, err)

err = this.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)

select {
case err := <-thisReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

select {
case err := <-otherReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}

checkConnErrs(t, this, other)
})
}
94 changes: 59 additions & 35 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,57 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
}
}

func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
// prepareRead sets the readTimeout context and returns a done function
// to be called after the read is done. It also returns an error if the
// connection is closed. The reference to the error is used to assign
// an error depending on if the connection closed or the context timed
// out during use. Typically the referenced error is a named return
// variable of the function calling this method.
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
select {
case <-c.closed:
return header{}, net.ErrClosed
return nil, net.ErrClosed
case c.readTimeout <- ctx:
}

h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
done := func() {
select {
case <-c.closed:
return header{}, net.ErrClosed
case <-ctx.Done():
return header{}, ctx.Err()
default:
return header{}, err
if *err != nil {
*err = net.ErrClosed
}
case c.readTimeout <- context.Background():
}
if *err != nil && ctx.Err() != nil {
*err = ctx.Err()
}
}

select {
case <-c.closed:
return header{}, net.ErrClosed
case c.readTimeout <- context.Background():
c.closeStateMu.Lock()
closeReceivedErr := c.closeReceivedErr
c.closeStateMu.Unlock()
if closeReceivedErr != nil {
defer done()
return nil, closeReceivedErr
}

return h, nil
return done, nil
}

func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
select {
case <-c.closed:
return 0, net.ErrClosed
case c.readTimeout <- ctx:
func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
readDone, err := c.prepareRead(ctx, &err)
if err != nil {
return header{}, err
}
defer readDone()

n, err := io.ReadFull(c.br, p)
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
select {
case <-c.closed:
return n, net.ErrClosed
case <-ctx.Done():
return n, ctx.Err()
default:
return n, fmt.Errorf("failed to read frame payload: %w", err)
}
return header{}, err
}

select {
case <-c.closed:
return n, net.ErrClosed
case c.readTimeout <- context.Background():
return h, nil
}

func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
readDone, err := c.prepareRead(ctx, &err)
if err != nil {
return 0, err
}
defer readDone()

n, err := io.ReadFull(c.br, p)
if err != nil {
return n, fmt.Errorf("failed to read frame payload: %w", err)
}

return n, err
Expand Down Expand Up @@ -325,9 +336,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
}

err = fmt.Errorf("received close frame: %w", ce)
c.writeClose(ce.Code, ce.Reason)
c.readMu.unlock()
c.close()
c.closeStateMu.Lock()
c.closeReceivedErr = err
closeSent := c.closeSentErr != nil
c.closeStateMu.Unlock()

// Only unlock readMu if this connection is being closed becaue
// c.close will try to acquire the readMu lock. We unlock for
// writeClose as well because it may also call c.close.
if !closeSent {
c.readMu.unlock()
_ = c.writeClose(ce.Code, ce.Reason)
}
if !c.casClosing() {
c.readMu.unlock()
_ = c.close()
}
return err
}

Expand Down
Loading

0 comments on commit 11bda98

Please sign in to comment.