From c4f686f4e192fa92977d68f4e22efb49047c7441 Mon Sep 17 00:00:00 2001 From: xianzhe-databricks Date: Fri, 13 Dec 2024 20:01:15 +0100 Subject: [PATCH] Add a new RPC ConnectWithCreds to allow gofer to connect to a unix domain socket with application's credentials --- images/basic/integrationtest/Dockerfile | 5 +- pkg/lisafs/client_file.go | 29 +++++-- pkg/lisafs/fd.go | 7 ++ pkg/lisafs/handlers.go | 95 ++++++++++++++-------- pkg/lisafs/message.go | 19 +++++ pkg/sentry/fsimpl/gofer/dentry_impl.go | 11 ++- pkg/sentry/fsimpl/gofer/directfs_dentry.go | 4 +- pkg/sentry/kernel/auth/context.go | 8 ++ runsc/cmd/gofer.go | 31 ++++++- runsc/fsgofer/filter/config.go | 4 +- runsc/fsgofer/lisafs.go | 60 ++++++++++++++ test/e2e/integration_runtime_test.go | 43 ++++++++-- 12 files changed, 263 insertions(+), 53 deletions(-) diff --git a/images/basic/integrationtest/Dockerfile b/images/basic/integrationtest/Dockerfile index 8da8fc082b..fb98d6ae99 100644 --- a/images/basic/integrationtest/Dockerfile +++ b/images/basic/integrationtest/Dockerfile @@ -18,8 +18,11 @@ RUN gcc -O2 -o tcp_server tcp_server.c # Add nonprivileged regular user named "nonroot". RUN groupadd --gid 1337 nonroot && \ - useradd --uid 1337 --gid 1337 \ + useradd --uid 1338 --gid 1337 \ --create-home \ --shell $(which bash) \ --password '' \ nonroot + +# Copy host_connect to /home/nonroot so that "nonroot" can execute it. +RUN cp host_connect /home/nonroot/host_connect diff --git a/pkg/lisafs/client_file.go b/pkg/lisafs/client_file.go index fe4dd34167..20e43a19ae 100644 --- a/pkg/lisafs/client_file.go +++ b/pkg/lisafs/client_file.go @@ -467,13 +467,30 @@ func (f *ClientFD) BindAt(ctx context.Context, sockType linux.SockType, name str } // Connect makes the Connect RPC. -func (f *ClientFD) Connect(ctx context.Context, sockType linux.SockType) (int, error) { - req := ConnectReq{FD: f.fd, SockType: uint32(sockType)} - var resp ConnectResp +func (f *ClientFD) Connect(ctx context.Context, sockType linux.SockType, euid UID, egid GID) (int, error) { + credsAvailable := euid != NoUID && egid != NoGID + var err error var sockFD [1]int - ctx.UninterruptibleSleepStart(false) - err := f.client.SndRcvMessage(Connect, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.CheckedUnmarshal, sockFD[:], req.String, resp.String) - ctx.UninterruptibleSleepFinish(false) + var resp ConnectResp + if credsAvailable && f.client.IsSupported(ConnectWithCreds) { + req := ConnectWithCredsReq{ + ConnectReq: ConnectReq{ + FD: f.fd, + SockType: uint32(sockType), + }, + UID: euid, + GID: egid, + } + ctx.UninterruptibleSleepStart(false) + err = f.client.SndRcvMessage(ConnectWithCreds, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.CheckedUnmarshal, sockFD[:], req.String, resp.String) + ctx.UninterruptibleSleepFinish(false) + } else { + req := ConnectReq{FD: f.fd, SockType: uint32(sockType)} + ctx.UninterruptibleSleepStart(false) + err = f.client.SndRcvMessage(Connect, uint32(req.SizeBytes()), req.MarshalUnsafe, resp.CheckedUnmarshal, sockFD[:], req.String, resp.String) + ctx.UninterruptibleSleepFinish(false) + } + if err == nil && sockFD[0] < 0 { err = unix.EBADF } diff --git a/pkg/lisafs/fd.go b/pkg/lisafs/fd.go index 3b8ad6533f..93deceaf06 100644 --- a/pkg/lisafs/fd.go +++ b/pkg/lisafs/fd.go @@ -484,6 +484,13 @@ type ControlFDImpl interface { // On the server, Connect has a read concurrency guarantee. Connect(sockType uint32) (int, error) + // ConnectWithCreds is a wrapper around Connect but first changes the gofer's + // euid and egid to the given uid and gid before calling Connect. It restores + // the euid and egid after Connect. + // + // On the server, ConnectWithCreds has a read concurrency guarantee. + ConnectWithCreds(sockType uint32, uid UID, gid GID) (int, error) + // BindAt creates a host unix domain socket of type sockType, bound to // the given namt of type sockType, bound to the given name. It returns // a ControlFD that can be used for path operations on the socket, a diff --git a/pkg/lisafs/handlers.go b/pkg/lisafs/handlers.go index 9f3f9c1c66..78f80cc98d 100644 --- a/pkg/lisafs/handlers.go +++ b/pkg/lisafs/handlers.go @@ -46,38 +46,39 @@ const ( type RPCHandler func(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) var handlers = [...]RPCHandler{ - Error: ErrorHandler, - Mount: MountHandler, - Channel: ChannelHandler, - FStat: FStatHandler, - SetStat: SetStatHandler, - Walk: WalkHandler, - WalkStat: WalkStatHandler, - OpenAt: OpenAtHandler, - OpenCreateAt: OpenCreateAtHandler, - Close: CloseHandler, - FSync: FSyncHandler, - PWrite: PWriteHandler, - PRead: PReadHandler, - MkdirAt: MkdirAtHandler, - MknodAt: MknodAtHandler, - SymlinkAt: SymlinkAtHandler, - LinkAt: LinkAtHandler, - FStatFS: FStatFSHandler, - FAllocate: FAllocateHandler, - ReadLinkAt: ReadLinkAtHandler, - Flush: FlushHandler, - UnlinkAt: UnlinkAtHandler, - RenameAt: RenameAtHandler, - Getdents64: Getdents64Handler, - FGetXattr: FGetXattrHandler, - FSetXattr: FSetXattrHandler, - FListXattr: FListXattrHandler, - FRemoveXattr: FRemoveXattrHandler, - Connect: ConnectHandler, - BindAt: BindAtHandler, - Listen: ListenHandler, - Accept: AcceptHandler, + Error: ErrorHandler, + Mount: MountHandler, + Channel: ChannelHandler, + FStat: FStatHandler, + SetStat: SetStatHandler, + Walk: WalkHandler, + WalkStat: WalkStatHandler, + OpenAt: OpenAtHandler, + OpenCreateAt: OpenCreateAtHandler, + Close: CloseHandler, + FSync: FSyncHandler, + PWrite: PWriteHandler, + PRead: PReadHandler, + MkdirAt: MkdirAtHandler, + MknodAt: MknodAtHandler, + SymlinkAt: SymlinkAtHandler, + LinkAt: LinkAtHandler, + FStatFS: FStatFSHandler, + FAllocate: FAllocateHandler, + ReadLinkAt: ReadLinkAtHandler, + Flush: FlushHandler, + UnlinkAt: UnlinkAtHandler, + RenameAt: RenameAtHandler, + Getdents64: Getdents64Handler, + FGetXattr: FGetXattrHandler, + FSetXattr: FSetXattrHandler, + FListXattr: FListXattrHandler, + FRemoveXattr: FRemoveXattrHandler, + Connect: ConnectHandler, + BindAt: BindAtHandler, + Listen: ListenHandler, + Accept: AcceptHandler, + ConnectWithCreds: ConnectWithCredsHandler, } // ErrorHandler handles Error message. @@ -1069,6 +1070,36 @@ func ConnectHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32 return 0, nil } +// ConnectWithCredsHandler handles the ConnectWithCreds RPC. +func ConnectWithCredsHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { + var req ConnectWithCredsReq + if _, ok := req.CheckedUnmarshal(comm.PayloadBuf(payloadLen)); !ok { + return 0, unix.EIO + } + + fd, err := c.lookupControlFD(req.FD) + if err != nil { + return 0, err + } + defer fd.DecRef(nil) + if !fd.IsSocket() { + return 0, unix.ENOTSOCK + } + var sock int + if err := fd.safelyRead(func() error { + if fd.node.isDeleted() { + return unix.EINVAL + } + sock, err = fd.impl.ConnectWithCreds(req.SockType, req.UID, req.GID) + return err + }); err != nil { + return 0, err + } + + comm.DonateFD(sock) + return 0, nil +} + // BindAtHandler handles the BindAt RPC. func BindAtHandler(c *Connection, comm Communicator, payloadLen uint32) (uint32, error) { var req BindAtReq diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go index e623b68b37..2ee86329aa 100644 --- a/pkg/lisafs/message.go +++ b/pkg/lisafs/message.go @@ -172,6 +172,10 @@ const ( // Accept is analogous to accept4(2). Accept MID = 31 + + // ConnectWithCreds is analogous to connect(2) but it asks the server + // to connect with the provided effective uid/gid. + ConnectWithCreds MID = 32 ) const ( @@ -1318,6 +1322,21 @@ func (*ConnectResp) String() string { return "ConnectResp{}" } +// ConnectWithCredsReq is used to make a ConnectWithCreds request. The response is also ConnectResp. +// +// +marshal boundCheck +type ConnectWithCredsReq struct { + ConnectReq + // UID and GID are used to specify the credentials to connect with. + UID UID + GID GID +} + +// String implements fmt.Stringer.String. +func (c *ConnectWithCredsReq) String() string { + return fmt.Sprintf("ConnectWithCredsReq{FD: %d, SockType: %d, UID: %d, GID: %d}", c.FD, c.SockType, c.UID, c.GID) +} + // BindAtReq is used to make BindAt requests. type BindAtReq struct { createCommon diff --git a/pkg/sentry/fsimpl/gofer/dentry_impl.go b/pkg/sentry/fsimpl/gofer/dentry_impl.go index 8a494353e6..56a1b6e6e5 100644 --- a/pkg/sentry/fsimpl/gofer/dentry_impl.go +++ b/pkg/sentry/fsimpl/gofer/dentry_impl.go @@ -452,11 +452,18 @@ func (d *dentry) allocate(ctx context.Context, mode, offset, length uint64) erro // - !d.isSynthetic(). // - fs.renameMu is locked. func (d *dentry) connect(ctx context.Context, sockType linux.SockType) (int, error) { + credentials := auth.CredentialsFromContextOrNil(ctx) + euid := lisafs.NoUID + egid := lisafs.NoGID + if credentials != nil { + euid = lisafs.UID(credentials.EffectiveKUID) + egid = lisafs.GID(credentials.EffectiveKGID) + } switch dt := d.impl.(type) { case *lisafsDentry: - return dt.controlFD.Connect(ctx, sockType) + return dt.controlFD.Connect(ctx, sockType, euid, egid) case *directfsDentry: - return dt.connect(ctx, sockType) + return dt.connect(ctx, sockType, euid, egid) default: panic("unknown dentry implementation") } diff --git a/pkg/sentry/fsimpl/gofer/directfs_dentry.go b/pkg/sentry/fsimpl/gofer/directfs_dentry.go index 3759250077..7ebbc9d8e0 100644 --- a/pkg/sentry/fsimpl/gofer/directfs_dentry.go +++ b/pkg/sentry/fsimpl/gofer/directfs_dentry.go @@ -603,13 +603,13 @@ func (d *directfsDentry) getDirentsLocked(recordDirent func(name string, key ino } // Precondition: fs.renameMu is locked. -func (d *directfsDentry) connect(ctx context.Context, sockType linux.SockType) (int, error) { +func (d *directfsDentry) connect(ctx context.Context, sockType linux.SockType, euid lisafs.UID, egid lisafs.GID) (int, error) { // There are no filesystems mounted in the sandbox process's mount namespace. // So we can't perform absolute path traversals. So fallback to using lisafs. if err := d.ensureLisafsControlFD(ctx); err != nil { return -1, err } - return d.controlFDLisa.Connect(ctx, sockType) + return d.controlFDLisa.Connect(ctx, sockType, euid, egid) } func (d *directfsDentry) readlink() (string, error) { diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go index e1a6cdac5f..09a40d1e9a 100644 --- a/pkg/sentry/kernel/auth/context.go +++ b/pkg/sentry/kernel/auth/context.go @@ -39,6 +39,14 @@ func CredentialsFromContext(ctx context.Context) *Credentials { return NewAnonymousCredentials() } +// CredentialsFromContextOrNil returns a copy of the Credentials used by ctx, or nil if ctx does not have Credentials. +func CredentialsFromContextOrNil(ctx context.Context) *Credentials { + if v := ctx.Value(CtxCredentials); v != nil { + return v.(*Credentials) + } + return nil +} + // ThreadGroupIDFromContext returns the current thread group ID when ctx // represents a task context. func ThreadGroupIDFromContext(ctx context.Context) (tgid int32, ok bool) { diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 645bbbe15c..00af04e718 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -53,6 +53,11 @@ var caps = []string{ "CAP_SYS_CHROOT", } +var udsOpenCaps = []string{ + "CAP_SETUID", + "CAP_SETGID", +} + // goferCaps is the minimal set of capabilities needed by the Gofer to operate // on files. var goferCaps = &specs.LinuxCapabilities{ @@ -61,6 +66,12 @@ var goferCaps = &specs.LinuxCapabilities{ Permitted: caps, } +var goferUdsOpenCaps = &specs.LinuxCapabilities{ + Bounding: udsOpenCaps, + Effective: udsOpenCaps, + Permitted: udsOpenCaps, +} + // goferSyncFDs contains file descriptors that are used for synchronization // of the Gofer startup process against other processes. type goferSyncFDs struct { @@ -180,7 +191,11 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...any) subcomm overrides["apply-caps"] = "false" overrides["setup-root"] = "false" args := prepareArgs(g.Name(), f, overrides) - util.Fatalf("setCapsAndCallSelf(%v, %v): %v", args, goferCaps, setCapsAndCallSelf(args, goferCaps)) + capsToApply := goferCaps + if conf.GetHostUDS().AllowOpen() { + capsToApply = specutils.MergeCapabilities(capsToApply, goferUdsOpenCaps) + } + util.Fatalf("setCapsAndCallSelf(%v, %v): %v", args, capsToApply, setCapsAndCallSelf(args, capsToApply)) panic("unreachable") } @@ -252,6 +267,12 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...any) subcomm } log.Infof("Process chroot'd to %q", root) + ruid := unix.Getuid() + euid := unix.Geteuid() + rgid := unix.Getgid() + egid := unix.Getegid() + log.Debugf("Process running as uid=%d euid=%d gid=%d egid=%d", ruid, euid, rgid, egid) + // Initialize filters. opts := filter.Options{ UDSOpenEnabled: conf.GetHostUDS().AllowOpen(), @@ -264,7 +285,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...any) subcomm util.Fatalf("installing seccomp filters: %v", err) } - return g.serve(spec, conf, root) + return g.serve(spec, conf, root, ruid, euid, rgid, egid) } func newSocket(ioFD int) *unet.Socket { @@ -275,7 +296,7 @@ func newSocket(ioFD int) *unet.Socket { return socket } -func (g *Gofer) serve(spec *specs.Spec, conf *config.Config, root string) subcommands.ExitStatus { +func (g *Gofer) serve(spec *specs.Spec, conf *config.Config, root string, ruid int, euid int, rgid int, egid int) subcommands.ExitStatus { type connectionConfig struct { sock *unet.Socket mountPath string @@ -288,6 +309,10 @@ func (g *Gofer) serve(spec *specs.Spec, conf *config.Config, root string) subcom HostUDS: conf.GetHostUDS(), HostFifo: conf.HostFifo, DonateMountPointFD: conf.DirectFS, + RUID: ruid, + EUID: euid, + RGID: rgid, + EGID: egid, }) ioFDs := g.ioFDs diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go index 3f6c150fdb..70e6130885 100644 --- a/runsc/fsgofer/filter/config.go +++ b/runsc/fsgofer/filter/config.go @@ -208,7 +208,9 @@ var udsCommonSyscalls = seccomp.MakeSyscallRules(map[uintptr]seccomp.SyscallRule }) var udsOpenSyscalls = seccomp.MakeSyscallRules(map[uintptr]seccomp.SyscallRule{ - unix.SYS_CONNECT: seccomp.MatchAll{}, + unix.SYS_CONNECT: seccomp.MatchAll{}, + unix.SYS_SETREUID: seccomp.MatchAll{}, + unix.SYS_SETREGID: seccomp.MatchAll{}, }) var udsCreateSyscalls = seccomp.MakeSyscallRules(map[uintptr]seccomp.SyscallRule{ diff --git a/runsc/fsgofer/lisafs.go b/runsc/fsgofer/lisafs.go index 3ee66deee5..e4a1506b82 100644 --- a/runsc/fsgofer/lisafs.go +++ b/runsc/fsgofer/lisafs.go @@ -17,14 +17,17 @@ package fsgofer import ( + "errors" "fmt" "io" "math" "os" "path" "path/filepath" + "runtime" "strconv" "sync" + "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" @@ -64,6 +67,18 @@ type Config struct { // DonateMountPointFD indicates whether a host FD to the mount point should // be donated to the client on Mount RPC. DonateMountPointFD bool + + // Gofer process's RUID. + RUID int + + // Gofer process's EUID. + EUID int + + // Gofer process's RGID. + RGID int + + // Gofer process's EGID. + EGID int } var procSelfFD *rwfd.FD @@ -178,6 +193,7 @@ func (s *LisafsServer) SupportedMessages() []lisafs.MID { lisafs.BindAt, lisafs.Listen, lisafs.Accept, + lisafs.ConnectWithCreds, } } @@ -823,6 +839,50 @@ func (fd *controlFDLisa) Connect(sockType uint32) (int, error) { return sock, nil } +// ConnectWithCreds implements lisafs.ControlFDImpl.ConnectWithCreds. +func (fd *controlFDLisa) ConnectWithCreds(sockType uint32, uid lisafs.UID, gid lisafs.GID) (int, error) { + serverConfig := fd.Conn().ServerImpl().(*LisafsServer).config + if !serverConfig.HostUDS.AllowOpen() { + logRejectedUdsConnectOnce.Do(func() { + log.Warningf("Rejecting attempt to connect to unix domain socket from host filesystem: %q. If you want to allow this, set flag --host-uds=open", fd.ControlFD.Node().FilePath()) + }) + return -1, unix.EPERM + } + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + _, _, err := unix.Syscall(unix.SYS_SETREGID, uintptr(serverConfig.RGID), uintptr(gid), 0) + if !errors.Is(err, syscall.Errno(0)) { + log.Warningf("Failed to set egid; err: %v", err) + } else { + log.Debugf("Successfully set egid to %d", gid) + } + + _, _, err = unix.Syscall(unix.SYS_SETREUID, uintptr(serverConfig.RUID), uintptr(uid), 0) + if !errors.Is(err, syscall.Errno(0)) { + log.Warningf("Failed to set euid; err: %v", err) + } else { + log.Debugf("Successfully set euid to %d", uid) + } + + defer func() { + _, _, err := unix.Syscall(unix.SYS_SETREUID, uintptr(serverConfig.RUID), uintptr(serverConfig.EUID), 0) + if !errors.Is(err, unix.Errno(0)) { + log.Warningf("Failed to restore euid; err: %v", err) + } else { + log.Debugf("Successfully restored euid to %d", serverConfig.EUID) + } + _, _, err = unix.Syscall(unix.SYS_SETREGID, uintptr(serverConfig.RGID), uintptr(serverConfig.EGID), 0) + if !errors.Is(err, unix.Errno(0)) { + log.Warningf("Failed to restore egid; err: %v", err) + } else { + log.Debugf("Successfully restored egid to %d", serverConfig.EGID) + } + }() + + return fd.Connect(sockType) +} + // BindAt implements lisafs.ControlFDImpl.BindAt. func (fd *controlFDLisa) BindAt(name string, sockType uint32, mode linux.FileMode, uid lisafs.UID, gid lisafs.GID) (*lisafs.ControlFD, linux.Statx, *lisafs.BoundSocketFD, int, error) { if !fd.Conn().ServerImpl().(*LisafsServer).config.HostUDS.AllowCreate() { diff --git a/test/e2e/integration_runtime_test.go b/test/e2e/integration_runtime_test.go index d455ad7b78..f31432497d 100644 --- a/test/e2e/integration_runtime_test.go +++ b/test/e2e/integration_runtime_test.go @@ -44,7 +44,9 @@ import ( const ( // defaultWait is the default wait time used for tests. defaultWait = time.Minute - + // nonRootUID and nonRootGID correspond to the uid/gid defined in images/basic/integrationtest/Dockerfile. + nonRootUID = 1338 + nonRootGID = 1337 memInfoCmd = "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'" ) @@ -54,6 +56,26 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +func checkPeerCreds(conn net.Conn) error { + unixConn, ok := conn.(*net.UnixConn) + if !ok { + return fmt.Errorf("expected *net.UnixConn, got %T", conn) + } + file, err := unixConn.File() + if err != nil { + return fmt.Errorf("file error: %v", err) + } + defer file.Close() + cred, err := unix.GetsockoptUcred(int(file.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED) + if err != nil { + return fmt.Errorf("getsockopt error: %v", err) + } + if cred.Uid != nonRootUID || cred.Gid != nonRootGID { + return fmt.Errorf("expected uid/gid %d/%d, got %d/%d", nonRootUID, nonRootGID, cred.Uid, cred.Gid) + } + return nil +} + func TestRlimitNoFile(t *testing.T) { ctx := context.Background() d := dockerutil.MakeContainerWithRuntime(ctx, t, "-fdlimit") @@ -135,11 +157,16 @@ func TestHostSocketConnect(t *testing.T) { } defer unix.Close(tmpDirFD) // Use /proc/self/fd to generate path to avoid EINVAL on large path. - l, err := net.Listen("unix", filepath.Join("/proc/self/fd", strconv.Itoa(tmpDirFD), "test.sock")) + socketPath := filepath.Join("/proc/self/fd", strconv.Itoa(tmpDirFD), "test.sock") + l, err := net.Listen("unix", socketPath) if err != nil { t.Fatalf("listen error: %v", err) } defer l.Close() + // Change the socket's permission so that "nonroot" can connect to it. + if err := os.Chmod(socketPath, 0777); err != nil { + t.Errorf("chmod error: %v", err) + } var wg sync.WaitGroup wg.Add(1) @@ -150,7 +177,10 @@ func TestHostSocketConnect(t *testing.T) { t.Errorf("accept error: %v", err) return } - + if err := checkPeerCreds(conn); err != nil { + t.Errorf("peer creds check failed: %v", err) + return + } conn.SetReadDeadline(time.Now().Add(30 * time.Second)) var buf [5]byte if _, err := conn.Read(buf[:]); err != nil { @@ -165,16 +195,17 @@ func TestHostSocketConnect(t *testing.T) { opts := dockerutil.RunOpts{ Image: "basic/integrationtest", - WorkDir: "/root", + WorkDir: "/home/nonroot", + User: "nonroot", Mounts: []mount.Mount{ { Type: mount.TypeBind, Source: filepath.Join(tmpDir, "test.sock"), - Target: "/test.sock", + Target: "/home/nonroot/test.sock", }, }, } - if _, err := d.Run(ctx, opts, "./host_connect", "/test.sock"); err != nil { + if _, err := d.Run(ctx, opts, "./host_connect", "./test.sock"); err != nil { t.Fatalf("docker run failed: %v", err) } wg.Wait()