From d994f7d11843681aafc1d4eb13a497211690de72 Mon Sep 17 00:00:00 2001 From: Joshua Colvin Date: Fri, 25 Mar 2022 16:52:26 -0700 Subject: [PATCH] Backport feed fix Handle data race condition on initial connect. --- .../broadcastclient/broadcastclient.go | 43 ++++++++++++------ .../wsbroadcastserver/clientconnection.go | 2 +- packages/arb-util/wsbroadcastserver/utils.go | 45 +++++++++++++++++-- 3 files changed, 71 insertions(+), 19 deletions(-) diff --git a/packages/arb-util/broadcastclient/broadcastclient.go b/packages/arb-util/broadcastclient/broadcastclient.go index 4247af220a..8df20fe062 100644 --- a/packages/arb-util/broadcastclient/broadcastclient.go +++ b/packages/arb-util/broadcastclient/broadcastclient.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "github.com/offchainlabs/arbitrum/packages/arb-util/arblog" + "io" "math/big" "net" "strings" @@ -76,12 +77,12 @@ func (bc *BroadcastClient) Connect(ctx context.Context) (chan broadcaster.Broadc } func (bc *BroadcastClient) ConnectWithChannel(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) error { - _, err := bc.connect(ctx, messageReceiver) + earlyFrameData, _, err := bc.connect(ctx, messageReceiver) if err != nil { return err } - bc.startBackgroundReader(ctx, messageReceiver) + bc.startBackgroundReader(ctx, messageReceiver, earlyFrameData) return nil } @@ -104,11 +105,11 @@ func (bc *BroadcastClient) ConnectInBackground(ctx context.Context, messageRecei })() } -func (bc *BroadcastClient) connect(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) (chan broadcaster.BroadcastFeedMessage, error) { +func (bc *BroadcastClient) connect(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) (io.Reader, chan broadcaster.BroadcastFeedMessage, error) { if len(bc.websocketUrl) == 0 { // Nothing to do - return nil, nil + return nil, nil, nil } logger.Info().Str("url", bc.websocketUrl).Msg("connecting to arbitrum inbox message broadcaster") @@ -116,10 +117,22 @@ func (bc *BroadcastClient) connect(ctx context.Context, messageReceiver chan bro Timeout: 10 * time.Second, } - conn, _, _, err := timeoutDialer.Dial(ctx, bc.websocketUrl) + conn, br, _, err := timeoutDialer.Dial(ctx, bc.websocketUrl) if err != nil { logger.Warn().Err(err).Msg("broadcast client unable to connect") - return nil, errors.Wrap(err, "broadcast client unable to connect") + return nil, nil, errors.Wrap(err, "broadcast client unable to connect") + } + + var earlyFrameData io.Reader + if br != nil { + // Depending on how long the client takes to read the response, there may be + // data after the WebSocket upgrade response in a single read from the socket, + // ie WebSocket frames sent by the server. If this happens, Dial returns + // a non-nil bufio.Reader so that data isn't lost. But beware, this buffered + // reader is still hooked up to the socket; trying to read past what had already + // been buffered will do a blocking read on the socket, so we have to wrap it + // in a LimitedReader. + earlyFrameData = io.LimitReader(br, int64(br.Buffered())) } bc.connMutex.Lock() @@ -128,10 +141,10 @@ func (bc *BroadcastClient) connect(ctx context.Context, messageReceiver chan bro logger.Info().Msg("Connected") - return messageReceiver, nil + return earlyFrameData, messageReceiver, nil } -func (bc *BroadcastClient) startBackgroundReader(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) { +func (bc *BroadcastClient) startBackgroundReader(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage, earlyFrameData io.Reader) { go func() { for { select { @@ -140,7 +153,7 @@ func (bc *BroadcastClient) startBackgroundReader(ctx context.Context, messageRec default: } - msg, op, err := wsbroadcastserver.ReadData(ctx, bc.conn, bc.idleTimeout, ws.StateClientSide) + msg, op, err := wsbroadcastserver.ReadData(ctx, bc.conn, earlyFrameData, bc.idleTimeout, ws.StateClientSide) if err != nil { if bc.shuttingDown { return @@ -151,7 +164,7 @@ func (bc *BroadcastClient) startBackgroundReader(ctx context.Context, messageRec logger.Error().Err(err).Str("feed", bc.websocketUrl).Int("opcode", int(op)).Msgf("error calling readData") } _ = bc.conn.Close() - bc.RetryConnect(ctx, messageReceiver) + earlyFrameData = bc.RetryConnect(ctx, messageReceiver) continue } @@ -192,7 +205,7 @@ func (bc *BroadcastClient) GetRetryCount() int { return bc.retryCount } -func (bc *BroadcastClient) RetryConnect(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) { +func (bc *BroadcastClient) RetryConnect(ctx context.Context, messageReceiver chan broadcaster.BroadcastFeedMessage) io.Reader { bc.retryMutex.Lock() defer bc.retryMutex.Unlock() @@ -202,21 +215,23 @@ func (bc *BroadcastClient) RetryConnect(ctx context.Context, messageReceiver cha for !bc.shuttingDown { select { case <-ctx.Done(): - return + return nil case <-time.After(waitDuration): } bc.retryCount++ - _, err := bc.connect(ctx, messageReceiver) + earlyFrameData, _, err := bc.connect(ctx, messageReceiver) if err == nil { bc.retrying = false - return + return earlyFrameData } if waitDuration < maxWaitDuration { waitDuration += 500 * time.Millisecond } } + + return nil } func (bc *BroadcastClient) Close() { diff --git a/packages/arb-util/wsbroadcastserver/clientconnection.go b/packages/arb-util/wsbroadcastserver/clientconnection.go index 65784b646b..7182660a61 100644 --- a/packages/arb-util/wsbroadcastserver/clientconnection.go +++ b/packages/arb-util/wsbroadcastserver/clientconnection.go @@ -122,7 +122,7 @@ func (cc *ClientConnection) readRequest(ctx context.Context, timeout time.Durati atomic.StoreInt64(&cc.lastHeardUnix, time.Now().Unix()) - return ReadData(ctx, cc.conn, timeout, ws.StateServerSide) + return ReadData(ctx, cc.conn, nil, timeout, ws.StateServerSide) } func (cc *ClientConnection) Write(x interface{}) error { diff --git a/packages/arb-util/wsbroadcastserver/utils.go b/packages/arb-util/wsbroadcastserver/utils.go index 3221e234ad..c99b2f2956 100644 --- a/packages/arb-util/wsbroadcastserver/utils.go +++ b/packages/arb-util/wsbroadcastserver/utils.go @@ -2,24 +2,61 @@ package wsbroadcastserver import ( "context" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" + "errors" + "io" "io/ioutil" "net" "strings" "time" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" ) +type chainedReader struct { + readers []io.Reader +} + +func (cr *chainedReader) Read(b []byte) (n int, err error) { + for len(cr.readers) > 0 { + n, err = cr.readers[0].Read(b) + if errors.Is(err, io.EOF) { + cr.readers = cr.readers[1:] + if n == 0 { + continue // EOF and empty, skip to next + } else { + // The Read interface specifies some data can be returned along with an EOF. + if len(cr.readers) != 1 { + // If this isn't the last reader, return the data without the EOF since this + // may not be the end of all the readers. + return n, nil + } else { + return + } + } + } + break + } + return +} + +func (cr *chainedReader) add(r io.Reader) *chainedReader { + if r != nil { + cr.readers = append(cr.readers, r) + } + return cr +} + func logError(err error, msg string) { if !strings.Contains(err.Error(), "use of closed network connection") { logger.Error().Err(err).Msg(msg) } } -func ReadData(ctx context.Context, conn net.Conn, idleTimeout time.Duration, state ws.State) ([]byte, ws.OpCode, error) { +func ReadData(ctx context.Context, conn net.Conn, earlyFrameData io.Reader, idleTimeout time.Duration, state ws.State) ([]byte, ws.OpCode, error) { controlHandler := wsutil.ControlFrameHandler(conn, state) reader := wsutil.Reader{ - Source: conn, + Source: (&chainedReader{}).add(earlyFrameData).add(conn), State: state, CheckUTF8: true, SkipHeaderCheck: false,