Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s/a/notify: add proper version handling to the notification protocol #14914

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions sandbox/apparmor/notify/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,49 @@ package notify

import (
"golang.org/x/sys/unix"

"github.com/snapcore/snapd/testutil"
)

var (
Versions = versions
VersionLikelySupportedChecks = versionLikelySupportedChecks

LikelySupported = ProtocolVersion.likelySupported
LikelySupportedProtocolVersion = likelySupportedProtocolVersion
)

func MockSyscall(syscall func(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err unix.Errno)) (restore func()) {
old := doSyscall
doSyscall = syscall
restore = func() { doSyscall = old }
return testutil.Mock(&doSyscall, syscall)
}

// VersionAndCheck couples protocol version with a support check function which
// returns true if the version is supported. This type is used so that
// `versions` and `versionLikelySupportedChecks` can be mocked to avoid
// calling the actual check functions (which generally probe the host
// system), and so that the logic around handling of unsupported and supported
// versions can be tested.
type VersionAndCheck struct {
Version ProtocolVersion
Check func() bool
}

func MockVersionLikelySupportedChecks(pairs []VersionAndCheck) (restore func()) {
restoreVersions := testutil.Backup(&versions)
restoreChecks := testutil.Backup(&versionLikelySupportedChecks)
restore = func() {
restoreChecks()
restoreVersions()
}
versions = make([]ProtocolVersion, 0, len(pairs))
versionLikelySupportedChecks = make(map[ProtocolVersion]func() bool, len(pairs))
for _, pair := range pairs {
versions = append(versions, pair.Version)
versionLikelySupportedChecks[pair.Version] = pair.Check
}
return restore
}

func MockIoctl(f func(fd uintptr, req IoctlRequest, buf IoctlRequestBuffer) ([]byte, error)) (restore func()) {
return testutil.Mock(&doIoctl, f)
}
2 changes: 1 addition & 1 deletion sandbox/apparmor/notify/ioctl.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func Ioctl(fd uintptr, req IoctlRequest, buf IoctlRequestBuffer) ([]byte, error)
}
}
if errno != 0 {
return nil, fmt.Errorf("cannot perform IOCTL request %v: %v", req, unix.Errno(errno))
return nil, fmt.Errorf("cannot perform IOCTL request %v: %w (%s)", req, errno, unix.ErrnoName(errno))
}
if size >= 0 && size <= len(buf) {
buf = buf[:size]
Expand Down
29 changes: 21 additions & 8 deletions sandbox/apparmor/notify/listener/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,31 @@ func MockEpollWait(f func(l *Listener) ([]epoll.Event, error)) (restore func())
return restore
}

func MockNotifyRegisterFileDescriptor(f func(fd uintptr) (notify.ProtocolVersion, error)) (restore func()) {
restore = testutil.Backup(&notifyRegisterFileDescriptor)
notifyRegisterFileDescriptor = f
return restore
}

func MockNotifyIoctl(f func(fd uintptr, req notify.IoctlRequest, buf notify.IoctlRequestBuffer) ([]byte, error)) (restore func()) {
restore = testutil.Backup(&notifyIoctl)
notifyIoctl = f
return restore
}

