From fc72beee65482b36bb45598c2e09f1f83baff582 Mon Sep 17 00:00:00 2001 From: Gabriel Adrian Samfira Date: Mon, 6 Feb 2023 11:17:25 -0800 Subject: [PATCH] Move bind filter to different package Signed-off-by: Gabriel Adrian Samfira --- .github/workflows/ci.yml | 7 +- .../bindfilter/bind_filter.go | 10 +- .../bindfilter/bind_filter_test.go | 49 +++++----- pkg/bindfilter/zsyscall_windows.go | 93 +++++++++++++++++++ zsyscall_windows.go | 57 +----------- 5 files changed, 137 insertions(+), 79 deletions(-) rename bind_filter.go => pkg/bindfilter/bind_filter.go (96%) rename bind_filter_test.go => pkg/bindfilter/bind_filter_test.go (81%) create mode 100644 pkg/bindfilter/zsyscall_windows.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bbd32daf..04b51ae6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,7 +70,12 @@ jobs: - uses: actions/setup-go@v3 with: go-version: ${{ env.GO_VERSION }} - - run: go test -gcflags=all=-d=checkptr -v ./... + - name: Run tests on ltsc 2019 + if: matrix.os == 'windows-2019' + run: go test -gcflags=all=-d=checkptr -v --test.run="[^TestEnsureOnlyOneTargetCanBeMounted|^TestGetBindMappings|^TestRemoveFileBinding]" ./... + - name: Run tests + if: matrix.os != 'windows-2019' + run: go test -gcflags=all=-d=checkptr -v ./... build: name: Build Repo diff --git a/bind_filter.go b/pkg/bindfilter/bind_filter.go similarity index 96% rename from bind_filter.go rename to pkg/bindfilter/bind_filter.go index b10f218a..b7bfa761 100644 --- a/bind_filter.go +++ b/pkg/bindfilter/bind_filter.go @@ -1,7 +1,7 @@ //go:build windows // +build windows -package winio +package bindfilter import ( "bytes" @@ -18,12 +18,15 @@ import ( "golang.org/x/sys/windows" ) +//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./bind_filter.go //sys bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) = bindfltapi.BfSetupFilter? //sys bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) = bindfltapi.BfRemoveMapping? //sys bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) = bindfltapi.BfGetMappings? // BfSetupFilter flags. See: // https://github.com/microsoft/BuildXL/blob/a6dce509f0d4f774255e5fbfb75fa6d5290ed163/Public/Src/Utilities/Native/Processes/Windows/NativeContainerUtilities.cs#L193-L240 +// +//nolint:revive // var-naming: ALL_CAPS const ( BINDFLT_FLAG_READ_ONLY_MAPPING uint32 = 0x00000001 // Generates a merged binding, mapping target entries to the virtualization root. @@ -48,6 +51,7 @@ const ( BINDFLT_FLAG_BATCHED_REMOVE_MAPPINGS uint32 = 0x20000000 ) +//nolint:revive // var-naming: ALL_CAPS const ( BINDFLT_GET_MAPPINGS_FLAG_VOLUME uint32 = 0x00000001 BINDFLT_GET_MAPPINGS_FLAG_SILO uint32 = 0x00000002 @@ -128,7 +132,7 @@ func GetBindMappings(volumePath string) ([]BindMapping, error) { return nil, err } - var flags uint32 = BINDFLT_GET_MAPPINGS_FLAG_VOLUME + flags := BINDFLT_GET_MAPPINGS_FLAG_VOLUME // allocate a large buffer for results var outBuffSize uint32 = 256 * 1024 buf := make([]byte, outBuffSize) @@ -224,7 +228,7 @@ func getTargetsFromBuffer(buffer []byte, offset, count int) ([]string, error) { if len(buffer) < int(tgt.TargetRootOffset)+int(tgt.TargetRootLength) { return nil, fmt.Errorf("invalid buffer") } - decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : uint32(tgt.TargetRootOffset)+uint32(tgt.TargetRootLength)]) + decoded, err := decodeEntry(buffer[tgt.TargetRootOffset : tgt.TargetRootOffset+tgt.TargetRootLength]) if err != nil { return nil, fmt.Errorf("decoding name: %w", err) } diff --git a/bind_filter_test.go b/pkg/bindfilter/bind_filter_test.go similarity index 81% rename from bind_filter_test.go rename to pkg/bindfilter/bind_filter_test.go index d4c5a401..9e7efd0c 100644 --- a/bind_filter_test.go +++ b/pkg/bindfilter/bind_filter_test.go @@ -1,7 +1,7 @@ //go:build windows // +build windows -package winio +package bindfilter import ( "errors" @@ -23,11 +23,11 @@ func TestApplyFileBinding(t *testing.T) { if err != nil { t.Fatal(err) } - defer RemoveFileBinding(destination) + defer removeFileBinding(t, destination) data := []byte("bind filter test") - if err := os.WriteFile(srcFile, data, 0755); err != nil { + if err := os.WriteFile(srcFile, data, 0600); err != nil { t.Fatal(err) } @@ -51,6 +51,12 @@ func TestApplyFileBinding(t *testing.T) { } } +func removeFileBinding(t *testing.T, mountpoint string) { + if err := RemoveFileBinding(mountpoint); err != nil { + t.Logf("failed to remove file binding from %s: %q", mountpoint, err) + } +} + func TestApplyFileBindingReadOnly(t *testing.T) { source := t.TempDir() destination := t.TempDir() @@ -62,11 +68,11 @@ func TestApplyFileBindingReadOnly(t *testing.T) { if err != nil { t.Fatal(err) } - defer RemoveFileBinding(destination) + defer removeFileBinding(t, destination) data := []byte("bind filter test") - if err := os.WriteFile(srcFile, data, 0755); err != nil { + if err := os.WriteFile(srcFile, data, 0600); err != nil { t.Fatal(err) } @@ -99,10 +105,11 @@ func TestEnsureOnlyOneTargetCanBeMounted(t *testing.T) { t.Fatal(err) } - defer RemoveFileBinding(destination) + defer removeFileBinding(t, destination) + err = ApplyFileBinding(destination, secondarySource, false) if err == nil { - RemoveFileBinding(destination) + removeFileBinding(t, destination) t.Fatalf("we should not be able to mount multiple targets in the same destination") } } @@ -150,7 +157,7 @@ func TestGetBindMappings(t *testing.T) { if err != nil { t.Fatal(err) } - defer RemoveFileBinding(destination) + defer removeFileBinding(t, destination) hasMapping, err := checkSourceIsMountedOnDestination(source, destination) if err != nil { @@ -177,32 +184,30 @@ func TestRemoveFileBinding(t *testing.T) { t.Fatalf("failed to get long path") } - err = ApplyFileBinding(destination, source, false) - if err != nil { + fileName := "testFile.txt" + srcFile := filepath.Join(source, fileName) + dstFile := filepath.Join(destination, fileName) + data := []byte("bind filter test") + + if err := os.WriteFile(srcFile, data, 0600); err != nil { t.Fatal(err) } - hasMapping, err := checkSourceIsMountedOnDestination(source, destination) + err = ApplyFileBinding(destination, source, false) if err != nil { - RemoveFileBinding(destination) t.Fatal(err) } + defer removeFileBinding(t, destination) - if !hasMapping { - RemoveFileBinding(destination) - t.Fatalf("expected to find %s mounted on %s, but could not", source, destination) + if _, err := os.Stat(dstFile); err != nil { + t.Fatalf("expected to find %s, but did not", dstFile) } if err := RemoveFileBinding(destination); err != nil { t.Fatal(err) } - hasMapping, err = checkSourceIsMountedOnDestination(source, destination) - if err != nil { - t.Fatal(err) - } - - if hasMapping { - t.Fatalf("expected to find %s unmounted from %s, but it seems to still be mounted", source, destination) + if _, err := os.Stat(dstFile); err == nil { + t.Fatalf("expected %s to be gone, but it not", dstFile) } } diff --git a/pkg/bindfilter/zsyscall_windows.go b/pkg/bindfilter/zsyscall_windows.go new file mode 100644 index 00000000..091065af --- /dev/null +++ b/pkg/bindfilter/zsyscall_windows.go @@ -0,0 +1,93 @@ +//go:build windows + +// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT. + +package bindfilter + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modbindfltapi = windows.NewLazySystemDLL("bindfltapi.dll") + + procBfGetMappings = modbindfltapi.NewProc("BfGetMappings") + procBfRemoveMapping = modbindfltapi.NewProc("BfRemoveMapping") + procBfSetupFilter = modbindfltapi.NewProc("BfSetupFilter") +) + +func bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) { + hr = procBfGetMappings.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall6(procBfGetMappings.Addr(), 6, uintptr(flags), uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(bufferSize)), uintptr(outBuffer)) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) { + hr = procBfRemoveMapping.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall(procBfRemoveMapping.Addr(), 2, uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), 0) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} + +func bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { + hr = procBfSetupFilter.Find() + if hr != nil { + return + } + r0, _, _ := syscall.Syscall6(procBfSetupFilter.Addr(), 6, uintptr(jobHandle), uintptr(flags), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(virtTargetPath)), uintptr(unsafe.Pointer(virtExceptions)), uintptr(virtExceptionPathCount)) + if int32(r0) < 0 { + if r0&0x1fff0000 == 0x00070000 { + r0 &= 0xffff + } + hr = syscall.Errno(r0) + } + return +} diff --git a/zsyscall_windows.go b/zsyscall_windows.go index ba67daa0..83f45a13 100644 --- a/zsyscall_windows.go +++ b/zsyscall_windows.go @@ -40,11 +40,10 @@ func errnoErr(e syscall.Errno) error { } var ( - modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") - modbindfltapi = windows.NewLazySystemDLL("bindfltapi.dll") - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - modntdll = windows.NewLazySystemDLL("ntdll.dll") - modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modntdll = windows.NewLazySystemDLL("ntdll.dll") + modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges") procConvertSecurityDescriptorToStringSecurityDescriptorW = modadvapi32.NewProc("ConvertSecurityDescriptorToStringSecurityDescriptorW") @@ -60,9 +59,6 @@ var ( procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW") procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken") procRevertToSelf = modadvapi32.NewProc("RevertToSelf") - procBfGetMappings = modbindfltapi.NewProc("BfGetMappings") - procBfRemoveMapping = modbindfltapi.NewProc("BfRemoveMapping") - procBfSetupFilter = modbindfltapi.NewProc("BfSetupFilter") procBackupRead = modkernel32.NewProc("BackupRead") procBackupWrite = modkernel32.NewProc("BackupWrite") procCancelIoEx = modkernel32.NewProc("CancelIoEx") @@ -253,51 +249,6 @@ func revertToSelf() (err error) { return } -func bfGetMappings(flags uint32, jobHandle windows.Handle, virtRootPath *uint16, sid *windows.SID, bufferSize *uint32, outBuffer uintptr) (hr error) { - hr = procBfGetMappings.Find() - if hr != nil { - return - } - r0, _, _ := syscall.Syscall6(procBfGetMappings.Addr(), 6, uintptr(flags), uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(bufferSize)), uintptr(outBuffer)) - if int32(r0) < 0 { - if r0&0x1fff0000 == 0x00070000 { - r0 &= 0xffff - } - hr = syscall.Errno(r0) - } - return -} - -func bfRemoveMapping(jobHandle windows.Handle, virtRootPath *uint16) (hr error) { - hr = procBfRemoveMapping.Find() - if hr != nil { - return - } - r0, _, _ := syscall.Syscall(procBfRemoveMapping.Addr(), 2, uintptr(jobHandle), uintptr(unsafe.Pointer(virtRootPath)), 0) - if int32(r0) < 0 { - if r0&0x1fff0000 == 0x00070000 { - r0 &= 0xffff - } - hr = syscall.Errno(r0) - } - return -} - -func bfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { - hr = procBfSetupFilter.Find() - if hr != nil { - return - } - r0, _, _ := syscall.Syscall6(procBfSetupFilter.Addr(), 6, uintptr(jobHandle), uintptr(flags), uintptr(unsafe.Pointer(virtRootPath)), uintptr(unsafe.Pointer(virtTargetPath)), uintptr(unsafe.Pointer(virtExceptions)), uintptr(virtExceptionPathCount)) - if int32(r0) < 0 { - if r0&0x1fff0000 == 0x00070000 { - r0 &= 0xffff - } - hr = syscall.Errno(r0) - } - return -} - func backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) { var _p0 *byte if len(b) > 0 {