Skip to content
This repository has been archived by the owner on Jun 4, 2024. It is now read-only.

Commit

Permalink
Update test.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Nov 5, 2023
1 parent 9d616c2 commit 6984a71
Showing 1 changed file with 206 additions and 91 deletions.
297 changes: 206 additions & 91 deletions event-handler/teleport_events_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ limitations under the License.
package main

import (
"errors"
"sync"
"strconv"
"testing"
"time"

Expand All @@ -36,28 +35,53 @@ import (
type mockTeleportEventWatcher struct {
// events is the mock list of events
events []events.AuditEvent
// mockSearchErr is an error to return
mockSearchErr error
}

func (c *mockTeleportEventWatcher) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) {
e := c.events
c.events = nil
return e, "test", nil
if c.mockSearchErr != nil {
return nil, "", c.mockSearchErr
}

var startIndex int
if startKey != "" {
startIndex, _ = strconv.Atoi(startKey)
}

endIndex := startIndex + limit
if endIndex >= len(c.events) {
endIndex = len(c.events)
}

// Get the next page
e := c.events[startIndex:endIndex]

// Check if we finished the page
var lastKey string
if len(e) == limit {
lastKey = strconv.Itoa(startIndex + (len(e) - 1))
}

return e, lastKey, nil
}

func (c *mockTeleportEventWatcher) StreamSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan events.AuditEvent, chan error) {
return nil, nil
}

func (c *mockTeleportEventWatcher) SearchUnstructuredEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]*auditlogpb.EventUnstructured, string, error) {
e := c.events
c.events = nil
events, lastKey, err := c.SearchEvents(ctx, fromUTC, toUTC, namespace, eventTypes, limit, order, startKey)
if err != nil {
return nil, "", trace.Wrap(err)
}

protoEvents, err := eventsToProto(e)
protoEvents, err := eventsToProto(events)
if err != nil {
return nil, "", trace.Wrap(err)
}

return protoEvents, "test", nil
return protoEvents, lastKey, nil
}

