diff --git a/api/types/user.go b/api/types/user.go index f509ec508e991..8594d3805df12 100644 --- a/api/types/user.go +++ b/api/types/user.go @@ -521,11 +521,15 @@ func (u UserV2) GetGCPServiceAccounts() []string { // GetUserType indicates if the User was created by an SSO Provider or locally. func (u UserV2) GetUserType() UserType { - if u.GetCreatedBy().Connector == nil { - return UserTypeLocal + if u.GetCreatedBy().Connector != nil || + len(u.GetOIDCIdentities()) > 0 || + len(u.GetGithubIdentities()) > 0 || + len(u.GetSAMLIdentities()) > 0 { + + return UserTypeSSO } - return UserTypeSSO + return UserTypeLocal } // IsBot returns true if the user is a bot. diff --git a/constants.go b/constants.go index 37ac20acf60ad..892923c66d1d9 100644 --- a/constants.go +++ b/constants.go @@ -538,6 +538,10 @@ const ( // HomeDirNotFound is returned when a the "teleport checkhomedir" command cannot // find the user's home directory. HomeDirNotFound = 254 + // HomeDirNotAccessible is returned when a the "teleport checkhomedir" command has + // found the user's home directory, but the user does NOT have permissions to + // access it. + HomeDirNotAccessible = 253 ) // MaxEnvironmentFileLines is the maximum number of lines in a environment file. diff --git a/docs/pages/enroll-resources/desktop-access/rbac.mdx b/docs/pages/enroll-resources/desktop-access/rbac.mdx index 382de7010c32d..a9bb4bd822459 100644 --- a/docs/pages/enroll-resources/desktop-access/rbac.mdx +++ b/docs/pages/enroll-resources/desktop-access/rbac.mdx @@ -11,7 +11,7 @@ desktop access: ```yaml kind: role -version: v4 +version: v5 metadata: name: developer spec: @@ -31,6 +31,12 @@ spec: # the clipboard, then it will be disabled. desktop_clipboard: true + # Specify whether directory sharing should be allowed from the + # local machine to remote desktop (requires a supported browser). Defaults to true + # if unspecified. If one or more of the user's roles has disabled + # directory sharing, then it will be disabled. + desktop_directory_sharing: true + # Specify whether local users should be created automatically at connection # time. By default, this feature is disabled, and the user must already exist. # Note: this is applicable to local users only and is not supported in Active diff --git a/docs/pages/includes/role-spec.mdx b/docs/pages/includes/role-spec.mdx index a604bd49f5111..e899a5536ed7a 100644 --- a/docs/pages/includes/role-spec.mdx +++ b/docs/pages/includes/role-spec.mdx @@ -84,6 +84,11 @@ spec: # if unspecified. If one or more of the user's roles has disabled # the clipboard, then it will be disabled. desktop_clipboard: true + # Specify whether directory sharing should be allowed from the + # local machine to remote desktop (requires a supported browser). Defaults to true + # if unspecified. If one or more of the user's roles has disabled + # directory sharing, then it will be disabled. + desktop_directory_sharing: true # enterprise-only: when enabled, the source IP that was used to log in is embedded in the user # certificates, preventing a compromised certificate from being used on another # network. The default is false. diff --git a/docs/pages/reference/access-controls/roles.mdx b/docs/pages/reference/access-controls/roles.mdx index cea87e934d430..e3ca06f4a73aa 100644 --- a/docs/pages/reference/access-controls/roles.mdx +++ b/docs/pages/reference/access-controls/roles.mdx @@ -67,6 +67,7 @@ user: | `max_kubernetes_connections` | Defines the maximum number of concurrent Kubernetes sessions per user | | | `record_session` |Defines the [Session recording mode](../monitoring/audit.mdx).|The strictest value takes precedence.| | `desktop_clipboard` | Allow clipboard sharing for desktop sessions | Logical "AND" i.e. evaluates to "yes" if all roles enable clipboard sharing | +| `desktop_directory_sharing` | Allows sharing local workstation directory to remote desktop | Logical "AND" i.e. evaluates to "yes" if all roles enable directory sharing | | `pin_source_ip` | Enable source IP pinning for SSH certificates. | Logical "OR" i.e. evaluates to "yes" if at least one role requires session termination | | `cert_extensions` | Specifies extensions to be included in SSH certificates | | | `create_host_user_mode` | Allow users to be automatically created on a host | Logical "AND" i.e. if all roles matching a server specify host user creation (`off`, `keep`, `insecure-drop`), it will evaluate to the option specified by all of the roles. If some roles specify both `insecure-drop` or `keep` it will evaluate to `keep`| diff --git a/integration/integration_test.go b/integration/integration_test.go index 0d20d726cef75..745071ed3c020 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -4970,6 +4970,13 @@ func testProxyHostKeyCheck(t *testing.T, suite *integrationTestSuite) { _, err = clt.UpsertNode(context.Background(), server) require.NoError(t, err) + // Wait for the node to be visible before continuing. + require.EventuallyWithT(t, func(t *assert.CollectT) { + found, err := clt.GetNodes(context.Background(), defaults.Namespace) + assert.NoError(t, err) + assert.Len(t, found, 2) + }, 10*time.Second, 100*time.Millisecond) + _, err = runCommand(t, instance, []string{"echo hello"}, clientConfig, 1) // check if we were able to exec the command or not diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 63325136dd92f..eb57c6b60de8f 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -471,9 +471,9 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { log.Warnf("missing connected resources gauge for keep alive %s (this is a bug)", s) } }), - inventory.WithOnDisconnect(func(s string) { + inventory.WithOnDisconnect(func(s string, c int) { if g, ok := connectedResourceGauges[s]; ok { - g.Dec() + g.Sub(float64(c)) } else { log.Warnf("missing connected resources gauge for keep alive %s (this is a bug)", s) } diff --git a/lib/inventory/controller.go b/lib/inventory/controller.go index ae5258cf97630..e401c4e239c01 100644 --- a/lib/inventory/controller.go +++ b/lib/inventory/controller.go @@ -105,7 +105,7 @@ type controllerOptions struct { maxKeepAliveErrs int authID string onConnectFunc func(string) - onDisconnectFunc func(string) + onDisconnectFunc func(string, int) } func (options *controllerOptions) SetDefaults() { @@ -127,11 +127,11 @@ func (options *controllerOptions) SetDefaults() { } if options.onConnectFunc == nil { - options.onConnectFunc = func(s string) {} + options.onConnectFunc = func(string) {} } if options.onDisconnectFunc == nil { - options.onDisconnectFunc = func(s string) {} + options.onDisconnectFunc = func(string, int) {} } } @@ -154,12 +154,12 @@ func WithOnConnect(f func(heartbeatKind string)) ControllerOption { } } -// WithOnDisconnect sets a function to be called every time an existing -// instance disconnects from the inventory control stream. The value -// provided to the callback is the keep alive type of the disconnected -// resource. The callback should return quickly so as not to prevent -// processing of heartbeats. -func WithOnDisconnect(f func(heartbeatKind string)) ControllerOption { +// WithOnDisconnect sets a function to be called every time an existing instance +// disconnects from the inventory control stream. The values provided to the +// callback are the keep alive type of the disconnected resource, as well as a +// count of how many resources disconnected at once. The callback should return +// quickly so as not to prevent processing of heartbeats. +func WithOnDisconnect(f func(heartbeatKind string, amount int)) ControllerOption { return func(opts *controllerOptions) { opts.onDisconnectFunc = f } @@ -200,7 +200,7 @@ type Controller struct { usageReporter usagereporter.UsageReporter testEvents chan testEvent onConnectFunc func(string) - onDisconnectFunc func(string) + onDisconnectFunc func(string, int) closeContext context.Context cancel context.CancelFunc } @@ -324,7 +324,10 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) { defer func() { if handle.goodbye.GetDeleteResources() { - log.WithField("apps", len(handle.appServers)).Debug("Cleaning up resources in response to instance termination") + log.WithFields(log.Fields{ + "apps": len(handle.appServers), + "server_id": handle.Hello().ServerID, + }).Debug("Cleaning up resources in response to instance termination") for _, app := range handle.appServers { if err := c.auth.DeleteApplicationServer(c.closeContext, apidefaults.Namespace, app.resource.GetHostID(), app.resource.GetName()); err != nil && !trace.IsNotFound(err) { log.Warnf("Failed to remove app server %q on termination: %v.", handle.Hello().ServerID, err) @@ -341,11 +344,11 @@ func (c *Controller) handleControlStream(handle *upstreamHandle) { handle.ticker.Stop() if handle.sshServer != nil { - c.onDisconnectFunc(constants.KeepAliveNode) + c.onDisconnectFunc(constants.KeepAliveNode, 1) } - for range handle.appServers { - c.onDisconnectFunc(constants.KeepAliveApp) + if len(handle.appServers) > 0 { + c.onDisconnectFunc(constants.KeepAliveApp, len(handle.appServers)) } clear(handle.appServers) @@ -677,6 +680,7 @@ func (c *Controller) keepAliveAppServer(handle *upstreamHandle, now time.Time) e if shouldRemove { c.testEvent(appKeepAliveDel) + c.onDisconnectFunc(constants.KeepAliveApp, 1) delete(handle.appServers, name) } } else { diff --git a/lib/inventory/controller_test.go b/lib/inventory/controller_test.go index 323bc712e21b0..66c7bd3ca5940 100644 --- a/lib/inventory/controller_test.go +++ b/lib/inventory/controller_test.go @@ -144,11 +144,14 @@ func TestSSHServerBasics(t *testing.T) { expectAddr: wantAddr, } + rc := &resourceCounter{} controller := NewController( auth, usagereporter.DiscardUsageReporter{}, withServerKeepAlive(time.Millisecond*200), withTestEventsChannel(events), + WithOnConnect(rc.onConnect), + WithOnDisconnect(rc.onDisconnect), ) defer controller.Close() @@ -282,6 +285,9 @@ func TestSSHServerBasics(t *testing.T) { // here). require.Equal(t, int64(0), controller.instanceHBVariableDuration.Count()) + // verify that metrics have been updated correctly + require.Zero(t, 0, rc.count()) + // verify that the peer address of the control stream was used to override // zero-value IPs for heartbeats. auth.mu.Lock() @@ -305,11 +311,14 @@ func TestAppServerBasics(t *testing.T) { auth := &fakeAuth{} + rc := &resourceCounter{} controller := NewController( auth, usagereporter.DiscardUsageReporter{}, withServerKeepAlive(time.Millisecond*200), withTestEventsChannel(events), + WithOnConnect(rc.onConnect), + WithOnDisconnect(rc.onDisconnect), ) defer controller.Close() @@ -500,6 +509,9 @@ func TestAppServerBasics(t *testing.T) { // always *before* closure is propagated to downstream handle, hence being safe to load // here). require.Equal(t, int64(0), controller.instanceHBVariableDuration.Count()) + + // verify that metrics have been updated correctly + require.Zero(t, rc.count()) } // TestInstanceHeartbeat verifies basic expected behaviors for instance heartbeat. @@ -897,7 +909,6 @@ func TestGoodbye(t *testing.T) { } func TestGetSender(t *testing.T) { - controller := NewController( &fakeAuth{}, usagereporter.DiscardUsageReporter{}, @@ -1008,3 +1019,37 @@ func awaitEvents(t *testing.T, ch <-chan testEvent, opts ...eventOption) { } } } + +type resourceCounter struct { + mu sync.Mutex + c map[string]int +} + +func (r *resourceCounter) onConnect(typ string) { + r.mu.Lock() + defer r.mu.Unlock() + if r.c == nil { + r.c = make(map[string]int) + } + r.c[typ]++ +} + +func (r *resourceCounter) onDisconnect(typ string, amount int) { + r.mu.Lock() + defer r.mu.Unlock() + if r.c == nil { + r.c = make(map[string]int) + } + r.c[typ] -= amount +} + +func (r *resourceCounter) count() int { + r.mu.Lock() + defer r.mu.Unlock() + + var count int + for _, v := range r.c { + count += v + } + return count +} diff --git a/lib/srv/app/connections_handler.go b/lib/srv/app/connections_handler.go index 559c812a75b07..1fade7931a4b5 100644 --- a/lib/srv/app/connections_handler.go +++ b/lib/srv/app/connections_handler.go @@ -755,6 +755,9 @@ func (c *ConnectionsHandler) deleteConnAuth(conn net.Conn) { // for Teleport application proxy servers. func CopyAndConfigureTLS(log logrus.FieldLogger, client authclient.AccessCache, config *tls.Config) *tls.Config { tlsConfig := config.Clone() + if log == nil { + log = logrus.StandardLogger() + } // Require clients to present a certificate tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert diff --git a/lib/srv/exec_linux_test.go b/lib/srv/exec_linux_test.go index 979992bf15e7d..da03e804c1641 100644 --- a/lib/srv/exec_linux_test.go +++ b/lib/srv/exec_linux_test.go @@ -26,6 +26,7 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "strconv" "syscall" "testing" @@ -41,17 +42,40 @@ import ( ) func TestOSCommandPrep(t *testing.T) { + utils.RequireRoot(t) + srv := newMockServer(t) scx := newExecServerContext(t, srv) - usr, err := user.Current() + // because CheckHomeDir now inspects access to the home directory as the actual user after a rexec, + // we need to setup a real, non-root user with a valid home directory in order for this test to + // exercise the correct paths + tempHome := t.TempDir() + require.NoError(t, os.Chmod(filepath.Dir(tempHome), 0777)) + + username := "test-os-command-prep" + scx.Identity.Login = username + _, err := host.UserAdd(username, nil, tempHome, "", "") + require.NoError(t, err) + t.Cleanup(func() { + // change homedir back so user deletion doesn't fail + changeHomeDir(t, username, tempHome) + _, err := host.UserDel(username) + require.NoError(t, err) + }) + + usr, err := user.Lookup(username) + require.NoError(t, err) + + uid, err := strconv.Atoi(usr.Uid) require.NoError(t, err) + require.NoError(t, os.Chown(tempHome, uid, -1)) expectedEnv := []string{ "LANG=en_US.UTF-8", - getDefaultEnvPath(strconv.Itoa(os.Geteuid()), defaultLoginDefsPath), + getDefaultEnvPath(usr.Uid, defaultLoginDefsPath), fmt.Sprintf("HOME=%s", usr.HomeDir), - fmt.Sprintf("USER=%s", usr.Username), + fmt.Sprintf("USER=%s", username), "SHELL=/bin/sh", "SSH_CLIENT=10.0.0.5 4817 3022", "SSH_CONNECTION=10.0.0.5 4817 127.0.0.1 3022", @@ -104,12 +128,9 @@ func TestOSCommandPrep(t *testing.T) { require.Equal(t, []string{"/bin/sh", "-c", "top"}, cmd.Args) require.Equal(t, syscall.SIGKILL, cmd.SysProcAttr.Pdeathsig) - if os.Geteuid() != 0 { - t.Skip("skipping portion of test which must run as root") - } - // Missing home directory - HOME should still be set to the given // home dir, but the command should set it's CWD to root instead. + changeHomeDir(t, username, "/wrong/place") usr.HomeDir = "/wrong/place" root := string(os.PathSeparator) expectedEnv[2] = "HOME=/wrong/place" diff --git a/lib/srv/reexec.go b/lib/srv/reexec.go index 130dd30594ac5..da1d84afa9079 100644 --- a/lib/srv/reexec.go +++ b/lib/srv/reexec.go @@ -615,6 +615,8 @@ func (o *osWrapper) startNewParker(ctx context.Context, credential *syscall.Cred type forwardHandler func(ctx context.Context, addr string, file *os.File) error +const rootDirectory = "/" + func handleLocalPortForward(ctx context.Context, addr string, file *os.File) error { conn, err := uds.FromFile(file) _ = file.Close() @@ -799,16 +801,21 @@ func RunRemoteForward() (errw io.Writer, code int, err error) { return errw, code, trace.Wrap(err) } -// runCheckHomeDir check's if the active user's $HOME dir exists. +// runCheckHomeDir checks if the active user's $HOME dir exists and is accessible. func runCheckHomeDir() (errw io.Writer, code int, err error) { - home, err := os.UserHomeDir() - if err != nil { - return io.Discard, teleport.HomeDirNotFound, nil - } - if !utils.IsDir(home) { - return io.Discard, teleport.HomeDirNotFound, nil + code = teleport.RemoteCommandSuccess + if err := hasAccessibleHomeDir(); err != nil { + switch { + case trace.IsNotFound(err), trace.IsBadParameter(err): + code = teleport.HomeDirNotFound + case trace.IsAccessDenied(err): + code = teleport.HomeDirNotAccessible + default: + code = teleport.RemoteCommandFailure + } } - return io.Discard, teleport.RemoteCommandSuccess, nil + + return io.Discard, code, nil } // runPark does nothing, forever. @@ -984,18 +991,20 @@ func buildCommand(c *ExecCommand, localUser *user.User, tty *os.File, pamEnviron // Set the command's cwd to the user's $HOME, or "/" if // they don't have an existing home dir. // TODO (atburke): Generalize this to support Windows. - exists, err := CheckHomeDir(localUser) + hasAccess, err := CheckHomeDir(localUser) if err != nil { return nil, trace.Wrap(err) - } else if exists { + } + + if hasAccess { cmd.Dir = localUser.HomeDir - } else if !exists { + } else { // Write failure to find home dir to stdout, same as OpenSSH. - msg := fmt.Sprintf("Could not set shell's cwd to home directory %q, defaulting to %q\n", localUser.HomeDir, string(os.PathSeparator)) + msg := fmt.Sprintf("Could not set shell's cwd to home directory %q, defaulting to %q\n", localUser.HomeDir, rootDirectory) if _, err := cmd.Stdout.Write([]byte(msg)); err != nil { return nil, trace.Wrap(err) } - cmd.Dir = string(os.PathSeparator) + cmd.Dir = rootDirectory } // Only set process credentials if the UID/GID of the requesting user are @@ -1157,16 +1166,73 @@ func copyCommand(ctx *ServerContext, cmdmsg *ExecCommand) { } } -// CheckHomeDir checks if the user's home dir exists +func coerceHomeDirError(usr *user.User, err error) error { + if os.IsNotExist(err) { + return trace.NotFound("home directory %q not found for user %q", usr.HomeDir, usr.Name) + } + + if os.IsPermission(err) { + return trace.AccessDenied("%q does not have permission to access %q", usr.Name, usr.HomeDir) + } + + return err +} + +// hasAccessibleHomeDir checks if the current user has access to an existing home directory. +func hasAccessibleHomeDir() error { + // this should usually be fetching a cached value + currentUser, err := user.Current() + if err != nil { + return trace.Wrap(err) + } + + fi, err := os.Stat(currentUser.HomeDir) + if err != nil { + return trace.Wrap(coerceHomeDirError(currentUser, err)) + } + + if !fi.IsDir() { + return trace.BadParameter("%q is not a directory", currentUser.HomeDir) + } + + cwd, err := os.Getwd() + if err != nil { + return trace.Wrap(err) + } + // make sure we return to the original working directory + defer os.Chdir(cwd) + + // attemping to cd into the target directory is the easiest, cross-platform way to test + // whether or not the current user has access + if err := os.Chdir(currentUser.HomeDir); err != nil { + return trace.Wrap(coerceHomeDirError(currentUser, err)) + } + + return nil +} + +// CheckHomeDir checks if the user's home directory exists and is accessible to the user. Only catastrophic +// errors will be returned, which means a missing, inaccessible, or otherwise invalid home directory will result +// in a return of (false, nil) func CheckHomeDir(localUser *user.User) (bool, error) { - if fi, err := os.Stat(localUser.HomeDir); err == nil { - return fi.IsDir(), nil + currentUser, err := user.Current() + if err != nil { + return false, trace.Wrap(err) + } + + // don't spawn a subcommand if already running as the user in question + if currentUser.Uid == localUser.Uid { + if err := hasAccessibleHomeDir(); err != nil { + if trace.IsNotFound(err) || trace.IsAccessDenied(err) || trace.IsBadParameter(err) { + return false, nil + } + + return false, trace.Wrap(err) + } + + return true, nil } - // In some environments, the user's home directory exists but isn't visible to - // root, e.g. /home is mounted to an nfs export with root_squash enabled. - // In case we are in that scenario, re-exec teleport as the user to check - // if the home dir actually does exist. executable, err := os.Executable() if err != nil { return false, trace.Wrap(err) @@ -1182,6 +1248,7 @@ func CheckHomeDir(localUser *user.User) (bool, error) { Path: executable, Args: []string{executable, teleport.CheckHomeDirSubCommand}, Env: []string{"HOME=" + localUser.HomeDir}, + Dir: rootDirectory, SysProcAttr: &syscall.SysProcAttr{ Setsid: true, Credential: credential, @@ -1192,11 +1259,13 @@ func CheckHomeDir(localUser *user.User) (bool, error) { reexecCommandOSTweaks(cmd) if err := cmd.Run(); err != nil { - if cmd.ProcessState.ExitCode() == teleport.HomeDirNotFound { - return false, nil + if cmd.ProcessState.ExitCode() == teleport.RemoteCommandFailure { + return false, trace.Wrap(err) } - return false, trace.Wrap(err) + + return false, nil } + return true, nil } diff --git a/lib/srv/reexec_test.go b/lib/srv/reexec_test.go index 4847198c52b12..69733790184b9 100644 --- a/lib/srv/reexec_test.go +++ b/lib/srv/reexec_test.go @@ -28,11 +28,13 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "strconv" "syscall" "testing" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport" @@ -312,3 +314,72 @@ func TestRootRemotePortForwardCommand(t *testing.T) { testRemotePortForwardCommand(t, login) } + +func TestRootCheckHomeDir(t *testing.T) { + utils.RequireRoot(t) + + tmp := t.TempDir() + require.NoError(t, os.Chmod(filepath.Dir(tmp), 0777)) + require.NoError(t, os.Chmod(tmp, 0777)) + + home := filepath.Join(tmp, "home") + noAccess := filepath.Join(tmp, "no_access") + file := filepath.Join(tmp, "file") + notFound := filepath.Join(tmp, "not_found") + + require.NoError(t, os.Mkdir(home, 0700)) + require.NoError(t, os.Mkdir(noAccess, 0700)) + _, err := os.Create(file) + require.NoError(t, err) + + login := utils.GenerateLocalUsername(t) + _, err = host.UserAdd(login, nil, home, "", "") + require.NoError(t, err) + t.Cleanup(func() { + // change back to accessible home so deletion works + changeHomeDir(t, login, home) + _, err := host.UserDel(login) + require.NoError(t, err) + }) + + testUser, err := user.Lookup(login) + require.NoError(t, err) + + uid, err := strconv.Atoi(testUser.Uid) + require.NoError(t, err) + + gid, err := strconv.Atoi(testUser.Gid) + require.NoError(t, err) + + require.NoError(t, os.Chown(home, uid, gid)) + require.NoError(t, os.Chown(file, uid, gid)) + + hasAccess, err := CheckHomeDir(testUser) + require.NoError(t, err) + require.True(t, hasAccess) + + changeHomeDir(t, login, file) + hasAccess, err = CheckHomeDir(testUser) + require.NoError(t, err) + require.False(t, hasAccess) + + changeHomeDir(t, login, notFound) + hasAccess, err = CheckHomeDir(testUser) + require.NoError(t, err) + require.False(t, hasAccess) + + changeHomeDir(t, login, noAccess) + hasAccess, err = CheckHomeDir(testUser) + require.NoError(t, err) + require.False(t, hasAccess) +} + +func changeHomeDir(t *testing.T, username, home string) { + usermodBin, err := exec.LookPath("usermod") + assert.NoError(t, err, "usermod binary must be present") + + cmd := exec.Command(usermodBin, "--home", home, username) + _, err = cmd.CombinedOutput() + assert.NoError(t, err, "changing home should not error") + assert.Equal(t, 0, cmd.ProcessState.ExitCode(), "changing home should exit 0") +} diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 6305304fc8e46..a0984e9fd7c50 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1485,7 +1485,7 @@ func (s *Server) HandleNewConn(ctx context.Context, ccx *sshutils.ConnectionCont // Create host user. created, userCloser, err := s.termHandlers.SessionRegistry.UpsertHostUser(identityContext) if err != nil { - log.Infof("error while creating host users: %s", err) + log.Warnf("error while creating host users: %s", err) } // Indicate that the user was created by Teleport. diff --git a/lib/srv/sess.go b/lib/srv/sess.go index e3d30af0e3249..4b92762d39107 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -420,7 +420,9 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann return trace.Wrap(err) } - canStart, _, err := sess.checkIfStart() + sess.mu.Lock() + canStart, _, err := sess.checkIfStartUnderLock() + sess.mu.Unlock() if err != nil { return trace.Wrap(err) } @@ -507,7 +509,7 @@ func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, erro sess.fileTransferReq = nil sess.BroadcastMessage("file transfer request %s denied due to %s attempting to transfer files", req.ID, scx.Identity.TeleportUser) - _ = s.NotifyFileTransferRequest(req, FileTransferDenied, scx) + _ = s.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx) return false, trace.AccessDenied("Teleport user does not match original requester") } @@ -540,9 +542,9 @@ const ( FileTransferDenied FileTransferRequestEvent = "file_transfer_request_deny" ) -// NotifyFileTransferRequest is called to notify all members of a party that a file transfer request has been created/approved/denied. +// notifyFileTransferRequestUnderLock is called to notify all members of a party that a file transfer request has been created/approved/denied. // The notification is a global ssh request and requires the client to update its UI state accordingly. -func (s *SessionRegistry) NotifyFileTransferRequest(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error { +func (s *SessionRegistry) notifyFileTransferRequestUnderLock(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error { session := scx.getSession() if session == nil { s.log.Debugf("Unable to notify %s, no session found in context.", res) @@ -1081,7 +1083,7 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) { // Notify all members of the party that a new member has joined over the // "x-teleport-event" channel. - for _, p := range s.parties { + for _, p := range s.getParties() { if len(notifyPartyPayload) == 0 { s.log.Warnf("No join event to send to %v", p.sconn.RemoteAddr()) continue @@ -1099,10 +1101,10 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) { } } -// emitSessionLeaveEvent emits a session leave event to both the Audit Log as +// emitSessionLeaveEventUnderLock emits a session leave event to both the Audit Log as // well as sending a "x-teleport-event" global request on the SSH connection. // Must be called under session Lock. -func (s *session) emitSessionLeaveEvent(ctx *ServerContext) { +func (s *session) emitSessionLeaveEventUnderLock(ctx *ServerContext) { sessionLeaveEvent := &apievents.SessionLeave{ Metadata: apievents.Metadata{ Type: events.SessionLeaveEvent, @@ -1296,7 +1298,9 @@ func (s *session) launch() { // startInteractive starts a new interactive process (or a shell) in the // current session. func (s *session) startInteractive(ctx context.Context, scx *ServerContext, p *party) error { - canStart, _, err := s.checkIfStart() + s.mu.Lock() + canStart, _, err := s.checkIfStartUnderLock() + s.mu.Unlock() if err != nil { return trace.Wrap(err) } @@ -1556,11 +1560,8 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve } func (s *session) broadcastResult(r ExecResult) { - s.mu.Lock() - defer s.mu.Unlock() - payload := ssh.Marshal(struct{ C uint32 }{C: uint32(r.Code)}) - for _, p := range s.parties { + for _, p := range s.getParties() { if _, err := p.ch.SendRequest("exit-status", false, payload); err != nil { s.log.Infof("Failed to send exit status for %v: %v", r.Command, err) } @@ -1568,7 +1569,7 @@ func (s *session) broadcastResult(r ExecResult) { } func (s *session) String() string { - return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.parties)) + return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.getParties())) } // removePartyUnderLock removes the party from the in-memory map that holds all party members @@ -1594,9 +1595,9 @@ func (s *session) removePartyUnderLock(p *party) error { // Emit session leave event to both the Audit Log and over the // "x-teleport-event" channel in the SSH connection. - s.emitSessionLeaveEvent(p.ctx) + s.emitSessionLeaveEventUnderLock(p.ctx) - canRun, policyOptions, err := s.checkIfStart() + canRun, policyOptions, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) } @@ -1821,7 +1822,7 @@ func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestPar } else { s.BroadcastMessage("User %s would like to upload %s to: %s", params.Requester, params.Filename, params.Location) } - err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, FileTransferUpdate, scx) + err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, FileTransferUpdate, scx) return trace.Wrap(err) } @@ -1864,7 +1865,7 @@ func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisi } else { eventType = FileTransferUpdate } - err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, eventType, scx) + err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, eventType, scx) return trace.Wrap(err) } @@ -1897,12 +1898,15 @@ func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionP s.fileTransferReq = nil s.BroadcastMessage("%s denied file transfer request %s", scx.Identity.TeleportUser, req.ID) - err := s.registry.NotifyFileTransferRequest(req, FileTransferDenied, scx) + err := s.registry.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx) return trace.Wrap(err) } -func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) { +// checkIfStartUnderLock determines if any moderation policies associated with +// the session are satisfied. +// Must be called under session Lock. +func (s *session) checkIfStartUnderLock() (bool, auth.PolicyOptions, error) { var participants []auth.SessionAccessContext for _, party := range s.parties { @@ -1941,7 +1945,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { } if len(s.parties) == 0 { - canStart, _, err := s.checkIfStart() + canStart, _, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) } @@ -1994,7 +1998,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { } if s.tracker.GetState() == types.SessionState_SessionStatePending { - canStart, _, err := s.checkIfStart() + canStart, _, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) } diff --git a/lib/web/ui/usercontext.go b/lib/web/ui/usercontext.go index 66677845338bc..6a06fff1890e6 100644 --- a/lib/web/ui/usercontext.go +++ b/lib/web/ui/usercontext.go @@ -104,9 +104,7 @@ func NewUserContext(user types.User, userRoles services.RoleSet, features proto. authType := authLocal // check for any SSO identities - isSSO := len(user.GetOIDCIdentities()) > 0 || - len(user.GetGithubIdentities()) > 0 || - len(user.GetSAMLIdentities()) > 0 + isSSO := user.GetUserType() == types.UserTypeSSO if isSSO { // SSO user diff --git a/lib/web/ui/usercontext_test.go b/lib/web/ui/usercontext_test.go index cd1895fa2961e..18fab7d2c277c 100644 --- a/lib/web/ui/usercontext_test.go +++ b/lib/web/ui/usercontext_test.go @@ -68,6 +68,25 @@ func TestNewUserContext(t *testing.T) { userContext, err = NewUserContext(user, roleSet, proto.Features{}, true, false) require.NoError(t, err) require.Equal(t, authSSO, userContext.AuthType) + + // test sso auth type for users with the CreatedBy.Connector field set. + // Eg users import from okta do not have any Identities, so the CreatedBy.Connector must be checked. + userCreatedExternally := &types.UserV2{ + Metadata: types.Metadata{ + Name: "root", + }, + Status: types.UserStatusV2{ + PasswordState: types.PasswordState_PASSWORD_STATE_SET, + }, + Spec: types.UserSpecV2{ + CreatedBy: types.CreatedBy{ + Connector: &types.ConnectorRef{}, + }, + }, + } + userContext, err = NewUserContext(userCreatedExternally, roleSet, proto.Features{}, true, false) + require.NoError(t, err) + require.Equal(t, authSSO, userContext.AuthType) } func TestNewUserContextCloud(t *testing.T) { diff --git a/web/packages/shared/components/FieldMultiInput/FieldMultiInput.story.tsx b/web/packages/shared/components/FieldMultiInput/FieldMultiInput.story.tsx new file mode 100644 index 0000000000000..5362236a8b24d --- /dev/null +++ b/web/packages/shared/components/FieldMultiInput/FieldMultiInput.story.tsx @@ -0,0 +1,37 @@ +/** + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import React, { useState } from 'react'; + +import Box from 'design/Box'; + +import { FieldMultiInput } from './FieldMultiInput'; + +export default { + title: 'Shared', +}; + +export function Story() { + const [items, setItems] = useState([]); + return ( + + + + ); +} +Story.storyName = 'FieldMultiInput'; diff --git a/web/packages/shared/components/FieldMultiInput/FieldMultiInput.test.tsx b/web/packages/shared/components/FieldMultiInput/FieldMultiInput.test.tsx new file mode 100644 index 0000000000000..ce023a071053a --- /dev/null +++ b/web/packages/shared/components/FieldMultiInput/FieldMultiInput.test.tsx @@ -0,0 +1,71 @@ +/** + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import userEvent from '@testing-library/user-event'; +import React, { useState } from 'react'; + +import { render, screen } from 'design/utils/testing'; + +import { FieldMultiInput, FieldMultiInputProps } from './FieldMultiInput'; + +const TestFieldMultiInput = ({ + onChange, + ...rest +}: Partial) => { + const [items, setItems] = useState([]); + const handleChange = (it: string[]) => { + setItems(it); + onChange?.(it); + }; + return ; +}; + +test('adding, editing, and removing items', async () => { + const user = userEvent.setup(); + const onChange = jest.fn(); + render(); + + await user.type(screen.getByRole('textbox'), 'apples'); + expect(onChange).toHaveBeenLastCalledWith(['apples']); + + await user.click(screen.getByRole('button', { name: 'Add More' })); + expect(onChange).toHaveBeenLastCalledWith(['apples', '']); + + await user.type(screen.getAllByRole('textbox')[1], 'oranges'); + expect(onChange).toHaveBeenLastCalledWith(['apples', 'oranges']); + + await user.click(screen.getAllByRole('button', { name: 'Remove Item' })[0]); + expect(onChange).toHaveBeenLastCalledWith(['oranges']); + + await user.click(screen.getAllByRole('button', { name: 'Remove Item' })[0]); + expect(onChange).toHaveBeenLastCalledWith([]); +}); + +test('keyboard handling', async () => { + const user = userEvent.setup(); + const onChange = jest.fn(); + render(); + + await user.click(screen.getByRole('textbox')); + await user.keyboard('apples{Enter}oranges'); + expect(onChange).toHaveBeenLastCalledWith(['apples', 'oranges']); + + await user.click(screen.getAllByRole('textbox')[0]); + await user.keyboard('{Enter}bananas'); + expect(onChange).toHaveBeenLastCalledWith(['apples', 'bananas', 'oranges']); +}); diff --git a/web/packages/shared/components/FieldMultiInput/FieldMultiInput.tsx b/web/packages/shared/components/FieldMultiInput/FieldMultiInput.tsx new file mode 100644 index 0000000000000..eaa98ef0a6511 --- /dev/null +++ b/web/packages/shared/components/FieldMultiInput/FieldMultiInput.tsx @@ -0,0 +1,139 @@ +/** + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import Box from 'design/Box'; +import { ButtonSecondary } from 'design/Button'; +import ButtonIcon from 'design/ButtonIcon'; +import Flex from 'design/Flex'; +import * as Icon from 'design/Icon'; +import Input from 'design/Input'; +import { useRef } from 'react'; +import styled, { useTheme } from 'styled-components'; + +export type FieldMultiInputProps = { + label?: string; + value: string[]; + disabled?: boolean; + onChange?(val: string[]): void; +}; + +/** + * Allows editing a list of strings, one value per row. Use instead of + * `FieldSelectCreatable` when: + * + * - There are no predefined values to be picked from. + * - Values are expected to be relatively long and would be unreadable after + * being truncated. + */ +export function FieldMultiInput({ + label, + value, + disabled, + onChange, +}: FieldMultiInputProps) { + if (value.length === 0) { + value = ['']; + } + + const theme = useTheme(); + // Index of the input to be focused after the next rendering. + const toFocus = useRef(); + + const setFocus = element => { + element?.focus(); + toFocus.current = undefined; + }; + + function insertItem(index: number) { + onChange?.(value.toSpliced(index, 0, '')); + } + + function removeItem(index: number) { + onChange?.(value.toSpliced(index, 1)); + } + + function handleKeyDown(index: number, e: React.KeyboardEvent) { + if (e.key === 'Enter') { + insertItem(index + 1); + toFocus.current = index + 1; + } + } + + return ( + +
+ {label && {label}} + {value.map((val, i) => ( + // Note on keys: using index as a key is an anti-pattern in general, + // but here, we can safely assume that even though the list is + // editable, we don't rely on any unmanaged HTML element state other + // than focus, which we deal with separately anyway. The alternatives + // would be either to require an array with keys generated + // synthetically and injected from outside (which would make the API + // difficult to use) or to keep the array with generated IDs as local + // state (which would require us to write a prop/state reconciliation + // procedure whose complexity would probably outweigh the benefits). + + + + onChange?.( + value.map((v, j) => (j === i ? e.target.value : v)) + ) + } + onKeyDown={e => handleKeyDown(i, e)} + /> + + removeItem(i)} + disabled={disabled} + > + + + + ))} + insertItem(value.length)} + > + + Add More + +
+
+ ); +} + +const Fieldset = styled.fieldset` + border: none; + margin: 0; + padding: 0; + display: flex; + flex-direction: column; + gap: ${props => props.theme.space[2]}px; +`; + +const Legend = styled.legend` + margin: 0 0 ${props => props.theme.space[1]}px 0; + padding: 0; + ${props => props.theme.typography.body3} +`; diff --git a/web/packages/teleport/src/Discover/SelectResource/__snapshots__/SelectResource.story.test.tsx.snap b/web/packages/teleport/src/Discover/SelectResource/__snapshots__/SelectResource.story.test.tsx.snap index b82f160b6f66c..f9e6fdc5ee263 100644 --- a/web/packages/teleport/src/Discover/SelectResource/__snapshots__/SelectResource.story.test.tsx.snap +++ b/web/packages/teleport/src/Discover/SelectResource/__snapshots__/SelectResource.story.test.tsx.snap @@ -561,7 +561,7 @@ exports[`render with URL loc state set to "server" 1`] = ` - Redshift Serverless + RDS SQL Server @@ -643,7 +643,7 @@ exports[`render with URL loc state set to "server" 1`] = ` - Azure + Amazon Web Services (AWS)
- SQL Server + Redshift Serverless
@@ -684,7 +684,7 @@ exports[`render with URL loc state set to "server" 1`] = `
- Microsoft + Azure
- Self-Hosted + Amazon Web Services (AWS)
- Redis + RDS SQL Server
@@ -2480,7 +2480,7 @@ exports[`render with all access 1`] = `
- Redis Cluster + Redis @@ -2521,7 +2521,7 @@ exports[`render with all access 1`] = ` - Amazon Web Services (AWS) + Self-Hosted
- Redshift PostgreSQL + Redis Cluster
@@ -2562,7 +2562,7 @@ exports[`render with all access 1`] = `
- Redshift Serverless + Redshift PostgreSQL @@ -2603,7 +2603,7 @@ exports[`render with all access 1`] = ` +
+ Amazon Web Services (AWS) +
- Snowflake + Redshift Serverless
@@ -2637,7 +2644,7 @@ exports[`render with all access 1`] = `
-
- Azure -
- SQL Server + Snowflake
@@ -2678,7 +2678,7 @@ exports[`render with all access 1`] = `
- Microsoft + Azure
- Self-Hosted + Amazon Web Services (AWS)
- Redis + RDS SQL Server
@@ -4549,7 +4549,7 @@ exports[`render with no access 1`] = ` class="c19" color="text.main" > - Redis Cluster + Redis @@ -4589,13 +4589,13 @@ exports[`render with no access 1`] = ` color="text.slightlyMuted" font-size="12px" > - Amazon Web Services (AWS) + Self-Hosted
- Redshift PostgreSQL + Redis Cluster
@@ -4641,14 +4641,17 @@ exports[`render with no access 1`] = ` class="c19" color="text.main" > - Redshift Serverless + Redshift PostgreSQL
-
- Server + Amazon Web Services (AWS)
- RHEL/CentOS 7+ + Redshift Serverless
- - +
- Snowflake + Server +
+
+ RHEL/CentOS 7+
-
+ -
- Azure -
- SQL Server + Snowflake
@@ -4808,7 +4808,7 @@ exports[`render with no access 1`] = ` color="text.slightlyMuted" font-size="12px" > - Microsoft + Azure
- Self-Hosted + Amazon Web Services (AWS)
- Redis + RDS SQL Server
@@ -6791,7 +6791,7 @@ exports[`render with partial access 1`] = ` class="c18" color="text.main" > - Redis Cluster + Redis @@ -6831,13 +6831,13 @@ exports[`render with partial access 1`] = ` color="text.slightlyMuted" font-size="12px" > - Amazon Web Services (AWS) + Self-Hosted
- Redshift PostgreSQL + Redis Cluster
@@ -6883,7 +6883,7 @@ exports[`render with partial access 1`] = ` class="c18" color="text.main" > - Redshift Serverless + Redshift PostgreSQL @@ -6918,11 +6918,18 @@ exports[`render with partial access 1`] = `
+
+ Amazon Web Services (AWS) +
- Snowflake + Redshift Serverless
@@ -6957,18 +6964,11 @@ exports[`render with partial access 1`] = `
-
- Azure -
- SQL Server + Snowflake
@@ -7008,7 +7008,7 @@ exports[`render with partial access 1`] = ` color="text.slightlyMuted" font-size="12px" > - Microsoft + Azure