diff --git a/pkg/bindfilter/bind_filter.go b/pkg/bindfilter/bind_filter.go new file mode 100644 index 00000000..7ac377ae --- /dev/null +++ b/pkg/bindfilter/bind_filter.go @@ -0,0 +1,308 @@ +//go:build windows +// +build windows + +package bindfilter + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./bind_filter.go +//sys bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath string, virtTargetPath string, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter? +//sys bfRemoveMapping(jobHandle windows.Handle, virtRootPath string) (hr error) = bindfltapi.BfRemoveMapping? +//sys bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer *byte) (hr error) = bindfltapi.BfGetMappings? + +// BfSetupFilter flags. See: +// https://github.com/microsoft/BuildXL/blob/a6dce509f0d4f774255e5fbfb75fa6d5290ed163/Public/Src/Utilities/Native/Processes/Windows/NativeContainerUtilities.cs#L193-L240 +// +//nolint:revive // var-naming: ALL_CAPS +const ( + BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001 + // Tells bindflt to fail mapping with STATUS_INVALID_PARAMETER if a mapping produces + // multiple targets. + BINDFLT_FLAG_NO_MULTIPLE_TARGETS uint32 = 0x00000040 +) + +//nolint:revive // var-naming: ALL_CAPS +const ( + BINDFLT_GET_MAPPINGS_FLAG_VOLUME uint32 = 0x00000001 + BINDFLT_GET_MAPPINGS_FLAG_SILO uint32 = 0x00000002 + BINDFLT_GET_MAPPINGS_FLAG_USER uint32 = 0x00000004 +) + +// ApplyFileBinding creates a global mount of the source in root, with an optional +// read only flag. +// The bind filter allows us to create mounts of directories and volumes. By default it allows +// us to mount multiple sources inside a single root, acting as an overlay. Files from the +// second source will superscede the first source that was mounted. +// This function disables this behavior and sets the BINDFLT_FLAG_NO_MULTIPLE_TARGETS flag +// on the mount. +func ApplyFileBinding(root, source string, readOnly bool) error { + // The parent directory needs to exist for the bind to work. MkdirAll stats and + // returns nil if the directory exists internally so we should be fine to mkdirall + // every time. + if err := os.MkdirAll(filepath.Dir(root), 0); err != nil { + return err + } + + if strings.Contains(source, "Volume{") && !strings.HasSuffix(source, "\\") { + // Add trailing slash to volumes, otherwise we get an error when binding it to + // a folder. + source = source + "\\" + } + + flags := BINDFLT_FLAG_NO_MULTIPLE_TARGETS + if readOnly { + flags |= BINDFLT_FLAG_READ_ONLY_MAPPING + } + + // Set the job handle to 0 to create a global mount. + if err := bfSetupFilter( + 0, + flags, + root, + source, + nil, + 0, + ); err != nil { + return fmt.Errorf("failed to bind target %q to root %q: %w", source, root, err) + } + return nil +} + +// RemoveFileBinding removes a mount from the root path. +func RemoveFileBinding(root string) error { + if err := bfRemoveMapping(0, root); err != nil { + return fmt.Errorf("removing file binding: %w", err) + } + return nil +} + +// GetBindMappings returns a list of bind mappings that have their root on a +// particular volume. The volumePath parameter can be any path that exists on +// a volume. For example, if a number of mappings are created in C:\ProgramData\test, +// to get a list of those mappings, the volumePath parameter would have to be set to +// C:\ or the VOLUME_NAME_GUID notation of C:\ (\\?\Volume{GUID}\), or any child +// path that exists. +func GetBindMappings(volumePath string) ([]BindMapping, error) { + rootPtr, err := windows.UTF16PtrFromString(volumePath) + if err != nil { + return nil, err + } + + flags := BINDFLT_GET_MAPPINGS_FLAG_VOLUME + // allocate a large buffer for results + var outBuffSize uint32 = 256 * 1024 + buf := make([]byte, outBuffSize) + + if err := bfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, &buf[0]); err != nil { + return nil, err + } + + if outBuffSize < 12 { + return nil, fmt.Errorf("invalid buffer returned") + } + + result := buf[:outBuffSize] + + // The first 12 bytes are the three uint32 fields in getMappingsResponseHeader{} + headerBuffer := result[:12] + // The alternative to using unsafe and casting it to the above defined structures, is to manually + // parse the fields. Not too terrible, but not sure it'd worth the trouble. + header := *(*getMappingsResponseHeader)(unsafe.Pointer(&headerBuffer[0])) + + if header.MappingCount == 0 { + // no mappings + return []BindMapping{}, nil + } + + mappingsBuffer := result[12 : int(unsafe.Sizeof(mappingEntry{}))*int(header.MappingCount)] + // Get a pointer to the first mapping in the slice + mappingsPointer := (*mappingEntry)(unsafe.Pointer(&mappingsBuffer[0])) + // Get slice of mappings + mappings := unsafe.Slice(mappingsPointer, header.MappingCount) + + mappingEntries := make([]BindMapping, header.MappingCount) + for i := 0; i < int(header.MappingCount); i++ { + bindMapping, err := getBindMappingFromBuffer(result, mappings[i]) + if err != nil { + return nil, fmt.Errorf("fetching bind mappings: %w", err) + } + mappingEntries[i] = bindMapping + } + + return mappingEntries, nil +} + +// mappingEntry holds information about where in the response buffer we can +// find information about the virtual root (the mount point) and the targets (sources) +// that get mounted, as well as the flags used to bind the targets to the virtual root. +type mappingEntry struct { + VirtRootLength uint32 + VirtRootOffset uint32 + Flags uint32 + NumberOfTargets uint32 + TargetEntriesOffset uint32 +} + +type mappingTargetEntry struct { + TargetRootLength uint32 + TargetRootOffset uint32 +} + +// getMappingsResponseHeader represents the first 12 bytes of the BfGetMappings() response. +// It gives us the size of the buffer, the status of the call and the number of mappings. +// A response +type getMappingsResponseHeader struct { + Size uint32 + Status uint32 + MappingCount uint32 +} + +type BindMapping struct { + MountPoint string + Flags uint32 + Targets []string +} + +func decodeEntry(buffer []byte) (string, error) { + name := make([]uint16, len(buffer)/2) + err := binary.Read(bytes.NewReader(buffer), binary.LittleEndian, &name) + if err != nil { + return "", fmt.Errorf("decoding name: %w", err) + } + return windows.UTF16ToString(name), nil +} + +func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) { + if len(buffer) < offset+count*6 { + return nil, fmt.Errorf("invalid buffer") + } + + targets := make([]string, count) + for i := 0; i < count; i++ { + entryBuf := buffer[offset+i*8 : offset+i*8+8] + tgt := *(*mappingTargetEntry)(unsafe.Pointer(&entryBuf[0])) + if len(buffer) < int(tgt.TargetRootOffset)+int(tgt.TargetRootLength) { + return nil, fmt.Errorf("invalid buffer") + } + decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : tgt.TargetRootOffset+tgt.TargetRootLength]) + if err != nil { + return nil, fmt.Errorf("decoding name: %w", err) + } + decoded, err = getFinalPath(decoded) + if err != nil { + return nil, fmt.Errorf("fetching final path: %w", err) + } + + targets[i] = decoded + } + return targets, nil +} + +func getFinalPath(pth string) (string, error) { + // BfGetMappings returns VOLUME_NAME_NT paths like \Device\HarddiskVolume2\ProgramData. + // These can be accessed by prepending \\.\GLOBALROOT to the path. We use this to get the + // DOS paths for these files. + if strings.HasPrefix(pth, `\Device`) { + pth = `\\.\GLOBALROOT` + pth + } + + han, err := openPath(pth) + if err != nil { + return "", fmt.Errorf("fetching file handle: %w", err) + } + defer func() { + _ = windows.CloseHandle(han) + }() + + buf := make([]uint16, 100) + var flags uint32 = 0x0 + for { + n, err := windows.GetFinalPathNameByHandle(han, &buf[0], uint32(len(buf)), flags) + if err != nil { + // if we mounted a volume that does not also have a drive letter assigned, attempting to + // fetch the VOLUME_NAME_DOS will fail with os.ErrNotExist. Attempt to get the VOLUME_NAME_GUID. + if errors.Is(err, os.ErrNotExist) && flags != 0x1 { + flags = 0x1 + continue + } + return "", fmt.Errorf("getting final path name: %w", err) + } + if n < uint32(len(buf)) { + break + } + buf = make([]uint16, n) + } + finalPath := syscall.UTF16ToString(buf) + // We got VOLUME_NAME_DOS, we need to strip away some leading slashes. + // Leave unchanged if we ended up requesting VOLUME_NAME_GUID + if len(finalPath) > 4 && finalPath[:4] == `\\?\` && flags == 0x0 { + finalPath = finalPath[4:] + if len(finalPath) > 3 && finalPath[:3] == `UNC` { + // return path like \\server\share\... + finalPath = `\` + finalPath[3:] + } + } + + return finalPath, nil +} + +func getBindMappingFromBuffer(buffer []byte, entry mappingEntry) (BindMapping, error) { + if len(buffer) < int(entry.VirtRootOffset)+int(entry.VirtRootLength) { + return BindMapping{}, fmt.Errorf("invalid buffer") + } + + src, err := decodeEntry(buffer[entry.VirtRootOffset : entry.VirtRootOffset+entry.VirtRootLength]) + if err != nil { + return BindMapping{}, fmt.Errorf("decoding entry: %w", err) + } + targets, err := getTargetsFromBuffer(buffer, int(entry.TargetEntriesOffset), int(entry.NumberOfTargets)) + if err != nil { + return BindMapping{}, fmt.Errorf("fetching targets: %w", err) + } + + src, err = getFinalPath(src) + if err != nil { + return BindMapping{}, fmt.Errorf("fetching final path: %w", err) + } + + return BindMapping{ + Flags: entry.Flags, + Targets: targets, + MountPoint: src, + }, nil +} + +func openPath(path string) (windows.Handle, error) { + u16, err := windows.UTF16PtrFromString(path) + if err != nil { + return 0, err + } + h, err := windows.CreateFile( + u16, + 0, + windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, + nil, + windows.OPEN_EXISTING, + windows.FILE_FLAG_BACKUP_SEMANTICS, // Needed to open a directory handle. + 0) + if err != nil { + return 0, &os.PathError{ + Op: "CreateFile", + Path: path, + Err: err, + } + } + return h, nil +} diff --git a/pkg/bindfilter/bind_filter_test.go b/pkg/bindfilter/bind_filter_test.go new file mode 100644 index 00000000..d4450e59 --- /dev/null +++ b/pkg/bindfilter/bind_filter_test.go @@ -0,0 +1,309 @@ +//go:build windows +// +build windows + +package bindfilter + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + "golang.org/x/sys/windows/registry" +) + +func TestApplyFileBinding(t *testing.T) { + source := t.TempDir() + destination := t.TempDir() + fileName := "testFile.txt" + srcFile := filepath.Join(source, fileName) + dstFile := filepath.Join(destination, fileName) + + err := ApplyFileBinding(destination, source, false) + if err != nil { + t.Fatal(err) + } + defer removeFileBinding(t, destination) + + data := []byte("bind filter test") + + if err := os.WriteFile(srcFile, data, 0600); err != nil { + t.Fatal(err) + } + + readData, err := os.ReadFile(dstFile) + if err != nil { + t.Fatal(err) + } + + if string(readData) != string(data) { + t.Fatalf("source and destination file contents differ. Expected: %s, got: %s", string(data), string(readData)) + } + + // Remove the file on the mount point. The mount is not read-only, this should work. + if err := os.Remove(dstFile); err != nil { + t.Fatalf("failed to remove file from mount point: %s", err) + } + + // Check that it's gone from the source as well. + if _, err := os.Stat(srcFile); err == nil { + t.Fatalf("expected file %s to be gone but is not", srcFile) + } +} + +func removeFileBinding(t *testing.T, mountpoint string) { + if err := RemoveFileBinding(mountpoint); err != nil { + t.Logf("failed to remove file binding from %s: %q", mountpoint, err) + } +} + +func TestApplyFileBindingReadOnly(t *testing.T) { + source := t.TempDir() + destination := t.TempDir() + fileName := "testFile.txt" + srcFile := filepath.Join(source, fileName) + dstFile := filepath.Join(destination, fileName) + + err := ApplyFileBinding(destination, source, true) + if err != nil { + t.Fatal(err) + } + defer removeFileBinding(t, destination) + + data := []byte("bind filter test") + + if err := os.WriteFile(srcFile, data, 0600); err != nil { + t.Fatal(err) + } + + readData, err := os.ReadFile(dstFile) + if err != nil { + t.Fatal(err) + } + + if string(readData) != string(data) { + t.Fatalf("source and destination file contents differ. Expected: %s, got: %s", string(data), string(readData)) + } + + // Attempt to remove the file on the mount point + err = os.Remove(dstFile) + if err == nil { + t.Fatalf("should not be able to remove a file from a read-only mount") + } + if !errors.Is(err, os.ErrPermission) { + t.Fatalf("expected an access denied error, got: %q", err) + } + + // Attempt to write on the read-only mount point. + err = os.WriteFile(dstFile, []byte("something else"), 0600) + if err == nil { + t.Fatalf("should not be able to overwrite a file from a read-only mount") + } + if !errors.Is(err, os.ErrPermission) { + t.Fatalf("expected an access denied error, got: %q", err) + } +} + +func TestEnsureOnlyOneTargetCanBeMounted(t *testing.T) { + version, err := getWindowsBuildNumber() + if err != nil { + t.Fatalf("couldn't get version number: %s", err) + } + + if version <= 17763 { + t.Skip("not supported on RS5 or earlier") + } + source := t.TempDir() + secondarySource := t.TempDir() + destination := t.TempDir() + + err = ApplyFileBinding(destination, source, false) + if err != nil { + t.Fatal(err) + } + + defer removeFileBinding(t, destination) + + err = ApplyFileBinding(destination, secondarySource, false) + if err == nil { + removeFileBinding(t, destination) + t.Fatalf("we should not be able to mount multiple targets in the same destination") + } +} + +func checkSourceIsMountedOnDestination(src, dst string) (bool, error) { + mappings, err := GetBindMappings(dst) + if err != nil { + return false, err + } + + found := false + // There may be pre-existing mappings on the system. + for _, mapping := range mappings { + if mapping.MountPoint == dst { + found = true + if len(mapping.Targets) != 1 { + return false, fmt.Errorf("expected only one target, got: %s", strings.Join(mapping.Targets, ", ")) + } + if mapping.Targets[0] != src { + return false, fmt.Errorf("expected target to be %s, got %s", src, mapping.Targets[0]) + } + break + } + } + + return found, nil +} + +func TestGetBindMappings(t *testing.T) { + version, err := getWindowsBuildNumber() + if err != nil { + t.Fatalf("couldn't get version number: %s", err) + } + + if version <= 17763 { + t.Skip("not supported on RS5 or earlier") + } + // GetBindMappings will expand short paths like ADMINI~1 and PROGRA~1 to their + // full names. In order to properly match the names later, we expand them here. + srcShort := t.TempDir() + source, err := getFinalPath(srcShort) + if err != nil { + t.Fatalf("failed to get long path") + } + + dstShort := t.TempDir() + destination, err := getFinalPath(dstShort) + if err != nil { + t.Fatalf("failed to get long path") + } + + err = ApplyFileBinding(destination, source, false) + if err != nil { + t.Fatal(err) + } + defer removeFileBinding(t, destination) + + hasMapping, err := checkSourceIsMountedOnDestination(source, destination) + if err != nil { + t.Fatal(err) + } + + if !hasMapping { + t.Fatalf("expected to find %s mounted on %s, but could not", source, destination) + } +} + +func TestRemoveFileBinding(t *testing.T) { + srcShort := t.TempDir() + source, err := getFinalPath(srcShort) + if err != nil { + t.Fatalf("failed to get long path") + } + + dstShort := t.TempDir() + destination, err := getFinalPath(dstShort) + if err != nil { + t.Fatalf("failed to get long path") + } + + fileName := "testFile.txt" + srcFile := filepath.Join(source, fileName) + dstFile := filepath.Join(destination, fileName) + data := []byte("bind filter test") + + if err := os.WriteFile(srcFile, data, 0600); err != nil { + t.Fatal(err) + } + + err = ApplyFileBinding(destination, source, false) + if err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(dstFile); err != nil { + removeFileBinding(t, destination) + t.Fatalf("expected to find %s, but could not", dstFile) + } + + if err := RemoveFileBinding(destination); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(dstFile); err == nil { + t.Fatalf("expected %s to be gone, but it is not", dstFile) + } +} + +func getWindowsBuildNumber() (int, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) + if err != nil { + return 0, fmt.Errorf("read CurrentVersion reg key: %w", err) + } + defer k.Close() + buildNumStr, _, err := k.GetStringValue("CurrentBuild") + if err != nil { + return 0, fmt.Errorf("read CurrentBuild reg value: %w", err) + } + buildNum, err := strconv.Atoi(buildNumStr) + if err != nil { + return 0, err + } + return buildNum, nil +} + +func TestGetBindMappingsSymlinks(t *testing.T) { + version, err := getWindowsBuildNumber() + if err != nil { + t.Fatalf("couldn't get version number: %s", err) + } + + if version <= 17763 { + t.Skip("not supported on RS5 or earlier") + } + + srcShort := t.TempDir() + sourceNested := filepath.Join(srcShort, "source") + if err := os.MkdirAll(sourceNested, 0600); err != nil { + t.Fatalf("failed to create folder: %s", err) + } + simlinkSource := filepath.Join(srcShort, "symlink") + if err := os.Symlink(sourceNested, simlinkSource); err != nil { + t.Fatalf("failed to create symlink: %s", err) + } + + // We'll need the long form of the source folder, as we expect bfSetupFilter() + // to resolve the symlink and create a mapping to the actual source the symlink + // points to. + source, err := getFinalPath(sourceNested) + if err != nil { + t.Fatalf("failed to get long path") + } + + dstShort := t.TempDir() + destination, err := getFinalPath(dstShort) + if err != nil { + t.Fatalf("failed to get long path") + } + + // Use the symlink as a source for the mapping. + err = ApplyFileBinding(destination, simlinkSource, false) + if err != nil { + t.Fatal(err) + } + defer removeFileBinding(t, destination) + + // We expect the mapping to point to the folder the symlink points to, not to the + // actual symlink. + hasMapping, err := checkSourceIsMountedOnDestination(source, destination) + if err != nil { + t.Fatal(err) + } + + if !hasMapping { + t.Fatalf("expected to find %s mounted on %s, but could not", source, destination) + } +} diff --git a/pkg/bindfilter/zsyscall_windows.go b/pkg/bindfilter/zsyscall_windows.go new file mode 100644 index 00000000..45c45c96 --- /dev/null +++ b/pkg/bindfilter/zsyscall_windows.go @@ -0,0 +1,116 @@ +//go:build windows + +// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT. + +package bindfilter + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modbindfltapi = windows.NewLazySystemDLL("bindfltapi.dll") + + procBfGetMappings = modbindfltapi.NewProc("BfGetMappings") + procBfRemoveMapping = modbindfltapi.NewProc("BfRemoveMapping") + procBfSetupFilter = modbindfltapi.NewProc("BfSetupFilter") +) + +func bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer *byte) (hr error) { + hr = procBfGetMappings.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall6(procBfGetMappings.Addr(), 6, uintptr(flags), uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(bufferSize)), uintptr(unsafe.Pointer(outBuffer))) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func bfRemoveMapping(jobHandle windows.Handle, virtRootPath string) (hr error) { + var _p0 *uint16 + _p0, hr = syscall.UTF16PtrFromString(virtRootPath) + if hr != nil { + return + } + return _bfRemoveMapping(jobHandle, _p0) +} + +func _bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) { + hr = procBfRemoveMapping.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall(procBfRemoveMapping.Addr(), 2, uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), 0) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath string, virtTargetPath string, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { + var _p0 *uint16 + _p0, hr = syscall.UTF16PtrFromString(virtRootPath) + if hr != nil { + return + } + var _p1 *uint16 + _p1, hr = syscall.UTF16PtrFromString(virtTargetPath) + if hr != nil { + return + } + return _bfSetupFilter(jobHandle, flags, _p0, _p1, virtExceptions, virtExceptionPathCount) +} + +func _bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { + hr = procBfSetupFilter.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall6(procBfSetupFilter.Addr(), 6, uintptr(jobHandle), uintptr(flags), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(virtTargetPath)), uintptr(unsafe.Pointer(virtExceptions)), uintptr(virtExceptionPathCount)) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +}