Skip to content
This repository has been archived by the owner on Jun 4, 2024. It is now read-only.

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tigrato committed May 14, 2024
1 parent b413446 commit f587ade
Showing 1 changed file with 150 additions and 10 deletions.
160 changes: 150 additions & 10 deletions event-handler/teleport_events_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -72,8 +73,19 @@ func (c *mockTeleportEventWatcher) SearchEvents(ctx context.Context, fromUTC, to
endIndex = len(c.events)
}

// Get the next page
e := c.events[startIndex:endIndex]
// validate time
var e []events.AuditEvent
for i, event := range c.events {
if i < startIndex {
continue
}
if i >= endIndex {
break
}
if event.GetTime().After(fromUTC) && event.GetTime().Before(toUTC) {
e = append(e, event)
}
}

// Check if we finished the page
var lastKey string
Expand Down Expand Up @@ -120,17 +132,23 @@ func (c *mockTeleportEventWatcher) Close() error {
return nil
}

func newTeleportEventWatcher(t *testing.T, eventsClient TeleportSearchEventsClient) *TeleportEventsWatcher {

func newTeleportEventWatcher(t *testing.T, eventsClient TeleportSearchEventsClient, startTime time.Time, skipEventTypesRaw []string) *TeleportEventsWatcher {
skipEventTypes := map[string]struct{}{}
for _, eventType := range skipEventTypesRaw {
skipEventTypes[eventType] = struct{}{}
}
client := &TeleportEventsWatcher{
client: eventsClient,
pos: -1,
config: &StartCmdConfig{
IngestConfig: IngestConfig{
BatchSize: 5,
ExitOnLastEvent: true,
BatchSize: 5,
ExitOnLastEvent: true,
SkipEventTypes: skipEventTypes,
SkipSessionTypesRaw: skipEventTypesRaw,
},
},
windowStartTime: startTime,
}

return client
Expand All @@ -144,7 +162,9 @@ func TestEvents(t *testing.T) {
for i := 0; i < 20; i++ {
testAuditEvents[i] = &events.UserCreate{
Metadata: events.Metadata{
ID: strconv.Itoa(i),
ID: strconv.Itoa(i),
Time: time.Now(),
Type: libevents.UserUpdatedEvent,
},
}
}
Expand All @@ -153,7 +173,7 @@ func TestEvents(t *testing.T) {

// Add the 20 events to a mock event watcher.
mockEventWatcher := &mockTeleportEventWatcher{events: testAuditEvents}
client := newTeleportEventWatcher(t, mockEventWatcher)
client := newTeleportEventWatcher(t, mockEventWatcher, time.Now().Add(-48*time.Hour), nil)

// Start the events goroutine
chEvt, chErr := client.Events(ctx)
Expand Down Expand Up @@ -225,15 +245,17 @@ func TestUpdatePage(t *testing.T) {
for i := 0; i < 10; i++ {
testAuditEvents[i] = &events.UserCreate{
Metadata: events.Metadata{
ID: strconv.Itoa(i),
ID: strconv.Itoa(i),
Time: time.Now(),
Type: libevents.UserUpdatedEvent,
},
}
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()

mockEventWatcher := &mockTeleportEventWatcher{}
client := newTeleportEventWatcher(t, mockEventWatcher)
client := newTeleportEventWatcher(t, mockEventWatcher, time.Now().Add(-1*time.Hour), nil)
client.config.ExitOnLastEvent = false

// Start the events goroutine
Expand Down Expand Up @@ -474,3 +496,121 @@ func Test_splitRangeByDay(t *testing.T) {
})
}
}

func TestEventsWithWindowSkip(t *testing.T) {
ctx := context.Background()

// create fake audit events with ids 0-29
testAuditEvents := make([]events.AuditEvent, 30)
for i := 0; i < 10; i++ {
testAuditEvents[i] = &events.UserCreate{
Metadata: events.Metadata{
ID: strconv.Itoa(i),
Time: time.Now(),
Type: libevents.UserUpdatedEvent,
},
}
}
for i := 10; i < 20; i++ {
testAuditEvents[i] = &events.UserCreate{
Metadata: events.Metadata{
ID: strconv.Itoa(i),
Time: time.Now(),
Type: libevents.UserCreateEvent,
},
}
}

for i := 20; i < 30; i++ {
testAuditEvents[i] = &events.UserCreate{
Metadata: events.Metadata{
ID: strconv.Itoa(i),
Time: time.Now(),
Type: libevents.UserUpdatedEvent,
},
}
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()

// Add the 20 events to a mock event watcher.
mockEventWatcher := &mockTeleportEventWatcher{events: testAuditEvents}
client := newTeleportEventWatcher(t, mockEventWatcher, time.Now().Add(-48*time.Hour), []string{libevents.UserCreateEvent})

// Start the events goroutine
chEvt, chErr := client.Events(ctx)

// Collect all 10 first events
for i := 0; i < 10; i++ {
select {
case event, ok := <-chEvt:
require.NotNil(t, event, "Expected an event but got nil. i: %v", i)
require.Equal(t, strconv.Itoa(i), event.ID)
if !ok {
return
}
case err := <-chErr:
t.Fatalf("Received unexpected error from error channel: %v", err)
return
case <-time.After(2 * time.Second):
t.Fatalf("No events received within deadline")
}
}

for i := 20; i < 30; i++ {
select {
case event, ok := <-chEvt:
require.NotNil(t, event, "Expected an event but got nil. i: %v", i)
require.Equal(t, strconv.Itoa(i), event.ID)
if !ok {
return
}
case err := <-chErr:
t.Fatalf("Received unexpected error from error channel: %v", err)
return
case <-time.After(2 * time.Second):
t.Fatalf("No events received within deadline")
}
}

// Both channels should be closed once the last event is reached.
select {
case _, ok := <-chEvt:
require.False(t, ok, "Events channel should be closed")
case <-time.After(2 * time.Second):
t.Fatalf("No events received within deadline")
}

select {
case _, ok := <-chErr:
require.False(t, ok, "Error channel should be closed")
case <-time.After(2 * time.Second):
t.Fatalf("No events received within deadline")
}

// Events goroutine should return next page errors
mockErr := trace.Errorf("error")
mockEventWatcher.setSearchEventsError(mockErr)

select {
case err := <-chErr:
require.Error(t, mockErr, err)
case <-time.After(2 * time.Second):
t.Fatalf("No events received within deadline")
}

// Both channels should be closed
select {
case _, ok := <-chEvt:
require.False(t, ok, "Events channel should be closed")
case <-time.After(2 * time.Second):
t.Fatalf("No events received within deadline")
}

select {
case _, ok := <-chErr:
require.False(t, ok, "Error channel should be closed")
case <-time.After(2 * time.Second):
t.Fatalf("No events received within deadline")
}
}

0 comments on commit f587ade

Please sign in to comment.