From c7e9f9acbd6699fbed81541dee40d9f52485ef06 Mon Sep 17 00:00:00 2001 From: Nick Stenning Date: Thu, 14 Mar 2024 10:48:18 +0100 Subject: [PATCH] Shut down subscriber read loop correctly readLoop was leaking goroutines when subscribers were unsubscribed, because it was blocked trying to write a nil to erChan in response to an upstream io.EOF. This commit passes in a context to startReadLoop, ensuring that when that context is closed, the read loop terminates correctly. For SubscribeWithContext we just pass in the parent context, as there is no way to unsubscribe the handler. For SubscribeChanWithContext, we pass in a derived cancelable context, which we cancel explicitly when the channel is unsubscribed. --- client.go | 89 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/client.go b/client.go index 61772b6..03e6db5 100644 --- a/client.go +++ b/client.go @@ -98,7 +98,7 @@ func (c *Client) SubscribeWithContext(ctx context.Context, stream string, handle defer resp.Body.Close() reader := NewEventStreamReader(resp.Body, c.maxBufferSize) - eventChan, errorChan := c.startReadLoop(reader) + eventChan, errorChan := c.startReadLoop(ctx, reader) for { select { @@ -133,6 +133,8 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch c.subscribed[ch] = make(chan struct{}) c.mu.Unlock() + ctx, cancel := context.WithCancel(ctx) + operation := func() error { resp, err := c.request(ctx, stream) if err != nil { @@ -156,13 +158,14 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch } reader := NewEventStreamReader(resp.Body, c.maxBufferSize) - eventChan, errorChan := c.startReadLoop(reader) + eventChan, errorChan := c.startReadLoop(ctx, reader) for { var msg *Event // Wait for message to arrive or exit select { case <-c.subscribed[ch]: + cancel() return nil case err = <-errorChan: return err @@ -173,6 +176,7 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch if msg != nil { select { case <-c.subscribed[ch]: + cancel() return nil case ch <- msg: // message sent @@ -201,51 +205,72 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch return err } -func (c *Client) startReadLoop(reader *EventStreamReader) (chan *Event, chan error) { +func (c *Client) startReadLoop(ctx context.Context, reader *EventStreamReader) (chan *Event, chan error) { outCh := make(chan *Event) erChan := make(chan error) - go c.readLoop(reader, outCh, erChan) + go c.readLoop(ctx, reader, outCh, erChan) return outCh, erChan } -func (c *Client) readLoop(reader *EventStreamReader, outCh chan *Event, erChan chan error) { +func (c *Client) readLoop(ctx context.Context, reader *EventStreamReader, outCh chan *Event, erChan chan error) { for { - // Read each new line and process the type of event - event, err := reader.ReadEvent() - if err != nil { - if err == io.EOF { - erChan <- nil - return + msg, err := c.readLoopInner(reader) + if errors.Is(err, io.EOF) { + select { + case <-ctx.Done(): + case erChan <- nil: + } + break + } else if err != nil { + select { + case <-ctx.Done(): + case erChan <- err: } - // run user specified disconnect function - if c.disconnectcb != nil { - c.Connected = false - c.disconnectcb(c) + break + } else if msg != nil { + select { + case <-ctx.Done(): + case outCh <- msg: } - erChan <- err - return } + } +} - if !c.Connected && c.connectedcb != nil { - c.Connected = true - c.connectedcb(c) +func (c *Client) readLoopInner(reader *EventStreamReader) (*Event, error) { + // Read each new line and process the type of event + event, err := reader.ReadEvent() + if err != nil { + if errors.Is(err, io.EOF) { + return nil, err + } + // run user specified disconnect function + if c.disconnectcb != nil { + c.Connected = false + c.disconnectcb(c) } + return nil, err + } - // If we get an error, ignore it. - var msg *Event - if msg, err = c.processEvent(event); err == nil { - if len(msg.ID) > 0 { - c.LastEventID.Store(msg.ID) - } else { - msg.ID, _ = c.LastEventID.Load().([]byte) - } + if !c.Connected && c.connectedcb != nil { + c.Connected = true + c.connectedcb(c) + } - // Send downstream if the event has something useful - if msg.hasContent() { - outCh <- msg - } + // If we get an error, ignore it. + if msg, err := c.processEvent(event); err == nil { + if len(msg.ID) > 0 { + c.LastEventID.Store(msg.ID) + } else { + msg.ID, _ = c.LastEventID.Load().([]byte) + } + + // Send downstream only if the event has something useful + if msg.hasContent() { + return msg, nil } } + + return nil, nil } // SubscribeRaw to an sse endpoint