From 9d616c2b4186a19ef380a925705d74604b441d09 Mon Sep 17 00:00:00 2001 From: joerger Date: Thu, 2 Nov 2023 18:58:16 -0700 Subject: [PATCH] Edits. --- event-handler/cli.go | 3 +- event-handler/teleport_event_test.go | 56 +++++-- event-handler/teleport_events_watcher.go | 7 +- event-handler/teleport_events_watcher_test.go | 148 ++++++------------ 4 files changed, 102 insertions(+), 112 deletions(-) diff --git a/event-handler/cli.go b/event-handler/cli.go index 2ee3871b8..fa5557643 100644 --- a/event-handler/cli.go +++ b/event-handler/cli.go @@ -22,9 +22,10 @@ import ( "time" "github.com/alecthomas/kong" + "github.com/gravitational/trace" + "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/stringset" - "github.com/gravitational/trace" "github.com/gravitational/teleport-plugins/event-handler/lib" ) diff --git a/event-handler/teleport_event_test.go b/event-handler/teleport_event_test.go index 6ed79549d..bfe768e8d 100644 --- a/event-handler/teleport_event_test.go +++ b/event-handler/teleport_event_test.go @@ -21,14 +21,15 @@ import ( "encoding/hex" "testing" - auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" - "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/gravitational/teleport-plugins/event-handler/lib" + auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" + "github.com/gravitational/teleport/api/types/events" ) func TestNew(t *testing.T) { @@ -39,7 +40,10 @@ func TestNew(t *testing.T) { }, } - event, err := NewTeleportEvent(eventToJSON(t, events.AuditEvent(e)), "cursor") + protoEvent, err := eventToProto(events.AuditEvent(e)) + require.NoError(t, err) + + event, err := NewTeleportEvent(protoEvent, "cursor") require.NoError(t, err) assert.Equal(t, "test", event.ID) assert.Equal(t, "mock", event.Type) @@ -49,7 +53,10 @@ func TestNew(t *testing.T) { func TestGenID(t *testing.T) { e := &events.SessionPrint{} - event, err := NewTeleportEvent(eventToJSON(t, events.AuditEvent(e)), "cursor") + protoEvent, err := eventToProto(events.AuditEvent(e)) + require.NoError(t, err) + + event, err := NewTeleportEvent(protoEvent, "cursor") require.NoError(t, err) assert.NotEmpty(t, event.ID) } @@ -64,7 +71,10 @@ func TestSessionEnd(t *testing.T) { }, } - event, err := NewTeleportEvent(eventToJSON(t, events.AuditEvent(e)), "cursor") + protoEvent, err := eventToProto(events.AuditEvent(e)) + require.NoError(t, err) + + event, err := NewTeleportEvent(protoEvent, "cursor") require.NoError(t, err) assert.NotEmpty(t, event.ID) assert.NotEmpty(t, event.SessionID) @@ -81,7 +91,10 @@ func TestFailedLogin(t *testing.T) { }, } - event, err := NewTeleportEvent(eventToJSON(t, events.AuditEvent(e)), "cursor") + protoEvent, err := eventToProto(events.AuditEvent(e)) + require.NoError(t, err) + + event, err := NewTeleportEvent(protoEvent, "cursor") require.NoError(t, err) assert.NotEmpty(t, event.ID) assert.True(t, event.IsFailedLogin) @@ -97,28 +110,49 @@ func TestSuccessLogin(t *testing.T) { }, } - event, err := NewTeleportEvent(eventToJSON(t, events.AuditEvent(e)), "cursor") + protoEvent, err := eventToProto(events.AuditEvent(e)) + require.NoError(t, err) + + event, err := NewTeleportEvent(protoEvent, "cursor") require.NoError(t, err) assert.NotEmpty(t, event.ID) assert.False(t, event.IsFailedLogin) } -func eventToJSON(t *testing.T, e events.AuditEvent) *auditlogpb.EventUnstructured { +func eventToProto(e events.AuditEvent) (*auditlogpb.EventUnstructured, error) { data, err := lib.FastMarshal(e) - require.NoError(t, err) + if err != nil { + return nil, trace.Wrap(err) + } + str := &structpb.Struct{} - err = str.UnmarshalJSON(data) - require.NoError(t, err) + if err = str.UnmarshalJSON(data); err != nil { + return nil, trace.Wrap(err) + } + id := e.GetID() if id == "" { hash := sha256.Sum256(data) id = hex.EncodeToString(hash[:]) } + return &auditlogpb.EventUnstructured{ Type: e.GetType(), Unstructured: str, Id: id, Index: e.GetIndex(), Time: timestamppb.New(e.GetTime()), + }, nil +} + +func eventsToProto(events []events.AuditEvent) ([]*auditlogpb.EventUnstructured, error) { + protoEvents := make([]*auditlogpb.EventUnstructured, len(events)) + for i, event := range events { + protoEvent, err := eventToProto(event) + if err != nil { + return nil, trace.Wrap(err) + } + protoEvents[i] = protoEvent } + return protoEvents, nil } diff --git a/event-handler/teleport_events_watcher.go b/event-handler/teleport_events_watcher.go index 239d8033d..7ee45b9b7 100644 --- a/event-handler/teleport_events_watcher.go +++ b/event-handler/teleport_events_watcher.go @@ -20,6 +20,10 @@ import ( "fmt" "time" + "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" + "golang.org/x/net/context" + "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" @@ -28,9 +32,6 @@ import ( "github.com/gravitational/teleport/integrations/lib" "github.com/gravitational/teleport/integrations/lib/credentials" "github.com/gravitational/teleport/integrations/lib/logger" - "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" - "golang.org/x/net/context" ) const ( diff --git a/event-handler/teleport_events_watcher_test.go b/event-handler/teleport_events_watcher_test.go index abfb6d1f6..2edcc9142 100644 --- a/event-handler/teleport_events_watcher_test.go +++ b/event-handler/teleport_events_watcher_test.go @@ -18,23 +18,24 @@ package main import ( "errors" + "sync" "testing" "time" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "golang.org/x/net/context" + "github.com/gravitational/teleport/api/client/proto" auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/trace" - "github.com/stretchr/testify/require" - "golang.org/x/net/context" ) // mockTeleportEventWatcher is Teleport client mock type mockTeleportEventWatcher struct { // events is the mock list of events events []events.AuditEvent - t *testing.T } func (c *mockTeleportEventWatcher) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) { @@ -51,11 +52,12 @@ func (c *mockTeleportEventWatcher) SearchUnstructuredEvents(ctx context.Context, e := c.events c.events = nil - events := make([]*auditlogpb.EventUnstructured, len(e)) - for i, event := range e { - events[i] = eventToJSON(c.t, event) + protoEvents, err := eventsToProto(e) + if err != nil { + return nil, "", trace.Wrap(err) } - return events, "test", nil + + return protoEvents, "test", nil } func (c *mockTeleportEventWatcher) StreamUnstructuredSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan *auditlogpb.EventUnstructured, chan error) { @@ -76,11 +78,9 @@ func (c *mockTeleportEventWatcher) Close() error { return nil } -func newTeleportEventWatcher(t *testing.T, e []events.AuditEvent, exitOnLastEvent bool) *TeleportEventsWatcher { - teleportEventWatcher := &mockTeleportEventWatcher{events: e, t: t} - +func newTeleportEventWatcher(t *testing.T, eventsClient TeleportSearchEventsClient, exitOnLastEvent bool) *TeleportEventsWatcher { client := &TeleportEventsWatcher{ - client: teleportEventWatcher, + client: eventsClient, pos: -1, config: &StartCmdConfig{ IngestConfig: IngestConfig{ @@ -94,124 +94,73 @@ func newTeleportEventWatcher(t *testing.T, e []events.AuditEvent, exitOnLastEven } func TestNext(t *testing.T) { + const mockEventID = "1" e := []events.AuditEvent{ &events.UserCreate{ Metadata: events.Metadata{ - ID: "1", + ID: mockEventID, }, }, &events.UserDelete{ Metadata: events.Metadata{ - ID: "", + ID: mockEventID, }, }, } - client := newTeleportEventWatcher(t, e, true) + mockEventWatcher := &mockTeleportEventWatcher{e} + client := newTeleportEventWatcher(t, mockEventWatcher, true) chEvt, chErr := client.Events(context.Background()) select { case err := <-chErr: - require.Fail(t, "received unexpected error from error channel", "error: %v", err) + t.Fatalf("received unexpected error from error channel: %v", err) case e := <-chEvt: require.NotNil(t, e.Event) - require.Equal(t, "1", e.ID) + require.Equal(t, mockEventID, e.ID) case <-time.After(time.Second): - require.Fail(t, "No events were sent") + t.Fatalf("No events received withing one second") } select { case err := <-chErr: - require.Fail(t, "received unexpected error from error channel", "error: %v", err) + t.Fatalf("received unexpected error from error channel: %v", err) case e := <-chEvt: require.NotNil(t, e.Event) - require.Equal(t, "081ca05eea09ac0cd06e2d2acd06bec424146b254aa500de37bdc2c2b0a4dd0f", e.ID) + require.Equal(t, mockEventID, e.ID) case <-time.After(time.Second): - require.Fail(t, "No events were sent") + t.Fatalf("No events received withing one second") } } // errMockTeleportEventWatcher is Teleport client mock that returns an error after the first SearchUnstructuredEvents type errMockTeleportEventWatcher struct { - // events is the mock list of events - events []events.AuditEvent - t *testing.T - called bool -} - -func (c *errMockTeleportEventWatcher) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) { - e := c.events - c.events = nil - return e, "test", nil -} - -func (c *errMockTeleportEventWatcher) StreamSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan events.AuditEvent, chan error) { - return nil, nil + mockTeleportEventWatcher + searchUnstructuredEventsCalled bool } func (c *errMockTeleportEventWatcher) SearchUnstructuredEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]*auditlogpb.EventUnstructured, string, error) { - if c.called { + if c.searchUnstructuredEventsCalled { return nil, "", errors.New("error") } + defer func() { c.searchUnstructuredEventsCalled = true }() - e := c.events - c.events = nil - - events := make([]*auditlogpb.EventUnstructured, len(e)) - for i, event := range e { - events[i] = eventToJSON(c.t, event) - } - c.called = true - return events, "", nil -} - -func (c *errMockTeleportEventWatcher) StreamUnstructuredSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan *auditlogpb.EventUnstructured, chan error) { - return nil, nil -} - -func (c *errMockTeleportEventWatcher) UpsertLock(ctx context.Context, lock types.Lock) error { - return nil -} - -func (c *errMockTeleportEventWatcher) Ping(ctx context.Context) (proto.PingResponse, error) { - return proto.PingResponse{ - ServerVersion: Version, - }, nil -} - -// Close is mock close method -func (c *errMockTeleportEventWatcher) Close() error { - return nil -} - -func newErrTeleportEventWatcher(t *testing.T, e []events.AuditEvent, exitOnLastEvent bool) *TeleportEventsWatcher { - teleportEventWatcher := &errMockTeleportEventWatcher{events: e, t: t} - - client := &TeleportEventsWatcher{ - client: teleportEventWatcher, - pos: -1, - config: &StartCmdConfig{ - IngestConfig: IngestConfig{ - BatchSize: 5, - ExitOnLastEvent: exitOnLastEvent, - }, - }, - } - - return client + return c.mockTeleportEventWatcher.SearchUnstructuredEvents(ctx, fromUTC, toUTC, namespace, eventTypes, limit, order, startKey) } func TestLastEvent(t *testing.T) { t.Run("should not leave hanging go-routines", func(t *testing.T) { + const mockEventID = "1" e := []events.AuditEvent{ &events.UserCreate{ Metadata: events.Metadata{ - ID: "1", + ID: mockEventID, }, }, } - client := newErrTeleportEventWatcher(t, e, false) + mockEventWatcher := &errMockTeleportEventWatcher{mockTeleportEventWatcher: mockTeleportEventWatcher{e}} + client := newTeleportEventWatcher(t, mockEventWatcher, true) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) t.Cleanup(cancel) @@ -220,33 +169,38 @@ func TestLastEvent(t *testing.T) { select { case err := <-chErr: - require.Fail(t, "received unexpected error from error channel", "error: %v", err) + t.Fatalf("received unexpected error from error channel: %v", err) case e := <-chEvt: require.NotNil(t, e.Event) - require.Equal(t, "1", e.ID) + require.Equal(t, mockEventID, e.ID) case <-time.After(time.Second): - require.Fail(t, "No events were sent") + t.Fatalf("No events received withing one second") } - allDone := make(chan struct{}) + var wg sync.WaitGroup - const nIters = 5 - for i := 0; i < nIters; i++ { + const numGoroutines = 5 + for i := 0; i < numGoroutines; i++ { + wg.Add(1) go func() { + defer wg.Done() chEvt, _ := client.Events(ctx) - // we're assuming that a closed channel == closed goroutine + // consume events. for range chEvt { } - allDone <- struct{}{} }() } - for i := 0; i < nIters; i++ { - select { - case <-allDone: - case <-ctx.Done(): - require.Fail(t, "timeout reached, some goroutines were not closed") - } + goroutinesDone := make(chan struct{}) + go func() { + wg.Wait() + close(goroutinesDone) + }() + + select { + case <-goroutinesDone: + case <-ctx.Done(): + require.Fail(t, "timeout reached, some goroutines were not closed") } }) }