diff --git a/client.go b/client.go index 61772b6..03e6db5 100644 --- a/client.go +++ b/client.go @@ -98,7 +98,7 @@ func (c *Client) SubscribeWithContext(ctx context.Context, stream string, handle defer resp.Body.Close() reader := NewEventStreamReader(resp.Body, c.maxBufferSize) - eventChan, errorChan := c.startReadLoop(reader) + eventChan, errorChan := c.startReadLoop(ctx, reader) for { select { @@ -133,6 +133,8 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch c.subscribed[ch] = make(chan struct{}) c.mu.Unlock() + ctx, cancel := context.WithCancel(ctx) + operation := func() error { resp, err := c.request(ctx, stream) if err != nil { @@ -156,13 +158,14 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch } reader := NewEventStreamReader(resp.Body, c.maxBufferSize) - eventChan, errorChan := c.startReadLoop(reader) + eventChan, errorChan := c.startReadLoop(ctx, reader) for { var msg *Event // Wait for message to arrive or exit select { case <-c.subscribed[ch]: + cancel() return nil case err = <-errorChan: return err @@ -173,6 +176,7 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch if msg != nil { select { case <-c.subscribed[ch]: + cancel() return nil case ch <- msg: // message sent @@ -201,51 +205,72 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch return err } -func (c *Client) startReadLoop(reader *EventStreamReader) (chan *Event, chan error) { +func (c *Client) startReadLoop(ctx context.Context, reader *EventStreamReader) (chan *Event, chan error) { outCh := make(chan *Event) erChan := make(chan error) - go c.readLoop(reader, outCh, erChan) + go c.readLoop(ctx, reader, outCh, erChan) return outCh, erChan } -func (c *Client) readLoop(reader *EventStreamReader, outCh chan *Event, erChan chan error) { +func (c *Client) readLoop(ctx context.Context, reader *EventStreamReader, outCh chan *Event, erChan chan error) { for { - // Read each new line and process the type of event - event, err := reader.ReadEvent() - if err != nil { - if err == io.EOF { - erChan <- nil - return + msg, err := c.readLoopInner(reader) + if errors.Is(err, io.EOF) { + select { + case <-ctx.Done(): + case erChan <- nil: + } + break + } else if err != nil { + select { + case <-ctx.Done(): + case erChan <- err: } - // run user specified disconnect function - if c.disconnectcb != nil { - c.Connected = false - c.disconnectcb(c) + break + } else if msg != nil { + select { + case <-ctx.Done(): + case outCh <- msg: } - erChan <- err - return } + } +} - if !c.Connected && c.connectedcb != nil { - c.Connected = true - c.connectedcb(c) +func (c *Client) readLoopInner(reader *EventStreamReader) (*Event, error) { + // Read each new line and process the type of event + event, err := reader.ReadEvent() + if err != nil { + if errors.Is(err, io.EOF) { + return nil, err + } + // run user specified disconnect function + if c.disconnectcb != nil { + c.Connected = false + c.disconnectcb(c) } + return nil, err + } - // If we get an error, ignore it. - var msg *Event - if msg, err = c.processEvent(event); err == nil { - if len(msg.ID) > 0 { - c.LastEventID.Store(msg.ID) - } else { - msg.ID, _ = c.LastEventID.Load().([]byte) - } + if !c.Connected && c.connectedcb != nil { + c.Connected = true + c.connectedcb(c) + } - // Send downstream if the event has something useful - if msg.hasContent() { - outCh <- msg - } + // If we get an error, ignore it. + if msg, err := c.processEvent(event); err == nil { + if len(msg.ID) > 0 { + c.LastEventID.Store(msg.ID) + } else { + msg.ID, _ = c.LastEventID.Load().([]byte) + } + + // Send downstream only if the event has something useful + if msg.hasContent() { + return msg, nil } } + + return nil, nil } // SubscribeRaw to an sse endpoint