Skip to content
This repository has been archived by the owner on Jun 4, 2024. It is now read-only.

Commit

Permalink
Fix lint and race condition in test.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Nov 28, 2023
1 parent 416bb57 commit 9b199a2
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 30 deletions.
3 changes: 1 addition & 2 deletions event-handler/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ import (
"time"

"github.com/alecthomas/kong"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/integrations/lib/logger"
"github.com/gravitational/teleport/integrations/lib/stringset"
"github.com/gravitational/trace"

"github.com/gravitational/teleport-plugins/event-handler/lib"
)
Expand Down
4 changes: 2 additions & 2 deletions event-handler/teleport_event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ import (
"encoding/hex"
"testing"

auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/gravitational/teleport-plugins/event-handler/lib"
auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1"
"github.com/gravitational/teleport/api/types/events"
)

func TestNew(t *testing.T) {
Expand Down
5 changes: 2 additions & 3 deletions event-handler/teleport_events_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ import (
"fmt"
"time"

"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"

"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1"
Expand All @@ -32,6 +29,8 @@ import (
"github.com/gravitational/teleport/integrations/lib"
"github.com/gravitational/teleport/integrations/lib/credentials"
"github.com/gravitational/teleport/integrations/lib/logger"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
)

const (
Expand Down
64 changes: 41 additions & 23 deletions event-handler/teleport_events_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,46 @@ package main

import (
"strconv"
"sync"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"

"github.com/gravitational/teleport/api/client/proto"
auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
)

// mockTeleportEventWatcher is Teleport client mock
type mockTeleportEventWatcher struct {
mu sync.Mutex
// events is the mock list of events
events []events.AuditEvent
// mockSearchErr is an error to return
mockSearchErr error
}

func (c *mockTeleportEventWatcher) setEvents(events []events.AuditEvent) {
c.mu.Lock()
defer c.mu.Unlock()

c.events = events
}

func (c *mockTeleportEventWatcher) setSearchEventsError(err error) {
c.mu.Lock()
defer c.mu.Unlock()

c.mockSearchErr = err
}

func (c *mockTeleportEventWatcher) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) {
c.mu.Lock()
defer c.mu.Unlock()

if c.mockSearchErr != nil {
return nil, "", c.mockSearchErr
}
Expand Down Expand Up @@ -151,7 +169,7 @@ func TestEvents(t *testing.T) {
case err := <-chErr:
t.Fatalf("Received unexpected error from error channel: %v", err)
return
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}
}
Expand All @@ -160,40 +178,40 @@ func TestEvents(t *testing.T) {
select {
case _, ok := <-chEvt:
require.False(t, ok, "Events channel should be closed")
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}

select {
case _, ok := <-chErr:
require.False(t, ok, "Error channel should be closed")
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}

// Events goroutine should return next page errors
mockErr := trace.Errorf("error")
mockEventWatcher.mockSearchErr = mockErr
mockEventWatcher.setSearchEventsError(mockErr)

select {
case err := <-chErr:
require.Error(t, mockErr, err)
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}

// Both channels should be closed
select {
case _, ok := <-chEvt:
require.False(t, ok, "Events channel should be closed")
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}

select {
case _, ok := <-chErr:
require.False(t, ok, "Error channel should be closed")
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}
}
Expand Down Expand Up @@ -221,7 +239,7 @@ func TestUpdatePage(t *testing.T) {
chEvt, chErr := client.Events(ctx)

// Add an incomplete page of 3 events and collect them.
mockEventWatcher.events = testAuditEvents[:3]
mockEventWatcher.setEvents(testAuditEvents[:3])
var i int
for ; i < 3; i++ {
select {
Expand All @@ -234,7 +252,7 @@ func TestUpdatePage(t *testing.T) {
case err := <-chErr:
t.Fatalf("Received unexpected error from error channel: %v", err)
return
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}
}
Expand All @@ -245,11 +263,11 @@ func TestUpdatePage(t *testing.T) {
t.Fatalf("Events channel should be open")
case <-chErr:
t.Fatalf("Events channel should be open")
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
}

// Update the event watcher with the full page of events an collect.
mockEventWatcher.events = testAuditEvents[:5]
mockEventWatcher.setEvents(testAuditEvents[:5])
for ; i < 5; i++ {
select {
case event, ok := <-chEvt:
Expand All @@ -261,7 +279,7 @@ func TestUpdatePage(t *testing.T) {
case err := <-chErr:
t.Fatalf("Received unexpected error from error channel: %v", err)
return
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}
}
Expand All @@ -272,11 +290,11 @@ func TestUpdatePage(t *testing.T) {
t.Fatalf("Events channel should be open")
case <-chErr:
t.Fatalf("Events channel should be open")
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
}

// Add another partial page and collect the events
mockEventWatcher.events = testAuditEvents[:7]
mockEventWatcher.setEvents(testAuditEvents[:7])
for ; i < 7; i++ {
select {
case event, ok := <-chEvt:
Expand All @@ -288,34 +306,34 @@ func TestUpdatePage(t *testing.T) {
case err := <-chErr:
t.Fatalf("Received unexpected error from error channel: %v", err)
return
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}
}

// Events goroutine should return update page errors
mockErr := trace.Errorf("error")
mockEventWatcher.mockSearchErr = mockErr
mockEventWatcher.setSearchEventsError(mockErr)

select {
case err := <-chErr:
require.Error(t, mockErr, err)
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}

// Both channels should be closed
select {
case _, ok := <-chEvt:
require.False(t, ok, "Events channel should be closed")
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}

select {
case _, ok := <-chErr:
require.False(t, ok, "Error channel should be closed")
case <-time.After(time.Millisecond):
case <-time.After(100 * time.Millisecond):
t.Fatalf("No events received within deadline")
}
}
Expand Down

0 comments on commit 9b199a2

Please sign in to comment.