Skip to content

Commit

Permalink
Prevent racy access to session parties (#47935)
Browse files Browse the repository at this point in the history
Prefer using session.getParties instead of using session.parties
directly to prevent races when new parties are added. Any functions
that are using session.parties AND are called from another function
that already obtains the lock have been renamed to reflect that
they must only be called if the session lock is held.
  • Loading branch information
rosstimothy authored Oct 25, 2024
1 parent 39e88a6 commit 5ceeae9
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,9 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann
return trace.Wrap(err)
}

canStart, _, err := sess.checkIfStart()
sess.mu.Lock()
canStart, _, err := sess.checkIfStartUnderLock()
sess.mu.Unlock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -507,7 +509,7 @@ func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, erro
sess.fileTransferReq = nil

sess.BroadcastMessage("file transfer request %s denied due to %s attempting to transfer files", req.ID, scx.Identity.TeleportUser)
_ = s.NotifyFileTransferRequest(req, FileTransferDenied, scx)
_ = s.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx)

return false, trace.AccessDenied("Teleport user does not match original requester")
}
Expand Down Expand Up @@ -540,9 +542,9 @@ const (
FileTransferDenied FileTransferRequestEvent = "file_transfer_request_deny"
)

// NotifyFileTransferRequest is called to notify all members of a party that a file transfer request has been created/approved/denied.
// notifyFileTransferRequestUnderLock is called to notify all members of a party that a file transfer request has been created/approved/denied.
// The notification is a global ssh request and requires the client to update its UI state accordingly.
func (s *SessionRegistry) NotifyFileTransferRequest(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error {
func (s *SessionRegistry) notifyFileTransferRequestUnderLock(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error {
session := scx.getSession()
if session == nil {
s.log.Debugf("Unable to notify %s, no session found in context.", res)
Expand Down Expand Up @@ -1081,7 +1083,7 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) {

// Notify all members of the party that a new member has joined over the
// "x-teleport-event" channel.
for _, p := range s.parties {
for _, p := range s.getParties() {
if len(notifyPartyPayload) == 0 {
s.log.Warnf("No join event to send to %v", p.sconn.RemoteAddr())
continue
Expand All @@ -1099,10 +1101,10 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) {
}
}

// emitSessionLeaveEvent emits a session leave event to both the Audit Log as
// emitSessionLeaveEventUnderLock emits a session leave event to both the Audit Log as
// well as sending a "x-teleport-event" global request on the SSH connection.
// Must be called under session Lock.
func (s *session) emitSessionLeaveEvent(ctx *ServerContext) {
func (s *session) emitSessionLeaveEventUnderLock(ctx *ServerContext) {
sessionLeaveEvent := &apievents.SessionLeave{
Metadata: apievents.Metadata{
Type: events.SessionLeaveEvent,
Expand Down Expand Up @@ -1296,7 +1298,9 @@ func (s *session) launch() {
// startInteractive starts a new interactive process (or a shell) in the
// current session.
func (s *session) startInteractive(ctx context.Context, scx *ServerContext, p *party) error {
canStart, _, err := s.checkIfStart()
s.mu.Lock()
canStart, _, err := s.checkIfStartUnderLock()
s.mu.Unlock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1556,19 +1560,16 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve
}

func (s *session) broadcastResult(r ExecResult) {
s.mu.Lock()
defer s.mu.Unlock()

payload := ssh.Marshal(struct{ C uint32 }{C: uint32(r.Code)})
for _, p := range s.parties {
for _, p := range s.getParties() {
if _, err := p.ch.SendRequest("exit-status", false, payload); err != nil {
s.log.Infof("Failed to send exit status for %v: %v", r.Command, err)
}
}
}

func (s *session) String() string {
return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.parties))
return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.getParties()))
}

// removePartyUnderLock removes the party from the in-memory map that holds all party members
Expand All @@ -1594,9 +1595,9 @@ func (s *session) removePartyUnderLock(p *party) error {

// Emit session leave event to both the Audit Log and over the
// "x-teleport-event" channel in the SSH connection.
s.emitSessionLeaveEvent(p.ctx)
s.emitSessionLeaveEventUnderLock(p.ctx)

canRun, policyOptions, err := s.checkIfStart()
canRun, policyOptions, err := s.checkIfStartUnderLock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1821,7 +1822,7 @@ func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestPar
} else {
s.BroadcastMessage("User %s would like to upload %s to: %s", params.Requester, params.Filename, params.Location)
}
err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, FileTransferUpdate, scx)
err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, FileTransferUpdate, scx)

return trace.Wrap(err)
}
Expand Down Expand Up @@ -1864,7 +1865,7 @@ func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisi
} else {
eventType = FileTransferUpdate
}
err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, eventType, scx)
err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, eventType, scx)

return trace.Wrap(err)
}
Expand Down Expand Up @@ -1897,12 +1898,15 @@ func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionP
s.fileTransferReq = nil

s.BroadcastMessage("%s denied file transfer request %s", scx.Identity.TeleportUser, req.ID)
err := s.registry.NotifyFileTransferRequest(req, FileTransferDenied, scx)
err := s.registry.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx)

return trace.Wrap(err)
}

func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) {
// checkIfStartUnderLock determines if any moderation policies associated with
// the session are satisfied.
// Must be called under session Lock.
func (s *session) checkIfStartUnderLock() (bool, auth.PolicyOptions, error) {
var participants []auth.SessionAccessContext

for _, party := range s.parties {
Expand Down Expand Up @@ -1941,7 +1945,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
}

if len(s.parties) == 0 {
canStart, _, err := s.checkIfStart()
canStart, _, err := s.checkIfStartUnderLock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1994,7 +1998,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
}

if s.tracker.GetState() == types.SessionState_SessionStatePending {
canStart, _, err := s.checkIfStart()
canStart, _, err := s.checkIfStartUnderLock()
if err != nil {
return trace.Wrap(err)
}
Expand Down

0 comments on commit 5ceeae9

Please sign in to comment.