// Mocks epoll.Wait and notify.Ioctl calls by sending data over channels.
// Mocks epoll.Wait, notify.Ioctl, and notify.RegisterFileDescriptor calls by
// sending data over channels, using the given version as the protocol version
// for the listener.
//
// When data is sent over the recv channel (to be consumed by a mocked ioctl
// call), it triggers an epoll event with the listener's notify socket fd, and
// then passes the data on to the next ioctl RECV call. When the listener makes
// a SEND call via ioctl, the data is instead written to the send channel.
func MockEpollWaitNotifyIoctl() (recvChan chan<- []byte, sendChan <-chan []byte, restore func()) {
func MockEpollWaitNotifyIoctl(protoVersion notify.ProtocolVersion) (recvChan chan<- []byte, sendChan <-chan []byte, restore func()) {
recvChanRW := make(chan []byte)
sendChanRW := make(chan []byte)
internalRecvChan := make(chan []byte, 1)
ef := func(l *Listener) ([]epoll.Event, error) {
epollF := func(l *Listener) ([]epoll.Event, error) {
for {
select {
case request := <-recvChanRW:
Expand All @@ -103,7 +112,7 @@ func MockEpollWaitNotifyIoctl() (recvChan chan<- []byte, sendChan <-chan []byte,
}
}
}
nf := func(fd uintptr, req notify.IoctlRequest, buf notify.IoctlRequestBuffer) ([]byte, error) {
ioctlF := func(fd uintptr, req notify.IoctlRequest, buf notify.IoctlRequestBuffer) ([]byte, error) {
switch req {
case notify.APPARMOR_NOTIF_RECV:
request := <-internalRecvChan
Expand All @@ -115,13 +124,17 @@ func MockEpollWaitNotifyIoctl() (recvChan chan<- []byte, sendChan <-chan []byte,
}
return buf, nil
}
restoreEpoll := testutil.Backup(&listenerEpollWait)
listenerEpollWait = ef
restoreIoctl := testutil.Backup(&notifyIoctl)
notifyIoctl = nf
rfdF := func(fd uintptr) (notify.ProtocolVersion, error) {
return protoVersion, nil
}
restoreEpoll := testutil.Mock(&listenerEpollWait, epollF)
restoreIoctl := testutil.Mock(&notifyIoctl, ioctlF)
restoreRegisterFileDescriptor := testutil.Mock(&notifyRegisterFileDescriptor, rfdF)

restore = func() {
restoreEpoll()
restoreIoctl()
restoreRegisterFileDescriptor()
close(recvChanRW)
close(sendChanRW)
}
Expand Down
27 changes: 17 additions & 10 deletions sandbox/apparmor/notify/listener/listener.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-

/*
* Copyright (C) 2023-2024 Canonical Ltd
* Copyright (C) 2023-2025 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
Expand Down Expand Up @@ -53,8 +53,9 @@ var (
// ErrNotSupported indicates that the kernel does not support apparmor prompting.
ErrNotSupported = errors.New("kernel does not support apparmor notifications")

osOpen = os.Open
notifyIoctl = notify.Ioctl
osOpen = os.Open
notifyRegisterFileDescriptor = notify.RegisterFileDescriptor
notifyIoctl = notify.Ioctl
)

// Request is a high-level representation of an apparmor prompting message.
Expand Down Expand Up @@ -153,6 +154,12 @@ type Listener struct {
// and needs to be replied to.
reqs chan *Request

// protocolVersion is the notification protocol version associated with the
// listener's notify socket. Once registered with a particular version,
// that version will be used for all messages sent or received over that
// socket.
protocolVersion notify.ProtocolVersion

notifyFile *os.File
poll *epoll.Epoll

Expand Down Expand Up @@ -197,16 +204,10 @@ func Register() (listener *Listener, err error) {
}
}()

msg := notify.MsgNotificationFilter{ModeSet: notify.APPARMOR_MODESET_USER}
data, err := msg.MarshalBinary()
protoVersion, err := notifyRegisterFileDescriptor(notifyFile.Fd())
if err != nil {
return nil, err
}
ioctlBuf := notify.IoctlRequestBuffer(data)
_, err = notifyIoctl(notifyFile.Fd(), notify.APPARMOR_NOTIF_SET_FILTER, ioctlBuf)
if err != nil {
return nil, fmt.Errorf("cannot notify ioctl to modeset user on %q: %v", path, err)
}

poll, err := epoll.Open()
if err != nil {
Expand All @@ -224,6 +225,8 @@ func Register() (listener *Listener, err error) {
listener = &Listener{
reqs: make(chan *Request, 1),

protocolVersion: protoVersion,

notifyFile: notifyFile,
poll: poll,
}
Expand Down Expand Up @@ -371,6 +374,9 @@ func (l *Listener) decodeAndDispatchRequest(buf []byte) error {
if err := nmsg.UnmarshalBinary(first); err != nil {
return err
}
if nmsg.Version != l.protocolVersion {
return fmt.Errorf("unexpected protocol version: listener registered with %d, but received %d", l.protocolVersion, nmsg.Version)
}
// What kind of notification message did we get?
if nmsg.NotificationType != notify.APPARMOR_NOTIF_OP {
return fmt.Errorf("unsupported notification type: %v", nmsg.NotificationType)
Expand Down Expand Up @@ -423,6 +429,7 @@ func (l *Listener) handleRequestAaClassFile(buf []byte) error {

func (l *Listener) waitAndRespondAaClassFile(req *Request, msg *notify.MsgNotificationFile) error {
resp := notify.ResponseForRequest(&msg.MsgNotification)
resp.Version = l.protocolVersion
resp.MsgNotification.Error = 0 // ignored in responses
resp.MsgNotification.NoCache = 1

Expand Down
Loading
Loading