func (c *mockTeleportEventWatcher) StreamUnstructuredSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan *auditlogpb.EventUnstructured, chan error) {
Expand All @@ -78,131 +102,222 @@ func (c *mockTeleportEventWatcher) Close() error {
return nil
}

func newTeleportEventWatcher(t *testing.T, eventsClient TeleportSearchEventsClient, exitOnLastEvent bool) *TeleportEventsWatcher {
func newTeleportEventWatcher(t *testing.T, eventsClient TeleportSearchEventsClient) *TeleportEventsWatcher {
client := &TeleportEventsWatcher{
client: eventsClient,
pos: -1,
config: &StartCmdConfig{
IngestConfig: IngestConfig{
BatchSize: 5,
ExitOnLastEvent: exitOnLastEvent,
ExitOnLastEvent: true,
},
},
}

return client
}

func TestNext(t *testing.T) {
const mockEventID = "1"
e := []events.AuditEvent{
&events.UserCreate{
Metadata: events.Metadata{
ID: mockEventID,
},
},
&events.UserDelete{
func TestEvents(t *testing.T) {
ctx := context.Background()

// create fake audit events with ids 0-19
testAuditEvents := make([]events.AuditEvent, 20)
for i := 0; i < 20; i++ {
testAuditEvents[i] = &events.UserCreate{
Metadata: events.Metadata{
ID: mockEventID,
ID: strconv.Itoa(i),
},
},
}
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()

mockEventWatcher := &mockTeleportEventWatcher{e}
client := newTeleportEventWatcher(t, mockEventWatcher, true)
chEvt, chErr := client.Events(context.Background())
// Add the 20 events to a mock event watcher.
mockEventWatcher := &mockTeleportEventWatcher{events: testAuditEvents}
client := newTeleportEventWatcher(t, mockEventWatcher)

// Start the events goroutine
chEvt, chErr := client.Events(ctx)

// Collect all 20 events
for i := 0; i < 20; i++ {
select {
case event, ok := <-chEvt:
require.NotNil(t, event, "Expected an event but got nil. i: %v", i)
require.Equal(t, strconv.Itoa(i), event.ID)
if !ok {
return
}
case err := <-chErr:
t.Fatalf("Received unexpected error from error channel: %v", err)
return
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}
}

// Both channels should be closed once the last event is reached.
select {
case err := <-chErr:
t.Fatalf("received unexpected error from error channel: %v", err)
case e := <-chEvt:
require.NotNil(t, e.Event)
require.Equal(t, mockEventID, e.ID)
case <-time.After(time.Second):
t.Fatalf("No events received withing one second")
case _, ok := <-chEvt:
require.False(t, ok, "Events channel should be closed")
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}

select {
case err := <-chErr:
t.Fatalf("received unexpected error from error channel: %v", err)
case e := <-chEvt:
require.NotNil(t, e.Event)
require.Equal(t, mockEventID, e.ID)
case <-time.After(time.Second):
t.Fatalf("No events received withing one second")
case _, ok := <-chErr:
require.False(t, ok, "Error channel should be closed")
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}
}

// errMockTeleportEventWatcher is Teleport client mock that returns an error after the first SearchUnstructuredEvents
type errMockTeleportEventWatcher struct {
mockTeleportEventWatcher
searchUnstructuredEventsCalled bool
}
// Events goroutine should return next page errors
mockErr := trace.Errorf("error")
mockEventWatcher.mockSearchErr = mockErr

select {
case err := <-chErr:
require.Error(t, mockErr, err)
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}

func (c *errMockTeleportEventWatcher) SearchUnstructuredEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]*auditlogpb.EventUnstructured, string, error) {
if c.searchUnstructuredEventsCalled {
return nil, "", errors.New("error")
// Both channels should be closed
select {
case _, ok := <-chEvt:
require.False(t, ok, "Events channel should be closed")
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}
defer func() { c.searchUnstructuredEventsCalled = true }()

return c.mockTeleportEventWatcher.SearchUnstructuredEvents(ctx, fromUTC, toUTC, namespace, eventTypes, limit, order, startKey)
select {
case _, ok := <-chErr:
require.False(t, ok, "Error channel should be closed")
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}
}

func TestLastEvent(t *testing.T) {
t.Run("should not leave hanging go-routines", func(t *testing.T) {
const mockEventID = "1"
e := []events.AuditEvent{
&events.UserCreate{
Metadata: events.Metadata{
ID: mockEventID,
},
func TestUpdatePage(t *testing.T) {
ctx := context.Background()

// create fake audit events with ids 0-9
testAuditEvents := make([]events.AuditEvent, 10)
for i := 0; i < 10; i++ {
testAuditEvents[i] = &events.UserCreate{
Metadata: events.Metadata{
ID: strconv.Itoa(i),
},
}
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()

mockEventWatcher := &errMockTeleportEventWatcher{mockTeleportEventWatcher: mockTeleportEventWatcher{e}}
client := newTeleportEventWatcher(t, mockEventWatcher, true)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
mockEventWatcher := &mockTeleportEventWatcher{}
client := newTeleportEventWatcher(t, mockEventWatcher)
client.config.ExitOnLastEvent = false

chEvt, chErr := client.Events(ctx)
// Start the events goroutine
chEvt, chErr := client.Events(ctx)

// Add an incomplete page of 3 events and collect them.
mockEventWatcher.events = testAuditEvents[:3]
var i int
for ; i < 3; i++ {
select {
case event, ok := <-chEvt:
require.NotNil(t, event, "Expected an event but got nil")
require.Equal(t, strconv.Itoa(i), event.ID)
if !ok {
return
}
case err := <-chErr:
t.Fatalf("received unexpected error from error channel: %v", err)
case e := <-chEvt:
require.NotNil(t, e.Event)
require.Equal(t, mockEventID, e.ID)
case <-time.After(time.Second):
t.Fatalf("No events received withing one second")
t.Fatalf("Received unexpected error from error channel: %v", err)
return
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}
}

var wg sync.WaitGroup

const numGoroutines = 5
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
chEvt, _ := client.Events(ctx)
// consume events.
for range chEvt {
}
}()
// Both channels should still be open and empty.
select {
case <-chEvt:
t.Fatalf("Events channel should be open")
case <-chErr:
t.Fatalf("Events channel should be open")
case <-time.After(time.Millisecond):
}

// Update the event watcher with the full page of events an collect.
mockEventWatcher.events = testAuditEvents[:5]
for ; i < 5; i++ {
select {
case event, ok := <-chEvt:
require.NotNil(t, event, "Expected an event but got nil")
require.Equal(t, strconv.Itoa(i), event.ID)
if !ok {
return
}
case err := <-chErr:
t.Fatalf("Received unexpected error from error channel: %v", err)
return
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}
}

goroutinesDone := make(chan struct{})
go func() {
wg.Wait()
close(goroutinesDone)
}()
// Both channels should still be open and empty.
select {
case <-chEvt:
t.Fatalf("Events channel should be open")
case <-chErr:
t.Fatalf("Events channel should be open")
case <-time.After(time.Millisecond):
}

// Add another partial page and collect the events
mockEventWatcher.events = testAuditEvents[:7]
for ; i < 7; i++ {
select {
case <-goroutinesDone:
case <-ctx.Done():
require.Fail(t, "timeout reached, some goroutines were not closed")
case event, ok := <-chEvt:
require.NotNil(t, event, "Expected an event but got nil")
require.Equal(t, strconv.Itoa(i), event.ID)
if !ok {
return
}
case err := <-chErr:
t.Fatalf("Received unexpected error from error channel: %v", err)
return
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}
})
}

// Events goroutine should return update page errors
mockErr := trace.Errorf("error")
mockEventWatcher.mockSearchErr = mockErr

select {
case err := <-chErr:
require.Error(t, mockErr, err)
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}

// Both channels should be closed
select {
case _, ok := <-chEvt:
require.False(t, ok, "Events channel should be closed")
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}

select {
case _, ok := <-chErr:
require.False(t, ok, "Error channel should be closed")
case <-time.After(time.Millisecond):
t.Fatalf("No events received within deadline")
}
}

func TestValidateConfig(t *testing.T) {
Expand Down

0 comments on commit 6984a71

Please sign in to comment.