diff --git a/metric/system/process/process_windows.go b/metric/system/process/process_windows.go index 61882d6ec..dbd116121 100644 --- a/metric/system/process/process_windows.go +++ b/metric/system/process/process_windows.go @@ -80,26 +80,39 @@ func GetInfoForPid(_ resolve.Resolver, pid int) (ProcState, error) { } func FetchNumThreads(pid int) (int, error) { - pHandle, err := syscall.OpenProcess( + targetProcessHandle, err := syscall.OpenProcess( xsyswindows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) if err != nil { return 0, fmt.Errorf("OpenProcess failed for PID %d: %w", pid, err) } - defer syscall.CloseHandle(pHandle) + defer syscall.CloseHandle(targetProcessHandle) + + currentProcessHandle, err := syscall.GetCurrentProcess() + if err != nil { + return 0, fmt.Errorf("GetCurrentProcess failed: %w", err) + } + // The pseudo handle need not be closed when it is no longer + // needed, calling CloseHandle has no effect. Adding here to + // remind us to close any handles we open. + defer syscall.CloseHandle(currentProcessHandle) var snapshotHandle syscall.Handle - err = PssCaptureSnapshot(pHandle, PSSCaptureThreads, 0, &snapshotHandle) + err = PssCaptureSnapshot(targetProcessHandle, PSSCaptureThreads, 0, &snapshotHandle) if err != nil { return 0, fmt.Errorf("PssCaptureSnapshot failed: %w", err) } info := PssThreadInformation{} buffSize := unsafe.Sizeof(info) - err = PssQuerySnapshot(snapshotHandle, PssQueryThreadInformation, &info, uint32(buffSize)) - if err != nil { - return 0, fmt.Errorf("PssQuerySnapshot failed: %w", err) + queryErr := PssQuerySnapshot(snapshotHandle, PssQueryThreadInformation, &info, uint32(buffSize)) + freeErr := PssFreeSnapshot(currentProcessHandle, snapshotHandle) + if queryErr != nil || freeErr != nil { + //Join discards any nil errors + return 0, errors.Join( + fmt.Errorf("PssQuerySnapshot failed: %w", queryErr), + fmt.Errorf("PssFreeSnapshot failed: %w", freeErr)) } return int(info.ThreadsCaptured), nil diff --git a/metric/system/process/syscall_windows.go b/metric/system/process/syscall_windows.go index e7d965a41..e0bb9936d 100644 --- a/metric/system/process/syscall_windows.go +++ b/metric/system/process/syscall_windows.go @@ -21,6 +21,7 @@ package process // - https://learn.microsoft.com/en-us/previous-versions/windows/desktop/proc_snap/overview-of-process-snapshotting // PssCaptureSnapshot docs in https://learn.microsoft.com/en-us/windows/win32/api/processsnapshot/nf-processsnapshot-psscapturesnapshot // PssQuerySnapshot docs in https://learn.microsoft.com/en-us/windows/win32/api/processsnapshot/nf-processsnapshot-pssquerysnapshot +// PssFreeSnapshot docs in https://learn.microsoft.com/en-us/windows/win32/api/processsnapshot/nf-processsnapshot-pssfreesnapshot // Use golang.org/x/sys/windows/mkwinsyscall instead of adriansr/mksyscall // below once https://github.com/golang/go/issues/42373 is fixed. @@ -29,6 +30,7 @@ package process //sys PssCaptureSnapshot(processHandle syscall.Handle, captureFlags PSSCaptureFlags, threadContextFlags uint32, snapshotHandle *syscall.Handle) (err error) [failretval!=0] = kernel32.PssCaptureSnapshot //sys PssQuerySnapshot(snapshotHandle syscall.Handle, informationClass uint32, buffer *PssThreadInformation, bufferLength uint32) (err error) [failretval!=0] = kernel32.PssQuerySnapshot +//sys PssFreeSnapshot(processHandle syscall.Handle, snapshotHandle syscall.Handle) (err error) [failretval!=0] = kernel32.PssFreeSnapshot // The following constants are PssQueryInformationClass as defined on // https://learn.microsoft.com/en-us/windows/win32/api/processsnapshot/ne-processsnapshot-pss_query_information_class diff --git a/metric/system/process/zsyscall_windows.go b/metric/system/process/zsyscall_windows.go index e961c7c49..e2af05e9c 100644 --- a/metric/system/process/zsyscall_windows.go +++ b/metric/system/process/zsyscall_windows.go @@ -59,6 +59,7 @@ var ( procPssCaptureSnapshot = modkernel32.NewProc("PssCaptureSnapshot") procPssQuerySnapshot = modkernel32.NewProc("PssQuerySnapshot") + procPssFreeSnapshot = modkernel32.NewProc("PssFreeSnapshot") ) func PssCaptureSnapshot(processHandle syscall.Handle, captureFlags PSSCaptureFlags, threadContextFlags uint32, snapshotHandle *syscall.Handle) (err error) { @@ -76,3 +77,11 @@ func PssQuerySnapshot(snapshotHandle syscall.Handle, informationClass uint32, bu } return } + +func PssFreeSnapshot(processHandle syscall.Handle, snapshotHandle syscall.Handle) (err error) { + r1, _, e1 := syscall.Syscall6(procPssFreeSnapshot.Addr(), 2, uintptr(processHandle), uintptr(snapshotHandle), 0, 0, 0, 0) + if r1 != 0 { + err = errnoErr(e1) + } + return +}