Skip to content
This repository has been archived by the owner on Jun 4, 2024. It is now read-only.

fix leaking go-routines in event-handler watcher #963

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion event-handler/teleport_events_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func (t *TeleportEventsWatcher) Events(ctx context.Context) (chan *TeleportEvent
err := t.fetch(ctx)
if err != nil {
e <- trace.Wrap(err)
continue
break
}

// If there is still nothing new on current page, sleep
Expand Down
143 changes: 129 additions & 14 deletions event-handler/teleport_events_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package main

import (
"errors"
"testing"
"time"

Expand All @@ -36,22 +37,19 @@ type mockTeleportEventWatcher struct {
t *testing.T
}

// SearchEvents is mock SearchEvents method which returns events
func (c *mockTeleportEventWatcher) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) {
e := c.events
c.events = make([]events.AuditEvent, 0) // nullify events
c.events = nil
return e, "test", nil
}

// StreamSessionEvents returns session events stream
func (c *mockTeleportEventWatcher) StreamSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan events.AuditEvent, chan error) {
return nil, nil
}

// SearchEvents is mock SearchEvents method which returns events
func (c *mockTeleportEventWatcher) SearchUnstructuredEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]*auditlogpb.EventUnstructured, string, error) {
e := c.events
c.events = make([]events.AuditEvent, 0) // nullify events
c.events = nil

events := make([]*auditlogpb.EventUnstructured, len(e))
for i, event := range e {
Expand All @@ -60,12 +58,10 @@ func (c *mockTeleportEventWatcher) SearchUnstructuredEvents(ctx context.Context,
return events, "test", nil
}

// StreamSessionEvents returns session events stream
func (c *mockTeleportEventWatcher) StreamUnstructuredSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan *auditlogpb.EventUnstructured, chan error) {
return nil, nil
}

// UsertLock is mock UpsertLock method
func (c *mockTeleportEventWatcher) UpsertLock(ctx context.Context, lock types.Lock) error {
return nil
}
Expand All @@ -76,12 +72,11 @@ func (c *mockTeleportEventWatcher) Ping(ctx context.Context) (proto.PingResponse
}, nil
}

// Close is mock close method
func (c *mockTeleportEventWatcher) Close() error {
return nil
}

func newTeleportEventWatcher(t *testing.T, e []events.AuditEvent) *TeleportEventsWatcher {
func newTeleportEventWatcher(t *testing.T, e []events.AuditEvent, exitOnLastEvent bool) *TeleportEventsWatcher {
teleportEventWatcher := &mockTeleportEventWatcher{events: e, t: t}

client := &TeleportEventsWatcher{
Expand All @@ -90,7 +85,7 @@ func newTeleportEventWatcher(t *testing.T, e []events.AuditEvent) *TeleportEvent
config: &StartCmdConfig{
IngestConfig: IngestConfig{
BatchSize: 5,
ExitOnLastEvent: true,
ExitOnLastEvent: exitOnLastEvent,
},
},
}
Expand All @@ -112,22 +107,22 @@ func TestNext(t *testing.T) {
},
}

client := newTeleportEventWatcher(t, e)
client := newTeleportEventWatcher(t, e, true)
chEvt, chErr := client.Events(context.Background())

select {
case err := <-chErr:
require.NoError(t, err)
require.Fail(t, "received unexpected error from error channel", "error: %v", err)
case e := <-chEvt:
require.NotNil(t, e.Event)
require.Equal(t, e.ID, "1")
require.Equal(t, "1", e.ID)
case <-time.After(time.Second):
require.Fail(t, "No events were sent")
}

select {
case err := <-chErr:
require.NoError(t, err)
require.Fail(t, "received unexpected error from error channel", "error: %v", err)
case e := <-chEvt:
require.NotNil(t, e.Event)
require.Equal(t, "081ca05eea09ac0cd06e2d2acd06bec424146b254aa500de37bdc2c2b0a4dd0f", e.ID)
Expand All @@ -136,6 +131,126 @@ func TestNext(t *testing.T) {
}
}

// errMockTeleportEventWatcher is Teleport client mock that returns an error after the first SearchUnstructuredEvents
type errMockTeleportEventWatcher struct {
// events is the mock list of events
events []events.AuditEvent
t *testing.T
zmb3 marked this conversation as resolved.
Show resolved Hide resolved
called bool
}

func (c *errMockTeleportEventWatcher) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) {
e := c.events
c.events = nil
return e, "test", nil
}

func (c *errMockTeleportEventWatcher) StreamSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan events.AuditEvent, chan error) {
return nil, nil
}

func (c *errMockTeleportEventWatcher) SearchUnstructuredEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]*auditlogpb.EventUnstructured, string, error) {
if c.called {
return nil, "", errors.New("error")
}

e := c.events
c.events = nil

events := make([]*auditlogpb.EventUnstructured, len(e))
for i, event := range e {
events[i] = eventToJSON(c.t, event)
}
c.called = true
return events, "", nil
}

func (c *errMockTeleportEventWatcher) StreamUnstructuredSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan *auditlogpb.EventUnstructured, chan error) {
return nil, nil
}

func (c *errMockTeleportEventWatcher) UpsertLock(ctx context.Context, lock types.Lock) error {
return nil
}

func (c *errMockTeleportEventWatcher) Ping(ctx context.Context) (proto.PingResponse, error) {
return proto.PingResponse{
ServerVersion: Version,
}, nil
}

// Close is mock close method
func (c *errMockTeleportEventWatcher) Close() error {
return nil
}

func newErrTeleportEventWatcher(t *testing.T, e []events.AuditEvent, exitOnLastEvent bool) *TeleportEventsWatcher {
teleportEventWatcher := &errMockTeleportEventWatcher{events: e, t: t}

client := &TeleportEventsWatcher{
client: teleportEventWatcher,
pos: -1,
config: &StartCmdConfig{
IngestConfig: IngestConfig{
BatchSize: 5,
ExitOnLastEvent: exitOnLastEvent,
},
},
}

return client
}

func TestLastEvent(t *testing.T) {
t.Run("should not leave hanging go-routines", func(t *testing.T) {
e := []events.AuditEvent{
&events.UserCreate{
Metadata: events.Metadata{
ID: "1",
},
},
}

client := newErrTeleportEventWatcher(t, e, false)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)

chEvt, chErr := client.Events(ctx)

select {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one of these cases will be hit - which one do we expect? I would write the test in a way that it fails if an unexpected case is executed, rather than writing it so that it can succeed in multiple ways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the code, I think we don't ever receive a nil error in the chErr, so I think probably changing the corresponding case to just be a require.Fail should be fine. Do you agree? If so I wouldn't mind doing this change for the existing test from which I copied this boilerplate as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me.

case err := <-chErr:
require.Fail(t, "received unexpected error from error channel", "error: %v", err)
case e := <-chEvt:
require.NotNil(t, e.Event)
require.Equal(t, "1", e.ID)
case <-time.After(time.Second):
require.Fail(t, "No events were sent")
}

allDone := make(chan struct{})

const nIters = 5
for i := 0; i < nIters; i++ {
go func() {
zmb3 marked this conversation as resolved.
Show resolved Hide resolved
chEvt, _ := client.Events(ctx)
// we're assuming that a closed channel == closed goroutine
for range chEvt {
}
allDone <- struct{}{}
}()
}

for i := 0; i < nIters; i++ {
select {
case <-allDone:
case <-ctx.Done():
require.Fail(t, "timeout reached, some goroutines were not closed")
}
}
})
}

func TestValidateConfig(t *testing.T) {
for _, tc := range []struct {
name string
Expand Down