Skip to content

Commit

Permalink
Determine session type on stream based on stream watcher events (#50395
Browse files Browse the repository at this point in the history
…) (#50592)

* refactor: determine session type on stream based on watcher

* refactor(auth): early return when is teleport server

* refactor: move to function and add tests

* chore(lib): fix lint

* refactor(events): make private function
  • Loading branch information
gabrielcorado authored Dec 28, 2024
1 parent 24102bd commit a06f8e0
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 52 deletions.
1 change: 1 addition & 0 deletions api/types/session_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const (
DatabaseSessionKind SessionKind = "db"
AppSessionKind SessionKind = "app"
WindowsDesktopSessionKind SessionKind = "desktop"
UnknownSessionKind SessionKind = ""
)

// SessionParticipantMode is the mode that determines what you can do when you join a session.
Expand Down
84 changes: 42 additions & 42 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,30 +211,14 @@ func (a *ServerWithRoles) actionWithExtendedContext(namespace, kind, verb string
// actionForKindSession is a special checker that grants access to session
// recordings. It can allow access to a specific recording based on the
// `where` section of the user's access rule for kind `session`.
func (a *ServerWithRoles) actionForKindSession(ctx context.Context, namespace string, sid session.ID) (types.SessionKind, error) {
sessionEnd, err := a.findSessionEndEvent(ctx, sid)

extendContext := func(ctx *services.Context) error {
ctx.Session = sessionEnd
func (a *ServerWithRoles) actionForKindSession(ctx context.Context, namespace string, sid session.ID) error {
extendContext := func(servicesCtx *services.Context) error {
sessionEnd, err := a.findSessionEndEvent(ctx, sid)
servicesCtx.Session = sessionEnd
return trace.Wrap(err)
}

var sessionKind types.SessionKind
switch e := sessionEnd.(type) {
case *apievents.SessionEnd:
sessionKind = types.SSHSessionKind
if e.KubernetesCluster != "" {
sessionKind = types.KubernetesSessionKind
}
case *apievents.DatabaseSessionEnd:
sessionKind = types.DatabaseSessionKind
case *apievents.AppSessionEnd:
sessionKind = types.AppSessionKind
case *apievents.WindowsDesktopSessionEnd:
sessionKind = types.WindowsDesktopSessionKind
}

return sessionKind, trace.Wrap(a.actionWithExtendedContext(namespace, types.KindSession, types.VerbRead, extendContext))
return trace.Wrap(a.actionWithExtendedContext(namespace, types.KindSession, types.VerbRead, extendContext))
}

// localServerAction returns an access denied error if the role is not one of the builtin server roles.
Expand Down Expand Up @@ -6063,44 +6047,60 @@ func (a *ServerWithRoles) ReplaceRemoteLocks(ctx context.Context, clusterName st
// channel if one is encountered. Otherwise the event channel is closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
createErrorChannel := func(err error) (chan apievents.AuditEvent, chan error) {
e := make(chan error, 1)
e <- trace.Wrap(err)
return nil, e
}

err := a.localServerAction()
isTeleportServer := err == nil

var sessionType types.SessionKind
if !isTeleportServer {
var err error
sessionType, err = a.actionForKindSession(ctx, apidefaults.Namespace, sessionID)
if err != nil {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.Wrap(err)
return c, e
}
// StreamSessionEvents can be called internally, and when that
// happens we don't want to emit an event or check for permissions.
if isTeleportServer {
return a.alog.StreamSessionEvents(ctx, sessionID, startIndex)
}

// StreamSessionEvents can be called internally, and when that happens we don't want to emit an event.
shouldEmitAuditEvent := !isTeleportServer
if shouldEmitAuditEvent {
if err := a.actionForKindSession(ctx, apidefaults.Namespace, sessionID); err != nil {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.Wrap(err)
return c, e
}

// We can only determine the session type after the streaming started. For
// this reason, we delay the emit audit event until the first event or if
// the streaming returns an error.
cb := func(evt apievents.AuditEvent, _ error) {
if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{
Metadata: apievents.Metadata{
Type: events.SessionRecordingAccessEvent,
Code: events.SessionRecordingAccessCode,
},
SessionID: sessionID.String(),
UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(),
SessionType: string(sessionType),
SessionType: string(sessionTypeFromStartEvent(evt)),
Format: metadata.SessionRecordingFormatFromContext(ctx),
}); err != nil {
return createErrorChannel(err)
log.WithError(err).Errorf("Failed to emit stream session event audit event")
}
}

return a.alog.StreamSessionEvents(ctx, sessionID, startIndex)
return a.alog.StreamSessionEvents(events.ContextWithSessionStartCallback(ctx, cb), sessionID, startIndex)
}

// sessionTypeFromStartEvent determines the session type given the session start
// event.
func sessionTypeFromStartEvent(sessionStart apievents.AuditEvent) types.SessionKind {
switch e := sessionStart.(type) {
case *apievents.SessionStart:
if e.KubernetesCluster != "" {
return types.KubernetesSessionKind
}
return types.SSHSessionKind
case *apievents.DatabaseSessionStart:
return types.DatabaseSessionKind
case *apievents.AppSessionStart:
return types.AppSessionKind
case *apievents.WindowsDesktopSessionStart:
return types.WindowsDesktopSessionKind
default:
return types.UnknownSessionKind
}
}

// CreateApp creates a new application resource.
Expand Down
49 changes: 39 additions & 10 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2267,7 +2267,29 @@ func TestStreamSessionEvents(t *testing.T) {
func TestStreamSessionEvents_SessionType(t *testing.T) {
t.Parallel()

srv := newTestTLSServer(t)
authServerConfig := TestAuthServerConfig{
Dir: t.TempDir(),
Clock: clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()),
}
require.NoError(t, authServerConfig.CheckAndSetDefaults())

uploader := eventstest.NewMemoryUploader()
localLog, err := events.NewAuditLog(events.AuditLogConfig{
DataDir: authServerConfig.Dir,
ServerID: authServerConfig.ClusterName,
Clock: authServerConfig.Clock,
UploadHandler: uploader,
})
require.NoError(t, err)
authServerConfig.AuditLog = localLog

as, err := NewTestAuthServer(authServerConfig)
require.NoError(t, err)

srv, err := as.NewTestTLSServer()
require.NoError(t, err)
t.Cleanup(func() { srv.Close() })

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

Expand All @@ -2278,22 +2300,29 @@ func TestStreamSessionEvents_SessionType(t *testing.T) {
identity := TestUser(user.GetName())
clt, err := srv.NewClient(identity)
require.NoError(t, err)
sessionID := "44c6cea8-362f-11ea-83aa-125400432324"
sessionID := session.NewID()

// Emitting a session end event will cause the listing to correctly locate
// the recording (even if there might not be a recording file to stream).
require.NoError(t, srv.Auth().EmitAuditEvent(ctx, &apievents.DatabaseSessionEnd{
streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{
Uploader: uploader,
})
require.NoError(t, err)
stream, err := streamer.CreateAuditStream(ctx, sessionID)
require.NoError(t, err)
// The event is not required to pass through the auth server, we only need
// the upload to be present.
require.NoError(t, stream.RecordEvent(ctx, eventstest.PrepareEvent(&apievents.DatabaseSessionStart{
Metadata: apievents.Metadata{
Type: events.DatabaseSessionEndEvent,
Code: events.DatabaseSessionEndCode,
Type: events.DatabaseSessionStartEvent,
Code: events.DatabaseSessionStartCode,
},
SessionMetadata: apievents.SessionMetadata{
SessionID: sessionID,
SessionID: sessionID.String(),
},
}))
})))
require.NoError(t, stream.Complete(ctx))

accessedFormat := teleport.PTY
clt.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, accessedFormat), session.ID(sessionID), 0)
clt.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, accessedFormat), sessionID, 0)

// Perform the listing an eventually loop to ensure the event is emitted.
var searchEvents []apievents.AuditEvent
Expand Down
60 changes: 60 additions & 0 deletions lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,23 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
e := make(chan error, 1)
c := make(chan apievents.AuditEvent)

sessionStartCh := make(chan apievents.AuditEvent, 1)
if startCb, err := sessionStartCallbackFromContext(ctx); err == nil {
go func() {
evt, ok := <-sessionStartCh
if !ok {
startCb(nil, trace.NotFound("session start event not found"))
return
}

startCb(evt, nil)
}()
}

rawSession, err := os.CreateTemp(l.playbackDir, string(sessionID)+".stream.tar.*")
if err != nil {
e <- trace.Wrap(trace.ConvertSystemError(err), "creating temporary stream file")
close(sessionStartCh)
return c, e
}
// The file is still perfectly usable after unlinking it, and the space it's
Expand All @@ -528,6 +542,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
if err := os.Remove(rawSession.Name()); err != nil {
_ = rawSession.Close()
e <- trace.Wrap(trace.ConvertSystemError(err), "removing temporary stream file")
close(sessionStartCh)
return c, e
}

Expand All @@ -538,6 +553,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
err = trace.NotFound("a recording for session %v was not found", sessionID)
}
e <- trace.Wrap(err)
close(sessionStartCh)
return c, e
}
l.log.DebugContext(ctx, "Downloaded session to a temporary file for streaming.",
Expand All @@ -547,6 +563,8 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID

go func() {
defer rawSession.Close()
defer close(sessionStartCh)

// this shouldn't be necessary as the position should be already 0 (Download
// takes an io.WriterAt), but it's better to be safe than sorry
if _, err := rawSession.Seek(0, io.SeekStart); err != nil {
Expand All @@ -557,6 +575,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
protoReader := NewProtoReader(rawSession)
defer protoReader.Close()

firstEvent := true
for {
if ctx.Err() != nil {
e <- trace.Wrap(ctx.Err())
Expand All @@ -573,6 +592,11 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
return
}

if firstEvent {
sessionStartCh <- event
firstEvent = false
}

if event.GetIndex() >= startIndex {
select {
case c <- event:
Expand Down Expand Up @@ -667,3 +691,39 @@ func (l *AuditLog) periodicSpaceMonitor() {
}
}
}

// streamSessionEventsContextKey represent context keys used by
// StreamSessionEvents function.
type streamSessionEventsContextKey string

const (
// sessionStartCallbackContextKey is the context key used to store the
// session start callback function.
sessionStartCallbackContextKey streamSessionEventsContextKey = "session-start"
)

// SessionStartCallback is the function used when streaming reaches the start
// event. If any error, such as session not found, the event will be nil, and
// the error will be set.
type SessionStartCallback func(startEvent apievents.AuditEvent, err error)

// ContextWithSessionStartCallback returns a context.Context containing a
// session start event callback.
func ContextWithSessionStartCallback(ctx context.Context, cb SessionStartCallback) context.Context {
return context.WithValue(ctx, sessionStartCallbackContextKey, cb)
}

// sessionStartCallbackFromContext returns the session start callback from
// context.Context.
func sessionStartCallbackFromContext(ctx context.Context) (SessionStartCallback, error) {
if ctx == nil {
return nil, trace.BadParameter("context is nil")
}

cb, ok := ctx.Value(sessionStartCallbackContextKey).(SessionStartCallback)
if !ok {
return nil, trace.BadParameter("session start callback function was not found in the context")
}

return cb, nil
}
Loading

0 comments on commit a06f8e0

Please sign in to comment.