diff --git a/event-handler/teleport_event.go b/event-handler/teleport_event.go index f3bd2832f..36833e169 100644 --- a/event-handler/teleport_event.go +++ b/event-handler/teleport_event.go @@ -20,9 +20,10 @@ import ( "encoding/json" "time" + "github.com/gravitational/trace" + auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/trace" ) const ( diff --git a/event-handler/teleport_events_watcher_test.go b/event-handler/teleport_events_watcher_test.go index 8ab2f5f28..6a0b8c399 100644 --- a/event-handler/teleport_events_watcher_test.go +++ b/event-handler/teleport_events_watcher_test.go @@ -18,6 +18,7 @@ package main import ( "strconv" + "sync" "testing" "time" @@ -33,13 +34,31 @@ import ( // mockTeleportEventWatcher is Teleport client mock type mockTeleportEventWatcher struct { + mu sync.Mutex // events is the mock list of events events []events.AuditEvent // mockSearchErr is an error to return mockSearchErr error } +func (c *mockTeleportEventWatcher) setEvents(events []events.AuditEvent) { + c.mu.Lock() + defer c.mu.Unlock() + + c.events = events +} + +func (c *mockTeleportEventWatcher) setSearchEventsError(err error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.mockSearchErr = err +} + func (c *mockTeleportEventWatcher) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.mockSearchErr != nil { return nil, "", c.mockSearchErr } @@ -173,7 +192,7 @@ func TestEvents(t *testing.T) { // Events goroutine should return next page errors mockErr := trace.Errorf("error") - mockEventWatcher.mockSearchErr = mockErr + mockEventWatcher.setSearchEventsError(mockErr) select { case err := <-chErr: @@ -221,7 +240,7 @@ func TestUpdatePage(t *testing.T) { chEvt, chErr := client.Events(ctx) // Add an incomplete page of 3 events and collect them. - mockEventWatcher.events = testAuditEvents[:3] + mockEventWatcher.setEvents(testAuditEvents[:3]) var i int for ; i < 3; i++ { select { @@ -249,7 +268,7 @@ func TestUpdatePage(t *testing.T) { } // Update the event watcher with the full page of events an collect. - mockEventWatcher.events = testAuditEvents[:5] + mockEventWatcher.setEvents(testAuditEvents[:5]) for ; i < 5; i++ { select { case event, ok := <-chEvt: @@ -276,7 +295,7 @@ func TestUpdatePage(t *testing.T) { } // Add another partial page and collect the events - mockEventWatcher.events = testAuditEvents[:7] + mockEventWatcher.setEvents(testAuditEvents[:7]) for ; i < 7; i++ { select { case event, ok := <-chEvt: @@ -295,7 +314,7 @@ func TestUpdatePage(t *testing.T) { // Events goroutine should return update page errors mockErr := trace.Errorf("error") - mockEventWatcher.mockSearchErr = mockErr + mockEventWatcher.setSearchEventsError(mockErr) select { case err := <-chErr: