From 726b1c32367720a4cbf77cb5319077a41a78757f Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Sat, 4 Feb 2023 08:33:20 -0800 Subject: [PATCH 01/10] Add some basic bind filter functions This change adds the ability to mount a a single folder or a volume inside another folder, using the bind filter API. While the API allows mounting multiple sources inside a single mount point, acting as an overlay, we disable this functionality in the ApplyFileBinding function. Signed-off-by: Gabriel Adrian Samfira --- bind_filter.go | 329 ++++++++++++++++++++++++++++++++++++++++++++ zsyscall_windows.go | 57 +++++++- 2 files changed, 382 insertions(+), 4 deletions(-) create mode 100644 bind_filter.go diff --git a/bind_filter.go b/bind_filter.go new file mode 100644 index 00000000..392fbb28 --- /dev/null +++ b/bind_filter.go @@ -0,0 +1,329 @@ +//go:build windows +// +build windows + +package winio + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "syscall" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +//sys bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter? +//sys bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) = bindfltapi.BfRemoveMapping? +//sys bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) = bindfltapi.BfGetMappings? + +// BfSetupFilter flags. See: +// https://github.com/microsoft/BuildXL/blob/a6dce509f0d4f774255e5fbfb75fa6d5290ed163/Public/Src/Utilities/Native/Processes/Windows/NativeContainerUtilities.cs#L193-L240 +const ( + BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001 + // Generates a merged binding, mapping target entries to the virtualization root. + BINDFLT_FLAG_MERGED_BIND_MAPPING uint32 = 0x00000002 + // Use the binding mapping attached to the mapped-in job object (silo) instead of the default global mapping. + BINDFLT_FLAG_USE_CURRENT_SILO_MAPPING uint32 = 0x00000004 + BINDFLT_FLAG_REPARSE_ON_FILES uint32 = 0x00000008 + // Skips checks on file/dir creation inside a non-merged, read-only mapping. + // Only usable when READ_ONLY_MAPPING is set. + BINDFLT_FLAG_SKIP_SHARING_CHECK uint32 = 0x00000010 + BINDFLT_FLAG_CLOUD_FILES_ECPS uint32 = 0x00000020 + // Tells bindflt to fail mapping with STATUS_INVALID_PARAMETER if a mapping produces + // multiple targets. + BINDFLT_FLAG_NO_MULTIPLE_TARGETS uint32 = 0x00000040 + // Turns on caching by asserting that the backing store for name mappings is immutable. + BINDFLT_FLAG_IMMUTABLE_BACKING uint32 = 0x00000080 + BINDFLT_FLAG_PREVENT_CASE_SENSITIVE_BINDING uint32 = 0x00000100 + // Tells bindflt to fail with STATUS_OBJECT_PATH_NOT_FOUND when a mapping is being added + // but its parent paths (ancestors) have not already been added. + BINDFLT_FLAG_EMPTY_VIRT_ROOT uint32 = 0x00000200 + BINDFLT_FLAG_NO_REPARSE_ON_ROOT uint32 = 0x10000000 + BINDFLT_FLAG_BATCHED_REMOVE_MAPPINGS uint32 = 0x20000000 +) + +const ( + BINDFLT_GET_MAPPINGS_FLAG_VOLUME uint32 = 0x00000001 + BINDFLT_GET_MAPPINGS_FLAG_SILO uint32 = 0x00000002 + BINDFLT_GET_MAPPINGS_FLAG_USER uint32 = 0x00000004 +) + +// ApplyFileBinding creates a global mount of the source in root, with an optional +// read only flag. +// The bind filter allows us to create mounts of directories and volumes. By default it allows +// us to mount multiple sources inside a single root, acting as an overlay. Files from the +// second source will superscede the first source that was mounted. +// This function disables this behavior and sets the BINDFLT_FLAG_NO_MULTIPLE_TARGETS flag +// on the mount. +func ApplyFileBinding(root, source string, readOnly bool) error { + // The parent directory needs to exist for the bind to work. MkdirAll stats and + // returns nil if the directory exists internally so we should be fine to mkdirall + // every time. + if err := os.MkdirAll(filepath.Dir(root), 0); err != nil { + return err + } + + if strings.Contains(source, "Volume{") && !strings.HasSuffix(source, "\\") { + // Add trailing slash to volumes, otherwise we get an error when binding it to + // a folder. + source = source + "\\" + } + + rootPtr, err := windows.UTF16PtrFromString(root) + if err != nil { + return err + } + + targetPtr, err := windows.UTF16PtrFromString(source) + if err != nil { + return err + } + flags := BINDFLT_FLAG_NO_MULTIPLE_TARGETS + if readOnly { + flags |= BINDFLT_FLAG_READ_ONLY_MAPPING + } + + // Set the job handle to 0 to create a global mount. + if err := bfSetupFilter( + 0, + flags, + rootPtr, + targetPtr, + nil, + 0, + ); err != nil { + return fmt.Errorf("failed to bind target %q to root %q: %w", source, root, err) + } + return nil +} + +func RemoveFileBinding(root string) error { + rootPtr, err := windows.UTF16PtrFromString(root) + if err != nil { + return fmt.Errorf("converting path to utf-16: %w", err) + } + + if err := bfRemoveMapping(0, rootPtr); err != nil { + return fmt.Errorf("removing file binding: %w", err) + } + return nil +} + +// mappingEntry holds information about where in the response buffer we can +// find information about the virtual root (the mount point) and the targets (sources) +// that get mounted, as well as the flags used to bind the targets to the virtual root. +type mappingEntry struct { + VirtRootLength uint32 + VirtRootOffset uint32 + Flags uint32 + NumberOfTargets uint32 + TargetEntriesOffset uint32 +} + +type mappingTargetEntry struct { + TargetRootLength uint32 + TargetRootOffset uint32 +} + +// getMappingsResponseHeader represents the first 12 bytes of the BfGetMappings() response. +// It gives us the size of the buffer, the status of the call and the number of mappings. +// A response +type getMappingsResponseHeader struct { + Size uint32 + Status uint32 + MappingCount uint32 +} + +type BindMapping struct { + MountPoint string + Flags uint32 + Targets []string +} + +func decodeEntry(buffer []byte) (string, error) { + name := make([]uint16, len(buffer)/2) + err := binary.Read(bytes.NewReader(buffer), binary.LittleEndian, &name) + if err != nil { + return "", fmt.Errorf("decoding name: %w", err) + } + return string(utf16.Decode(name)), nil +} + +func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) { + if len(buffer) < offset+count*6 { + return nil, fmt.Errorf("invalid buffer") + } + + targets := make([]string, count) + for i := 0; i < count; i++ { + entryBuf := buffer[offset+i*8 : offset+i*8+8] + tgt := *(*mappingTargetEntry)(unsafe.Pointer(&entryBuf[0])) + if len(buffer) < int(tgt.TargetRootOffset)+int(tgt.TargetRootLength) { + return nil, fmt.Errorf("invalid buffer") + } + decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : uint32(tgt.TargetRootOffset)+uint32(tgt.TargetRootLength)]) + if err != nil { + return nil, fmt.Errorf("decoding name: %w", err) + } + decoded, err = getFinalPath(decoded) + if err != nil { + return nil, fmt.Errorf("fetching final path: %w", err) + } + + targets[i] = decoded + } + return targets, nil +} + +func getFinalPath(pth string) (string, error) { + // BfGetMappings returns VOLUME_NAME_NT paths like \Device\HarddiskVolume2\ProgramData. + // These can be accessed by prepending \\.\GLOBALROOT to the path. We use this to get the + // DOS paths for these files. + if strings.HasPrefix(pth, `\Device`) { + pth = `\\.\GLOBALROOT` + pth + } + + han, err := getFileHandle(pth) + if err != nil { + return "", fmt.Errorf("fetching file handle: %w", err) + } + + buf := make([]uint16, 100) + var flags uint32 = 0x0 + for { + n, err := windows.GetFinalPathNameByHandle(windows.Handle(han), &buf[0], uint32(len(buf)), flags) + if err != nil { + // if we mounted a volume that does not also have a drive letter assigned, attempting to + // fetch the VOLUME_NAME_DOS will fail with os.ErrNotExist. Attempt to get the VOLUME_NAME_GUID. + if errors.Is(err, os.ErrNotExist) && flags != 0x1 { + flags = 0x1 + continue + } + return "", fmt.Errorf("getting final path name: %w", err) + } + if n < uint32(len(buf)) { + break + } + buf = make([]uint16, n) + } + finalPath := syscall.UTF16ToString(buf) + // We got VOLUME_NAME_DOS, we need to strip away some leading slashes. + // Leave unchanged if we ended up requesting VOLUME_NAME_GUID + if len(finalPath) > 4 && finalPath[:4] == `\\?\` && flags == 0x0 { + finalPath = finalPath[4:] + if len(finalPath) > 3 && finalPath[:3] == `UNC` { + // return path like \\server\share\... + finalPath = `\` + finalPath[3:] + } + } + + return finalPath, nil +} + +func getBindMappingFromBuffer(buffer []byte, entry mappingEntry) (BindMapping, error) { + if len(buffer) < int(entry.VirtRootOffset)+int(entry.VirtRootLength) { + return BindMapping{}, fmt.Errorf("invalid buffer") + } + + src, err := decodeEntry(buffer[entry.VirtRootOffset : entry.VirtRootOffset+uint32(entry.VirtRootLength)]) + if err != nil { + return BindMapping{}, fmt.Errorf("decoding entry: %w", err) + } + targets, err := getTargetsFromBuffer(buffer, int(entry.TargetEntriesOffset), int(entry.NumberOfTargets)) + if err != nil { + return BindMapping{}, fmt.Errorf("fetching targets: %w", err) + } + + src, err = getFinalPath(src) + if err != nil { + return BindMapping{}, fmt.Errorf("fetching final path: %w", err) + } + + return BindMapping{ + Flags: entry.Flags, + Targets: targets, + MountPoint: src, + }, nil +} + +func getFileHandle(pth string) (syscall.Handle, error) { + info, err := os.Lstat(pth) + if err != nil { + return 0, fmt.Errorf("accessing file: %w", err) + } + p, err := syscall.UTF16PtrFromString(pth) + if err != nil { + return 0, err + } + attrs := uint32(syscall.FILE_FLAG_BACKUP_SEMANTICS) + if info.Mode()&os.ModeSymlink != 0 { + attrs |= syscall.FILE_FLAG_OPEN_REPARSE_POINT + } + h, err := syscall.CreateFile(p, 0, 0, nil, syscall.OPEN_EXISTING, attrs, 0) + if err != nil { + return 0, err + } + return h, nil +} + +// GetBindMappings returns a list of bind mappings that have their root on a +// particular volume. The volumePath parameter can be any path that exists on +// a volume. For example, if a number of mappings are created in C:\ProgramData\test, +// to get a list of those mappings, the volumePath parameter would have to be set to +// C:\ or the VOLUME_NAME_GUID notation of C:\ (\\?\Volume{GUID}\), or any child +// path that exists. +func GetBindMappings(volumePath string) ([]BindMapping, error) { + rootPtr, err := windows.UTF16PtrFromString(volumePath) + if err != nil { + return nil, err + } + + var flags uint32 = BINDFLT_GET_MAPPINGS_FLAG_VOLUME + // allocate a large buffer for results + var outBuffSize uint32 = 256 * 1024 + buf := make([]byte, outBuffSize) + + if err := bfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, uintptr(unsafe.Pointer(&buf[0]))); err != nil { + return nil, err + } + + if outBuffSize < 12 { + return nil, fmt.Errorf("invalid buffer returned") + } + + result := buf[:outBuffSize] + + // The first 12 bytes are the three uint32 fields in getMappingsResponseHeader{} + headerBuffer := result[:12] + // The alternative to using unsafe and casting it to the above defined structures, is to manually + // parse the fields. Not too terrible, but not sure it'd worth the trouble. + header := *(*getMappingsResponseHeader)(unsafe.Pointer(&headerBuffer[0])) + + if header.MappingCount == 0 { + // no mappings + return []BindMapping{}, nil + } + + mappingsBuffer := result[12 : int(unsafe.Sizeof(mappingEntry{}))*int(header.MappingCount)] + // Get a pointer to the first mapping in the slice + mappingsPointer := (*mappingEntry)(unsafe.Pointer(&mappingsBuffer[0])) + // Get slice of mappings + mappings := unsafe.Slice(mappingsPointer, header.MappingCount) + + mappingEntries := make([]BindMapping, header.MappingCount) + for i := 0; i < int(header.MappingCount); i++ { + bindMapping, err := getBindMappingFromBuffer(result, mappings[i]) + if err != nil { + return nil, fmt.Errorf("fetching bind mappings: %w", err) + } + mappingEntries[i] = bindMapping + } + + return mappingEntries, nil +} diff --git a/zsyscall_windows.go b/zsyscall_windows.go index 83f45a13..ba67daa0 100644 --- a/zsyscall_windows.go +++ b/zsyscall_windows.go @@ -40,10 +40,11 @@ func errnoErr(e syscall.Errno) error { } var ( - modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - modntdll = windows.NewLazySystemDLL("ntdll.dll") - modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modbindfltapi = windows.NewLazySystemDLL("bindfltapi.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modntdll = windows.NewLazySystemDLL("ntdll.dll") + modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges") procConvertSecurityDescriptorToStringSecurityDescriptorW = modadvapi32.NewProc("ConvertSecurityDescriptorToStringSecurityDescriptorW") @@ -59,6 +60,9 @@ var ( procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW") procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken") procRevertToSelf = modadvapi32.NewProc("RevertToSelf") + procBfGetMappings = modbindfltapi.NewProc("BfGetMappings") + procBfRemoveMapping = modbindfltapi.NewProc("BfRemoveMapping") + procBfSetupFilter = modbindfltapi.NewProc("BfSetupFilter") procBackupRead = modkernel32.NewProc("BackupRead") procBackupWrite = modkernel32.NewProc("BackupWrite") procCancelIoEx = modkernel32.NewProc("CancelIoEx") @@ -249,6 +253,51 @@ func revertToSelf() (err error) { return } +func bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) { + hr = procBfGetMappings.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall6(procBfGetMappings.Addr(), 6, uintptr(flags), uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(bufferSize)), uintptr(outBuffer)) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) { + hr = procBfRemoveMapping.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall(procBfRemoveMapping.Addr(), 2, uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), 0) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { + hr = procBfSetupFilter.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall6(procBfSetupFilter.Addr(), 6, uintptr(jobHandle), uintptr(flags), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(virtTargetPath)), uintptr(unsafe.Pointer(virtExceptions)), uintptr(virtExceptionPathCount)) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + func backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) { var _p0 *byte if len(b) > 0 { From 9a22fe182686de4065010e9acd5b2a4efcf330d5 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Mon, 6 Feb 2023 08:13:01 -0800 Subject: [PATCH 02/10] Add some tests Signed-off-by: Gabriel Adrian Samfira --- bind_filter.go | 113 ++++++++++++------------ bind_filter_test.go | 208 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 265 insertions(+), 56 deletions(-) create mode 100644 bind_filter_test.go diff --git a/bind_filter.go b/bind_filter.go index 392fbb28..b10f218a 100644 --- a/bind_filter.go +++ b/bind_filter.go @@ -103,6 +103,7 @@ func ApplyFileBinding(root, source string, readOnly bool) error { return nil } +// RemoveFileBinding removes a mount from the root path. func RemoveFileBinding(root string) error { rootPtr, err := windows.UTF16PtrFromString(root) if err != nil { @@ -115,6 +116,62 @@ func RemoveFileBinding(root string) error { return nil } +// GetBindMappings returns a list of bind mappings that have their root on a +// particular volume. The volumePath parameter can be any path that exists on +// a volume. For example, if a number of mappings are created in C:\ProgramData\test, +// to get a list of those mappings, the volumePath parameter would have to be set to +// C:\ or the VOLUME_NAME_GUID notation of C:\ (\\?\Volume{GUID}\), or any child +// path that exists. +func GetBindMappings(volumePath string) ([]BindMapping, error) { + rootPtr, err := windows.UTF16PtrFromString(volumePath) + if err != nil { + return nil, err + } + + var flags uint32 = BINDFLT_GET_MAPPINGS_FLAG_VOLUME + // allocate a large buffer for results + var outBuffSize uint32 = 256 * 1024 + buf := make([]byte, outBuffSize) + + if err := bfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, uintptr(unsafe.Pointer(&buf[0]))); err != nil { + return nil, err + } + + if outBuffSize < 12 { + return nil, fmt.Errorf("invalid buffer returned") + } + + result := buf[:outBuffSize] + + // The first 12 bytes are the three uint32 fields in getMappingsResponseHeader{} + headerBuffer := result[:12] + // The alternative to using unsafe and casting it to the above defined structures, is to manually + // parse the fields. Not too terrible, but not sure it'd worth the trouble. + header := *(*getMappingsResponseHeader)(unsafe.Pointer(&headerBuffer[0])) + + if header.MappingCount == 0 { + // no mappings + return []BindMapping{}, nil + } + + mappingsBuffer := result[12 : int(unsafe.Sizeof(mappingEntry{}))*int(header.MappingCount)] + // Get a pointer to the first mapping in the slice + mappingsPointer := (*mappingEntry)(unsafe.Pointer(&mappingsBuffer[0])) + // Get slice of mappings + mappings := unsafe.Slice(mappingsPointer, header.MappingCount) + + mappingEntries := make([]BindMapping, header.MappingCount) + for i := 0; i < int(header.MappingCount); i++ { + bindMapping, err := getBindMappingFromBuffer(result, mappings[i]) + if err != nil { + return nil, fmt.Errorf("fetching bind mappings: %w", err) + } + mappingEntries[i] = bindMapping + } + + return mappingEntries, nil +} + // mappingEntry holds information about where in the response buffer we can // find information about the virtual root (the mount point) and the targets (sources) // that get mounted, as well as the flags used to bind the targets to the virtual root. @@ -271,59 +328,3 @@ func getFileHandle(pth string) (syscall.Handle, error) { } return h, nil } - -// GetBindMappings returns a list of bind mappings that have their root on a -// particular volume. The volumePath parameter can be any path that exists on -// a volume. For example, if a number of mappings are created in C:\ProgramData\test, -// to get a list of those mappings, the volumePath parameter would have to be set to -// C:\ or the VOLUME_NAME_GUID notation of C:\ (\\?\Volume{GUID}\), or any child -// path that exists. -func GetBindMappings(volumePath string) ([]BindMapping, error) { - rootPtr, err := windows.UTF16PtrFromString(volumePath) - if err != nil { - return nil, err - } - - var flags uint32 = BINDFLT_GET_MAPPINGS_FLAG_VOLUME - // allocate a large buffer for results - var outBuffSize uint32 = 256 * 1024 - buf := make([]byte, outBuffSize) - - if err := bfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, uintptr(unsafe.Pointer(&buf[0]))); err != nil { - return nil, err - } - - if outBuffSize < 12 { - return nil, fmt.Errorf("invalid buffer returned") - } - - result := buf[:outBuffSize] - - // The first 12 bytes are the three uint32 fields in getMappingsResponseHeader{} - headerBuffer := result[:12] - // The alternative to using unsafe and casting it to the above defined structures, is to manually - // parse the fields. Not too terrible, but not sure it'd worth the trouble. - header := *(*getMappingsResponseHeader)(unsafe.Pointer(&headerBuffer[0])) - - if header.MappingCount == 0 { - // no mappings - return []BindMapping{}, nil - } - - mappingsBuffer := result[12 : int(unsafe.Sizeof(mappingEntry{}))*int(header.MappingCount)] - // Get a pointer to the first mapping in the slice - mappingsPointer := (*mappingEntry)(unsafe.Pointer(&mappingsBuffer[0])) - // Get slice of mappings - mappings := unsafe.Slice(mappingsPointer, header.MappingCount) - - mappingEntries := make([]BindMapping, header.MappingCount) - for i := 0; i < int(header.MappingCount); i++ { - bindMapping, err := getBindMappingFromBuffer(result, mappings[i]) - if err != nil { - return nil, fmt.Errorf("fetching bind mappings: %w", err) - } - mappingEntries[i] = bindMapping - } - - return mappingEntries, nil -} diff --git a/bind_filter_test.go b/bind_filter_test.go new file mode 100644 index 00000000..d4c5a401 --- /dev/null +++ b/bind_filter_test.go @@ -0,0 +1,208 @@ +//go:build windows +// +build windows + +package winio + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestApplyFileBinding(t *testing.T) { + source := t.TempDir() + destination := t.TempDir() + fileName := "testFile.txt" + srcFile := filepath.Join(source, fileName) + dstFile := filepath.Join(destination, fileName) + + err := ApplyFileBinding(destination, source, false) + if err != nil { + t.Fatal(err) + } + defer RemoveFileBinding(destination) + + data := []byte("bind filter test") + + if err := os.WriteFile(srcFile, data, 0755); err != nil { + t.Fatal(err) + } + + readData, err := os.ReadFile(dstFile) + if err != nil { + t.Fatal(err) + } + + if string(readData) != string(data) { + t.Fatalf("source and destination file contents differ. Expected: %s, got: %s", string(data), string(readData)) + } + + // Remove the file on the mount point. The mount is not read-only, this should work. + if err := os.Remove(dstFile); err != nil { + t.Fatalf("failed to remove file from mount point: %s", err) + } + + // Check that it's gone from the source as well. + if _, err := os.Stat(srcFile); err == nil { + t.Fatalf("expected file %s to be gone but is not", srcFile) + } +} + +func TestApplyFileBindingReadOnly(t *testing.T) { + source := t.TempDir() + destination := t.TempDir() + fileName := "testFile.txt" + srcFile := filepath.Join(source, fileName) + dstFile := filepath.Join(destination, fileName) + + err := ApplyFileBinding(destination, source, true) + if err != nil { + t.Fatal(err) + } + defer RemoveFileBinding(destination) + + data := []byte("bind filter test") + + if err := os.WriteFile(srcFile, data, 0755); err != nil { + t.Fatal(err) + } + + readData, err := os.ReadFile(dstFile) + if err != nil { + t.Fatal(err) + } + + if string(readData) != string(data) { + t.Fatalf("source and destination file contents differ. Expected: %s, got: %s", string(data), string(readData)) + } + + // Attempt to remove the file on the mount point + err = os.Remove(dstFile) + if err == nil { + t.Fatalf("should not be able to remove a file from a read-only mount") + } + if !errors.Is(err, os.ErrPermission) { + t.Fatalf("expected an access denied error, got: %q", err) + } +} + +func TestEnsureOnlyOneTargetCanBeMounted(t *testing.T) { + source := t.TempDir() + secondarySource := t.TempDir() + destination := t.TempDir() + + err := ApplyFileBinding(destination, source, false) + if err != nil { + t.Fatal(err) + } + + defer RemoveFileBinding(destination) + err = ApplyFileBinding(destination, secondarySource, false) + if err == nil { + RemoveFileBinding(destination) + t.Fatalf("we should not be able to mount multiple targets in the same destination") + } +} + +func checkSourceIsMountedOnDestination(src, dst string) (bool, error) { + mappings, err := GetBindMappings(dst) + if err != nil { + return false, err + } + + found := false + // There may be pre-existing mappings on the system. + for _, mapping := range mappings { + if mapping.MountPoint == dst { + found = true + if len(mapping.Targets) != 1 { + return false, fmt.Errorf("expected only one target, got: %s", strings.Join(mapping.Targets, ", ")) + } + if mapping.Targets[0] != src { + return false, fmt.Errorf("expected target to be %s, got %s", src, mapping.Targets[0]) + } + break + } + } + + return found, nil +} + +func TestGetBindMappings(t *testing.T) { + // GetBindMappings will exoand short paths like ADMINI~1 and PROGRA~1 to their + // full names. In order to properly match the names later, we expand them here. + srcShort := t.TempDir() + source, err := getFinalPath(srcShort) + if err != nil { + t.Fatalf("failed to get long path") + } + + dstShort := t.TempDir() + destination, err := getFinalPath(dstShort) + if err != nil { + t.Fatalf("failed to get long path") + } + + err = ApplyFileBinding(destination, source, false) + if err != nil { + t.Fatal(err) + } + defer RemoveFileBinding(destination) + + hasMapping, err := checkSourceIsMountedOnDestination(source, destination) + if err != nil { + t.Fatal(err) + } + + if !hasMapping { + t.Fatalf("expected to find %s mounted on %s, but could not", source, destination) + } +} + +func TestRemoveFileBinding(t *testing.T) { + // GetBindMappings will exoand short paths like ADMINI~1 and PROGRA~1 to their + // full names. In order to properly match the names later, we expand them here. + srcShort := t.TempDir() + source, err := getFinalPath(srcShort) + if err != nil { + t.Fatalf("failed to get long path") + } + + dstShort := t.TempDir() + destination, err := getFinalPath(dstShort) + if err != nil { + t.Fatalf("failed to get long path") + } + + err = ApplyFileBinding(destination, source, false) + if err != nil { + t.Fatal(err) + } + + hasMapping, err := checkSourceIsMountedOnDestination(source, destination) + if err != nil { + RemoveFileBinding(destination) + t.Fatal(err) + } + + if !hasMapping { + RemoveFileBinding(destination) + t.Fatalf("expected to find %s mounted on %s, but could not", source, destination) + } + + if err := RemoveFileBinding(destination); err != nil { + t.Fatal(err) + } + + hasMapping, err = checkSourceIsMountedOnDestination(source, destination) + if err != nil { + t.Fatal(err) + } + + if hasMapping { + t.Fatalf("expected to find %s unmounted from %s, but it seems to still be mounted", source, destination) + } +} From 76c22e3aa4278b7c11f940eddb58626fed5ad369 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Mon, 6 Feb 2023 11:17:25 -0800 Subject: [PATCH 03/10] Move bind filter to different package Signed-off-by: Gabriel Adrian Samfira --- .github/workflows/ci.yml | 7 +- .../bindfilter/bind_filter.go | 12 ++- .../bindfilter/bind_filter_test.go | 49 +++++----- pkg/bindfilter/zsyscall_windows.go | 93 +++++++++++++++++++ zsyscall_windows.go | 57 +----------- 5 files changed, 138 insertions(+), 80 deletions(-) rename bind_filter.go => pkg/bindfilter/bind_filter.go (96%) rename bind_filter_test.go => pkg/bindfilter/bind_filter_test.go (81%) create mode 100644 pkg/bindfilter/zsyscall_windows.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bbd32daf..04b51ae6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,7 +70,12 @@ jobs: - uses: actions/setup-go@v3 with: go-version: ${{ env.GO_VERSION }} - - run: go test -gcflags=all=-d=checkptr -v ./... + - name: Run tests on ltsc 2019 + if: matrix.os == 'windows-2019' + run: go test -gcflags=all=-d=checkptr -v --test.run="[^TestEnsureOnlyOneTargetCanBeMounted|^TestGetBindMappings|^TestRemoveFileBinding]" ./... + - name: Run tests + if: matrix.os != 'windows-2019' + run: go test -gcflags=all=-d=checkptr -v ./... build: name: Build Repo diff --git a/bind_filter.go b/pkg/bindfilter/bind_filter.go similarity index 96% rename from bind_filter.go rename to pkg/bindfilter/bind_filter.go index b10f218a..b601c760 100644 --- a/bind_filter.go +++ b/pkg/bindfilter/bind_filter.go @@ -1,7 +1,7 @@ //go:build windows // +build windows -package winio +package bindfilter import ( "bytes" @@ -18,12 +18,15 @@ import ( "golang.org/x/sys/windows" ) +//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./bind_filter.go //sys bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter? //sys bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) = bindfltapi.BfRemoveMapping? //sys bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) = bindfltapi.BfGetMappings? // BfSetupFilter flags. See: // https://github.com/microsoft/BuildXL/blob/a6dce509f0d4f774255e5fbfb75fa6d5290ed163/Public/Src/Utilities/Native/Processes/Windows/NativeContainerUtilities.cs#L193-L240 +// +//nolint:revive // var-naming: ALL_CAPS const ( BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001 // Generates a merged binding, mapping target entries to the virtualization root. @@ -48,6 +51,7 @@ const ( BINDFLT_FLAG_BATCHED_REMOVE_MAPPINGS uint32 = 0x20000000 ) +//nolint:revive // var-naming: ALL_CAPS const ( BINDFLT_GET_MAPPINGS_FLAG_VOLUME uint32 = 0x00000001 BINDFLT_GET_MAPPINGS_FLAG_SILO uint32 = 0x00000002 @@ -128,7 +132,7 @@ func GetBindMappings(volumePath string) ([]BindMapping, error) { return nil, err } - var flags uint32 = BINDFLT_GET_MAPPINGS_FLAG_VOLUME + flags := BINDFLT_GET_MAPPINGS_FLAG_VOLUME // allocate a large buffer for results var outBuffSize uint32 = 256 * 1024 buf := make([]byte, outBuffSize) @@ -224,7 +228,7 @@ func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) { if len(buffer) < int(tgt.TargetRootOffset)+int(tgt.TargetRootLength) { return nil, fmt.Errorf("invalid buffer") } - decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : uint32(tgt.TargetRootOffset)+uint32(tgt.TargetRootLength)]) + decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : tgt.TargetRootOffset+tgt.TargetRootLength]) if err != nil { return nil, fmt.Errorf("decoding name: %w", err) } @@ -288,7 +292,7 @@ func getBindMappingFromBuffer(buffer []byte, entry mappingEntry) (BindMapping, e return BindMapping{}, fmt.Errorf("invalid buffer") } - src, err := decodeEntry(buffer[entry.VirtRootOffset : entry.VirtRootOffset+uint32(entry.VirtRootLength)]) + src, err := decodeEntry(buffer[entry.VirtRootOffset : entry.VirtRootOffset+entry.VirtRootLength]) if err != nil { return BindMapping{}, fmt.Errorf("decoding entry: %w", err) } diff --git a/bind_filter_test.go b/pkg/bindfilter/bind_filter_test.go similarity index 81% rename from bind_filter_test.go rename to pkg/bindfilter/bind_filter_test.go index d4c5a401..9e7efd0c 100644 --- a/bind_filter_test.go +++ b/pkg/bindfilter/bind_filter_test.go @@ -1,7 +1,7 @@ //go:build windows // +build windows -package winio +package bindfilter import ( "errors" @@ -23,11 +23,11 @@ func TestApplyFileBinding(t *testing.T) { if err != nil { t.Fatal(err) } - defer RemoveFileBinding(destination) + defer removeFileBinding(t, destination) data := []byte("bind filter test") - if err := os.WriteFile(srcFile, data, 0755); err != nil { + if err := os.WriteFile(srcFile, data, 0600); err != nil { t.Fatal(err) } @@ -51,6 +51,12 @@ func TestApplyFileBinding(t *testing.T) { } } +func removeFileBinding(t *testing.T, mountpoint string) { + if err := RemoveFileBinding(mountpoint); err != nil { + t.Logf("failed to remove file binding from %s: %q", mountpoint, err) + } +} + func TestApplyFileBindingReadOnly(t *testing.T) { source := t.TempDir() destination := t.TempDir() @@ -62,11 +68,11 @@ func TestApplyFileBindingReadOnly(t *testing.T) { if err != nil { t.Fatal(err) } - defer RemoveFileBinding(destination) + defer removeFileBinding(t, destination) data := []byte("bind filter test") - if err := os.WriteFile(srcFile, data, 0755); err != nil { + if err := os.WriteFile(srcFile, data, 0600); err != nil { t.Fatal(err) } @@ -99,10 +105,11 @@ func TestEnsureOnlyOneTargetCanBeMounted(t *testing.T) { t.Fatal(err) } - defer RemoveFileBinding(destination) + defer removeFileBinding(t, destination) + err = ApplyFileBinding(destination, secondarySource, false) if err == nil { - RemoveFileBinding(destination) + removeFileBinding(t, destination) t.Fatalf("we should not be able to mount multiple targets in the same destination") } } @@ -150,7 +157,7 @@ func TestGetBindMappings(t *testing.T) { if err != nil { t.Fatal(err) } - defer RemoveFileBinding(destination) + defer removeFileBinding(t, destination) hasMapping, err := checkSourceIsMountedOnDestination(source, destination) if err != nil { @@ -177,32 +184,30 @@ func TestRemoveFileBinding(t *testing.T) { t.Fatalf("failed to get long path") } - err = ApplyFileBinding(destination, source, false) - if err != nil { + fileName := "testFile.txt" + srcFile := filepath.Join(source, fileName) + dstFile := filepath.Join(destination, fileName) + data := []byte("bind filter test") + + if err := os.WriteFile(srcFile, data, 0600); err != nil { t.Fatal(err) } - hasMapping, err := checkSourceIsMountedOnDestination(source, destination) + err = ApplyFileBinding(destination, source, false) if err != nil { - RemoveFileBinding(destination) t.Fatal(err) } + defer removeFileBinding(t, destination) - if !hasMapping { - RemoveFileBinding(destination) - t.Fatalf("expected to find %s mounted on %s, but could not", source, destination) + if _, err := os.Stat(dstFile); err != nil { + t.Fatalf("expected to find %s, but did not", dstFile) } if err := RemoveFileBinding(destination); err != nil { t.Fatal(err) } - hasMapping, err = checkSourceIsMountedOnDestination(source, destination) - if err != nil { - t.Fatal(err) - } - - if hasMapping { - t.Fatalf("expected to find %s unmounted from %s, but it seems to still be mounted", source, destination) + if _, err := os.Stat(dstFile); err == nil { + t.Fatalf("expected %s to be gone, but it not", dstFile) } } diff --git a/pkg/bindfilter/zsyscall_windows.go b/pkg/bindfilter/zsyscall_windows.go new file mode 100644 index 00000000..091065af --- /dev/null +++ b/pkg/bindfilter/zsyscall_windows.go @@ -0,0 +1,93 @@ +//go:build windows + +// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT. + +package bindfilter + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modbindfltapi = windows.NewLazySystemDLL("bindfltapi.dll") + + procBfGetMappings = modbindfltapi.NewProc("BfGetMappings") + procBfRemoveMapping = modbindfltapi.NewProc("BfRemoveMapping") + procBfSetupFilter = modbindfltapi.NewProc("BfSetupFilter") +) + +func bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) { + hr = procBfGetMappings.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall6(procBfGetMappings.Addr(), 6, uintptr(flags), uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(bufferSize)), uintptr(outBuffer)) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) { + hr = procBfRemoveMapping.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall(procBfRemoveMapping.Addr(), 2, uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), 0) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { + hr = procBfSetupFilter.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall6(procBfSetupFilter.Addr(), 6, uintptr(jobHandle), uintptr(flags), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(virtTargetPath)), uintptr(unsafe.Pointer(virtExceptions)), uintptr(virtExceptionPathCount)) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} diff --git a/zsyscall_windows.go b/zsyscall_windows.go index ba67daa0..83f45a13 100644 --- a/zsyscall_windows.go +++ b/zsyscall_windows.go @@ -40,11 +40,10 @@ func errnoErr(e syscall.Errno) error { } var ( - modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") - modbindfltapi = windows.NewLazySystemDLL("bindfltapi.dll") - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - modntdll = windows.NewLazySystemDLL("ntdll.dll") - modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modntdll = windows.NewLazySystemDLL("ntdll.dll") + modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges") procConvertSecurityDescriptorToStringSecurityDescriptorW = modadvapi32.NewProc("ConvertSecurityDescriptorToStringSecurityDescriptorW") @@ -60,9 +59,6 @@ var ( procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW") procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken") procRevertToSelf = modadvapi32.NewProc("RevertToSelf") - procBfGetMappings = modbindfltapi.NewProc("BfGetMappings") - procBfRemoveMapping = modbindfltapi.NewProc("BfRemoveMapping") - procBfSetupFilter = modbindfltapi.NewProc("BfSetupFilter") procBackupRead = modkernel32.NewProc("BackupRead") procBackupWrite = modkernel32.NewProc("BackupWrite") procCancelIoEx = modkernel32.NewProc("CancelIoEx") @@ -253,51 +249,6 @@ func revertToSelf() (err error) { return } -func bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) { - hr = procBfGetMappings.Find() - if hr != nil { - return - } - r0, _, _ := syscall.Syscall6(procBfGetMappings.Addr(), 6, uintptr(flags), uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(bufferSize)), uintptr(outBuffer)) - if int32(r0) < 0 { - if r0&0x1fff0000 == 0x00070000 { - r0 &= 0xffff - } - hr = syscall.Errno(r0) - } - return -} - -func bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) { - hr = procBfRemoveMapping.Find() - if hr != nil { - return - } - r0, _, _ := syscall.Syscall(procBfRemoveMapping.Addr(), 2, uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), 0) - if int32(r0) < 0 { - if r0&0x1fff0000 == 0x00070000 { - r0 &= 0xffff - } - hr = syscall.Errno(r0) - } - return -} - -func bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { - hr = procBfSetupFilter.Find() - if hr != nil { - return - } - r0, _, _ := syscall.Syscall6(procBfSetupFilter.Addr(), 6, uintptr(jobHandle), uintptr(flags), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(virtTargetPath)), uintptr(unsafe.Pointer(virtExceptions)), uintptr(virtExceptionPathCount)) - if int32(r0) < 0 { - if r0&0x1fff0000 == 0x00070000 { - r0 &= 0xffff - } - hr = syscall.Errno(r0) - } - return -} - func backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) { var _p0 *byte if len(b) > 0 { From 2c3a145e69188f95ffdd98c5ea02a114cabc2ccb Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Mon, 6 Feb 2023 15:50:01 -0800 Subject: [PATCH 04/10] Use string in signature and fix getFinalPath * Properly close handle in getFinalPath() * Use string in function signature. mksyscall generates proper code to convert to utf16 * Enable TestRemoveFileBinding on Windows Server 2019 Windows Server 2019 only exposes 2 function in bindfltapi.dll: * BfRemoveMapping * BfSetupFilter Signed-off-by: Gabriel Adrian Samfira --- .github/workflows/ci.yml | 2 +- pkg/bindfilter/bind_filter.go | 25 ++++++------------------- pkg/bindfilter/bind_filter_test.go | 4 +--- pkg/bindfilter/zsyscall_windows.go | 27 +++++++++++++++++++++++++-- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 04b51ae6..c6ff6045 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,7 +72,7 @@ jobs: go-version: ${{ env.GO_VERSION }} - name: Run tests on ltsc 2019 if: matrix.os == 'windows-2019' - run: go test -gcflags=all=-d=checkptr -v --test.run="[^TestEnsureOnlyOneTargetCanBeMounted|^TestGetBindMappings|^TestRemoveFileBinding]" ./... + run: go test -gcflags=all=-d=checkptr -v --test.run="[^TestEnsureOnlyOneTargetCanBeMounted|^TestGetBindMappings]" ./... - name: Run tests if: matrix.os != 'windows-2019' run: go test -gcflags=all=-d=checkptr -v ./... diff --git a/pkg/bindfilter/bind_filter.go b/pkg/bindfilter/bind_filter.go index b601c760..7ac01f1f 100644 --- a/pkg/bindfilter/bind_filter.go +++ b/pkg/bindfilter/bind_filter.go @@ -19,8 +19,8 @@ import ( ) //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./bind_filter.go -//sys bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter? -//sys bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) = bindfltapi.BfRemoveMapping? +//sys bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath string, virtTargetPath string, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter? +//sys bfRemoveMapping(jobHandle windows.Handle, virtRootPath string) (hr error) = bindfltapi.BfRemoveMapping? //sys bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) = bindfltapi.BfGetMappings? // BfSetupFilter flags. See: @@ -79,15 +79,6 @@ func ApplyFileBinding(root, source string, readOnly bool) error { source = source + "\\" } - rootPtr, err := windows.UTF16PtrFromString(root) - if err != nil { - return err - } - - targetPtr, err := windows.UTF16PtrFromString(source) - if err != nil { - return err - } flags := BINDFLT_FLAG_NO_MULTIPLE_TARGETS if readOnly { flags |= BINDFLT_FLAG_READ_ONLY_MAPPING @@ -97,8 +88,8 @@ func ApplyFileBinding(root, source string, readOnly bool) error { if err := bfSetupFilter( 0, flags, - rootPtr, - targetPtr, + root, + source, nil, 0, ); err != nil { @@ -109,12 +100,7 @@ func ApplyFileBinding(root, source string, readOnly bool) error { // RemoveFileBinding removes a mount from the root path. func RemoveFileBinding(root string) error { - rootPtr, err := windows.UTF16PtrFromString(root) - if err != nil { - return fmt.Errorf("converting path to utf-16: %w", err) - } - - if err := bfRemoveMapping(0, rootPtr); err != nil { + if err := bfRemoveMapping(0, root); err != nil { return fmt.Errorf("removing file binding: %w", err) } return nil @@ -254,6 +240,7 @@ func getFinalPath(pth string) (string, error) { if err != nil { return "", fmt.Errorf("fetching file handle: %w", err) } + defer syscall.CloseHandle(han) buf := make([]uint16, 100) var flags uint32 = 0x0 diff --git a/pkg/bindfilter/bind_filter_test.go b/pkg/bindfilter/bind_filter_test.go index 9e7efd0c..57cce6ad 100644 --- a/pkg/bindfilter/bind_filter_test.go +++ b/pkg/bindfilter/bind_filter_test.go @@ -170,8 +170,6 @@ func TestGetBindMappings(t *testing.T) { } func TestRemoveFileBinding(t *testing.T) { - // GetBindMappings will exoand short paths like ADMINI~1 and PROGRA~1 to their - // full names. In order to properly match the names later, we expand them here. srcShort := t.TempDir() source, err := getFinalPath(srcShort) if err != nil { @@ -197,9 +195,9 @@ func TestRemoveFileBinding(t *testing.T) { if err != nil { t.Fatal(err) } - defer removeFileBinding(t, destination) if _, err := os.Stat(dstFile); err != nil { + removeFileBinding(t, destination) t.Fatalf("expected to find %s, but did not", dstFile) } diff --git a/pkg/bindfilter/zsyscall_windows.go b/pkg/bindfilter/zsyscall_windows.go index 091065af..65e2e7d8 100644 --- a/pkg/bindfilter/zsyscall_windows.go +++ b/pkg/bindfilter/zsyscall_windows.go @@ -62,7 +62,16 @@ func bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, return } -func bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) { +func bfRemoveMapping(jobHandle windows.Handle, virtRootPath string) (hr error) { + var _p0 *uint16 + _p0, hr = syscall.UTF16PtrFromString(virtRootPath) + if hr != nil { + return + } + return _bfRemoveMapping(jobHandle, _p0) +} + +func _bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) { hr = procBfRemoveMapping.Find() if hr != nil { return @@ -77,7 +86,21 @@ func bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) return } -func bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { +func bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath string, virtTargetPath string, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { + var _p0 *uint16 + _p0, hr = syscall.UTF16PtrFromString(virtRootPath) + if hr != nil { + return + } + var _p1 *uint16 + _p1, hr = syscall.UTF16PtrFromString(virtTargetPath) + if hr != nil { + return + } + return _bfSetupFilter(jobHandle, flags, _p0, _p1, virtExceptions, virtExceptionPathCount) +} + +func _bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { hr = procBfSetupFilter.Find() if hr != nil { return From e1b82b1fe29405f0d1ec87fc7b09bad872ffbb0f Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Mon, 6 Feb 2023 15:56:13 -0800 Subject: [PATCH 05/10] Use windows.UTF16ToString to decode string Signed-off-by: Gabriel Adrian Samfira --- pkg/bindfilter/bind_filter.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/bindfilter/bind_filter.go b/pkg/bindfilter/bind_filter.go index 7ac01f1f..c35d5ca4 100644 --- a/pkg/bindfilter/bind_filter.go +++ b/pkg/bindfilter/bind_filter.go @@ -12,7 +12,6 @@ import ( "path/filepath" "strings" "syscall" - "unicode/utf16" "unsafe" "golang.org/x/sys/windows" @@ -199,7 +198,7 @@ func decodeEntry(buffer []byte) (string, error) { if err != nil { return "", fmt.Errorf("decoding name: %w", err) } - return string(utf16.Decode(name)), nil + return windows.UTF16ToString(name), nil } func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) { From 258cf205600b2c9f208648d63f34e6180da68d3f Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Mon, 6 Feb 2023 15:58:52 -0800 Subject: [PATCH 06/10] Optimize bfGetMappings signature Signed-off-by: Gabriel Adrian Samfira --- pkg/bindfilter/bind_filter.go | 8 +++++--- pkg/bindfilter/zsyscall_windows.go | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pkg/bindfilter/bind_filter.go b/pkg/bindfilter/bind_filter.go index c35d5ca4..5c84cfb3 100644 --- a/pkg/bindfilter/bind_filter.go +++ b/pkg/bindfilter/bind_filter.go @@ -20,7 +20,7 @@ import ( //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./bind_filter.go //sys bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath string, virtTargetPath string, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter? //sys bfRemoveMapping(jobHandle windows.Handle, virtRootPath string) (hr error) = bindfltapi.BfRemoveMapping? -//sys bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) = bindfltapi.BfGetMappings? +//sys bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer *byte) (hr error) = bindfltapi.BfGetMappings? // BfSetupFilter flags. See: // https://github.com/microsoft/BuildXL/blob/a6dce509f0d4f774255e5fbfb75fa6d5290ed163/Public/Src/Utilities/Native/Processes/Windows/NativeContainerUtilities.cs#L193-L240 @@ -122,7 +122,7 @@ func GetBindMappings(volumePath string) ([]BindMapping, error) { var outBuffSize uint32 = 256 * 1024 buf := make([]byte, outBuffSize) - if err := bfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, uintptr(unsafe.Pointer(&buf[0]))); err != nil { + if err := bfGetMappings(flags, 0, rootPtr, nil, &outBuffSize, &buf[0]); err != nil { return nil, err } @@ -239,7 +239,9 @@ func getFinalPath(pth string) (string, error) { if err != nil { return "", fmt.Errorf("fetching file handle: %w", err) } - defer syscall.CloseHandle(han) + defer func() { + _ = syscall.CloseHandle(han) + }() buf := make([]uint16, 100) var flags uint32 = 0x0 diff --git a/pkg/bindfilter/zsyscall_windows.go b/pkg/bindfilter/zsyscall_windows.go index 65e2e7d8..45c45c96 100644 --- a/pkg/bindfilter/zsyscall_windows.go +++ b/pkg/bindfilter/zsyscall_windows.go @@ -47,12 +47,12 @@ var ( procBfSetupFilter = modbindfltapi.NewProc("BfSetupFilter") ) -func bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) { +func bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer *byte) (hr error) { hr = procBfGetMappings.Find() if hr != nil { return } - r0, _, _ := syscall.Syscall6(procBfGetMappings.Addr(), 6, uintptr(flags), uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(bufferSize)), uintptr(outBuffer)) + r0, _, _ := syscall.Syscall6(procBfGetMappings.Addr(), 6, uintptr(flags), uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(bufferSize)), uintptr(unsafe.Pointer(outBuffer))) if int32(r0) < 0 { if r0&0x1fff0000 == 0x00070000 { r0 &= 0xffff From 841fe8910c9aacde8f759427709a4c3735b18af3 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Tue, 7 Feb 2023 10:17:49 -0800 Subject: [PATCH 07/10] Skip unsupported tests on ltsc2019 Signed-off-by: Gabriel Adrian Samfira --- .github/workflows/ci.yml | 7 +---- pkg/bindfilter/bind_filter_test.go | 42 +++++++++++++++++++++++++++--- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c6ff6045..bbd32daf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,12 +70,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: ${{ env.GO_VERSION }} - - name: Run tests on ltsc 2019 - if: matrix.os == 'windows-2019' - run: go test -gcflags=all=-d=checkptr -v --test.run="[^TestEnsureOnlyOneTargetCanBeMounted|^TestGetBindMappings]" ./... - - name: Run tests - if: matrix.os != 'windows-2019' - run: go test -gcflags=all=-d=checkptr -v ./... + - run: go test -gcflags=all=-d=checkptr -v ./... build: name: Build Repo diff --git a/pkg/bindfilter/bind_filter_test.go b/pkg/bindfilter/bind_filter_test.go index 57cce6ad..a5f2da91 100644 --- a/pkg/bindfilter/bind_filter_test.go +++ b/pkg/bindfilter/bind_filter_test.go @@ -8,8 +8,11 @@ import ( "fmt" "os" "path/filepath" + "strconv" "strings" "testing" + + "golang.org/x/sys/windows/registry" ) func TestApplyFileBinding(t *testing.T) { @@ -96,11 +99,19 @@ func TestApplyFileBindingReadOnly(t *testing.T) { } func TestEnsureOnlyOneTargetCanBeMounted(t *testing.T) { + version, err := getWindowsBuildNumber() + if err != nil { + t.Fatalf("couldn't get version number: %s", err) + } + + if version <= 17763 { + t.Skip("not supported on RS5 or earlier") + } source := t.TempDir() secondarySource := t.TempDir() destination := t.TempDir() - err := ApplyFileBinding(destination, source, false) + err = ApplyFileBinding(destination, source, false) if err != nil { t.Fatal(err) } @@ -139,6 +150,14 @@ func checkSourceIsMountedOnDestination(src, dst string) (bool, error) { } func TestGetBindMappings(t *testing.T) { + version, err := getWindowsBuildNumber() + if err != nil { + t.Fatalf("couldn't get version number: %s", err) + } + + if version <= 17763 { + t.Skip("not supported on RS5 or earlier") + } // GetBindMappings will exoand short paths like ADMINI~1 and PROGRA~1 to their // full names. In order to properly match the names later, we expand them here. srcShort := t.TempDir() @@ -198,7 +217,7 @@ func TestRemoveFileBinding(t *testing.T) { if _, err := os.Stat(dstFile); err != nil { removeFileBinding(t, destination) - t.Fatalf("expected to find %s, but did not", dstFile) + t.Fatalf("expected to find %s, but could not", dstFile) } if err := RemoveFileBinding(destination); err != nil { @@ -206,6 +225,23 @@ func TestRemoveFileBinding(t *testing.T) { } if _, err := os.Stat(dstFile); err == nil { - t.Fatalf("expected %s to be gone, but it not", dstFile) + t.Fatalf("expected %s to be gone, but it is not", dstFile) + } +} + +func getWindowsBuildNumber() (int, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) + if err != nil { + return 0, fmt.Errorf("read CurrentVersion reg key: %w", err) + } + defer k.Close() + buildNumStr, _, err := k.GetStringValue("CurrentBuild") + if err != nil { + return 0, fmt.Errorf("read CurrentBuild reg value: %w", err) + } + buildNum, err := strconv.Atoi(buildNumStr) + if err != nil { + return 0, err } + return buildNum, nil } From 14245b872fd9af0c42ee18dc75e6fc2d98029c11 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Sun, 12 Feb 2023 14:01:02 -0800 Subject: [PATCH 08/10] Fix typo, add testcase * Additionally check if we can write to a read-only mount point, not just delete from it * No need to set FILE_FLAG_OPEN_REPARSE_POINT when opening a file Signed-off-by: Gabriel Adrian Samfira --- pkg/bindfilter/bind_filter.go | 31 ++++++++++++++++-------------- pkg/bindfilter/bind_filter_test.go | 11 ++++++++++- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/pkg/bindfilter/bind_filter.go b/pkg/bindfilter/bind_filter.go index 5c84cfb3..0151aace 100644 --- a/pkg/bindfilter/bind_filter.go +++ b/pkg/bindfilter/bind_filter.go @@ -235,12 +235,12 @@ func getFinalPath(pth string) (string, error) { pth = `\\.\GLOBALROOT` + pth } - han, err := getFileHandle(pth) + han, err := openPath(pth) if err != nil { return "", fmt.Errorf("fetching file handle: %w", err) } defer func() { - _ = syscall.CloseHandle(han) + _ = windows.CloseHandle(han) }() buf := make([]uint16, 100) @@ -301,22 +301,25 @@ func getBindMappingFromBuffer(buffer []byte, entry mappingEntry) (BindMapping, e }, nil } -func getFileHandle(pth string) (syscall.Handle, error) { - info, err := os.Lstat(pth) - if err != nil { - return 0, fmt.Errorf("accessing file: %w", err) - } - p, err := syscall.UTF16PtrFromString(pth) +func openPath(path string) (windows.Handle, error) { + u16, err := windows.UTF16PtrFromString(path) if err != nil { return 0, err } - attrs := uint32(syscall.FILE_FLAG_BACKUP_SEMANTICS) - if info.Mode()&os.ModeSymlink != 0 { - attrs |= syscall.FILE_FLAG_OPEN_REPARSE_POINT - } - h, err := syscall.CreateFile(p, 0, 0, nil, syscall.OPEN_EXISTING, attrs, 0) + h, err := windows.CreateFile( + u16, + 0, + windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, + nil, + windows.OPEN_EXISTING, + windows.FILE_FLAG_BACKUP_SEMANTICS, // Needed to open a directory handle. + 0) if err != nil { - return 0, err + return 0, &os.PathError{ + Op: "CreateFile", + Path: path, + Err: err, + } } return h, nil } diff --git a/pkg/bindfilter/bind_filter_test.go b/pkg/bindfilter/bind_filter_test.go index a5f2da91..38d62c9c 100644 --- a/pkg/bindfilter/bind_filter_test.go +++ b/pkg/bindfilter/bind_filter_test.go @@ -96,6 +96,15 @@ func TestApplyFileBindingReadOnly(t *testing.T) { if !errors.Is(err, os.ErrPermission) { t.Fatalf("expected an access denied error, got: %q", err) } + + // Attempt to write on the read-only mount point. + err = os.WriteFile(dstFile, []byte("something else"), 0600) + if err == nil { + t.Fatalf("should not be able to overwrite a file from a read-only mount") + } + if !errors.Is(err, os.ErrPermission) { + t.Fatalf("expected an access denied error, got: %q", err) + } } func TestEnsureOnlyOneTargetCanBeMounted(t *testing.T) { @@ -158,7 +167,7 @@ func TestGetBindMappings(t *testing.T) { if version <= 17763 { t.Skip("not supported on RS5 or earlier") } - // GetBindMappings will exoand short paths like ADMINI~1 and PROGRA~1 to their + // GetBindMappings will expand short paths like ADMINI~1 and PROGRA~1 to their // full names. In order to properly match the names later, we expand them here. srcShort := t.TempDir() source, err := getFinalPath(srcShort) From d05c10d463079cd61bbdcca38fef78b9215844f6 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Sun, 12 Feb 2023 14:24:12 -0800 Subject: [PATCH 09/10] Remove extra flags Signed-off-by: Gabriel Adrian Samfira --- pkg/bindfilter/bind_filter.go | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/pkg/bindfilter/bind_filter.go b/pkg/bindfilter/bind_filter.go index 0151aace..0d427798 100644 --- a/pkg/bindfilter/bind_filter.go +++ b/pkg/bindfilter/bind_filter.go @@ -28,26 +28,9 @@ import ( //nolint:revive // var-naming: ALL_CAPS const ( BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001 - // Generates a merged binding, mapping target entries to the virtualization root. - BINDFLT_FLAG_MERGED_BIND_MAPPING uint32 = 0x00000002 - // Use the binding mapping attached to the mapped-in job object (silo) instead of the default global mapping. - BINDFLT_FLAG_USE_CURRENT_SILO_MAPPING uint32 = 0x00000004 - BINDFLT_FLAG_REPARSE_ON_FILES uint32 = 0x00000008 - // Skips checks on file/dir creation inside a non-merged, read-only mapping. - // Only usable when READ_ONLY_MAPPING is set. - BINDFLT_FLAG_SKIP_SHARING_CHECK uint32 = 0x00000010 - BINDFLT_FLAG_CLOUD_FILES_ECPS uint32 = 0x00000020 // Tells bindflt to fail mapping with STATUS_INVALID_PARAMETER if a mapping produces // multiple targets. BINDFLT_FLAG_NO_MULTIPLE_TARGETS uint32 = 0x00000040 - // Turns on caching by asserting that the backing store for name mappings is immutable. - BINDFLT_FLAG_IMMUTABLE_BACKING uint32 = 0x00000080 - BINDFLT_FLAG_PREVENT_CASE_SENSITIVE_BINDING uint32 = 0x00000100 - // Tells bindflt to fail with STATUS_OBJECT_PATH_NOT_FOUND when a mapping is being added - // but its parent paths (ancestors) have not already been added. - BINDFLT_FLAG_EMPTY_VIRT_ROOT uint32 = 0x00000200 - BINDFLT_FLAG_NO_REPARSE_ON_ROOT uint32 = 0x10000000 - BINDFLT_FLAG_BATCHED_REMOVE_MAPPINGS uint32 = 0x20000000 ) //nolint:revive // var-naming: ALL_CAPS From 33c45b1d317d9127a248ee91bbf7a2aa10afa086 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Sun, 12 Feb 2023 14:52:50 -0800 Subject: [PATCH 10/10] Add test to account for symlinks as sources Signed-off-by: Gabriel Adrian Samfira --- pkg/bindfilter/bind_filter.go | 2 +- pkg/bindfilter/bind_filter_test.go | 53 ++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/pkg/bindfilter/bind_filter.go b/pkg/bindfilter/bind_filter.go index 0d427798..7ac377ae 100644 --- a/pkg/bindfilter/bind_filter.go +++ b/pkg/bindfilter/bind_filter.go @@ -229,7 +229,7 @@ func getFinalPath(pth string) (string, error) { buf := make([]uint16, 100) var flags uint32 = 0x0 for { - n, err := windows.GetFinalPathNameByHandle(windows.Handle(han), &buf[0], uint32(len(buf)), flags) + n, err := windows.GetFinalPathNameByHandle(han, &buf[0], uint32(len(buf)), flags) if err != nil { // if we mounted a volume that does not also have a drive letter assigned, attempting to // fetch the VOLUME_NAME_DOS will fail with os.ErrNotExist. Attempt to get the VOLUME_NAME_GUID. diff --git a/pkg/bindfilter/bind_filter_test.go b/pkg/bindfilter/bind_filter_test.go index 38d62c9c..d4450e59 100644 --- a/pkg/bindfilter/bind_filter_test.go +++ b/pkg/bindfilter/bind_filter_test.go @@ -254,3 +254,56 @@ func getWindowsBuildNumber() (int, error) { } return buildNum, nil } + +func TestGetBindMappingsSymlinks(t *testing.T) { + version, err := getWindowsBuildNumber() + if err != nil { + t.Fatalf("couldn't get version number: %s", err) + } + + if version <= 17763 { + t.Skip("not supported on RS5 or earlier") + } + + srcShort := t.TempDir() + sourceNested := filepath.Join(srcShort, "source") + if err := os.MkdirAll(sourceNested, 0600); err != nil { + t.Fatalf("failed to create folder: %s", err) + } + simlinkSource := filepath.Join(srcShort, "symlink") + if err := os.Symlink(sourceNested, simlinkSource); err != nil { + t.Fatalf("failed to create symlink: %s", err) + } + + // We'll need the long form of the source folder, as we expect bfSetupFilter() + // to resolve the symlink and create a mapping to the actual source the symlink + // points to. + source, err := getFinalPath(sourceNested) + if err != nil { + t.Fatalf("failed to get long path") + } + + dstShort := t.TempDir() + destination, err := getFinalPath(dstShort) + if err != nil { + t.Fatalf("failed to get long path") + } + + // Use the symlink as a source for the mapping. + err = ApplyFileBinding(destination, simlinkSource, false) + if err != nil { + t.Fatal(err) + } + defer removeFileBinding(t, destination) + + // We expect the mapping to point to the folder the symlink points to, not to the + // actual symlink. + hasMapping, err := checkSourceIsMountedOnDestination(source, destination) + if err != nil { + t.Fatal(err) + } + + if !hasMapping { + t.Fatalf("expected to find %s mounted on %s, but could not", source, destination) + } +}