diff --git a/pkg/etw/activityid.go b/pkg/etw/activityid.go new file mode 100644 index 00000000..ff17ebbd --- /dev/null +++ b/pkg/etw/activityid.go @@ -0,0 +1,74 @@ +//go:build windows + +package etw + +import "github.com/Microsoft/go-winio/pkg/guid" + +type eventActivityIDControlCode uint32 + +//nolint:unused // all values listed here for completeness. +const ( + // Sets the ActivityId parameter to the value of the current thread's activity ID. + getEventActivityID eventActivityIDControlCode = iota + 1 + // Sets the current thread's activity ID to the value of the ActivityId parameter. + setEventActivityID + // Sets the ActivityId parameter to the value of a newly-generated locally-unique activity ID. + createEventActivityID + // Swaps the values of the ActivityId parameter and the current thread's activity ID. + // (Saves the value of the current thread's activity ID, then sets the current thread's activity ID to + // the value of the ActivityId parameter, then sets the ActivityId parameter to the saved value.) + getSetEventActivityID + // Sets the ActivityId parameter to the value of the current thread's activity ID, + // then sets the current thread's activity ID to the value of a newly-generated locally-unique activity ID + createSetEventActivityID +) + +// Activity ID is thread local, but since go doesn't expose a way to initialize threads, +// we have no way of calling this for all threads, or even knowing if the current thread +// was initialized without a syscall to [eventActivityIdControl] + +// InitializeThreadActivityID checks if the current thread's activity ID is empty, and, if so, +// creates a new activity ID for the thread. +// +// Subsequent ETW calls from this thread will use that Activity ID, if no ID is specified. +// +// See [EventActivityIdControl] for more information. +// +// [EventActivityIdControl]: https://learn.microsoft.com/en-us/windows/win32/api/evntprov/nf-evntprov-eventactivityidcontrol +func InitializeThreadActivityID() (guid.GUID, error) { + // check if the current thread is intialized + var g guid.GUID + if err := eventActivityIdControl(getEventActivityID, &g); err != nil { + return guid.GUID{}, err + } + if !g.IsEmpty() { + return g, nil + } + + // create a new activity ID + if err := eventActivityIdControl(createEventActivityID, &g); err != nil { + return guid.GUID{}, err + } + + // set the ID + if err := eventActivityIdControl(setEventActivityID, &g); err != nil { + return guid.GUID{}, err + } + return g, nil +} + +// GetThreadActivityID returns the current thread's activity ID. +// +// See [InitializeThreadActivityID] for more details. +func GetThreadActivityID() (guid.GUID, error) { + var g guid.GUID + err := eventActivityIdControl(getEventActivityID, &g) + return g, err +} + +// SetThreadActivityID returns the current thread's activity ID. +// +// See [InitializeThreadActivityID] for more details. +func SetThreadActivityID(g guid.GUID) error { + return eventActivityIdControl(setEventActivityID, &g) +} diff --git a/pkg/etw/newprovider.go b/pkg/etw/newprovider.go index 3669b4f7..e8027f38 100644 --- a/pkg/etw/newprovider.go +++ b/pkg/etw/newprovider.go @@ -40,7 +40,7 @@ func NewProviderWithOptions(name string, options ...ProviderOpt) (provider *Prov provider.ID = opts.id provider.callback = opts.callback - if err := eventRegister((*windows.GUID)(&provider.ID), globalProviderCallback, uintptr(provider.index), &provider.handle); err != nil { + if err := eventRegister(&provider.ID, globalProviderCallback, uintptr(provider.index), &provider.handle); err != nil { return nil, err } diff --git a/pkg/etw/provider.go b/pkg/etw/provider.go index 8174bff1..cbec7c88 100644 --- a/pkg/etw/provider.go +++ b/pkg/etw/provider.go @@ -10,7 +10,6 @@ import ( "unicode/utf16" "github.com/Microsoft/go-winio/pkg/guid" - "golang.org/x/sys/windows" ) // Provider represents an ETW event provider. It is identified by a provider @@ -277,7 +276,17 @@ func (provider *Provider) writeEventRaw( activityID guid.GUID, relatedActivityID guid.GUID, metadataBlobs [][]byte, - dataBlobs [][]byte) error { + dataBlobs [][]byte, +) error { + // Passing in an empty activity ID will override the thread's activity ID, so set it nil + // if no activity ID is specified. + // + // https://learn.microsoft.com/en-us/windows/win32/api/evntprov/nf-evntprov-eventactivityidcontrol#remarks + pActID := (*guid.GUID)(nil) + if !activityID.IsEmpty() { + pActID = &activityID + } + dataDescriptorCount := uint32(1 + len(metadataBlobs) + len(dataBlobs)) dataDescriptors := make([]eventDataDescriptor, 0, dataDescriptorCount) @@ -294,8 +303,8 @@ func (provider *Provider) writeEventRaw( return eventWriteTransfer(provider.handle, descriptor, - (*windows.GUID)(&activityID), - (*windows.GUID)(&relatedActivityID), + pActID, + &relatedActivityID, dataDescriptorCount, &dataDescriptors[0]) } diff --git a/pkg/etw/syscall.go b/pkg/etw/syscall.go index 16f3bb13..04a21edc 100644 --- a/pkg/etw/syscall.go +++ b/pkg/etw/syscall.go @@ -2,14 +2,23 @@ package etw -//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go syscall.go +//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -imports "github.com/Microsoft/go-winio/pkg/guid" -output zsyscall_windows.go syscall.go -//sys eventRegister(providerId *windows.GUID, callback uintptr, callbackContext uintptr, providerHandle *providerHandle) (win32err error) = advapi32.EventRegister +//sys eventRegister(providerId *guid.GUID, callback uintptr, callbackContext uintptr, providerHandle *providerHandle) (win32err error) = advapi32.EventRegister //sys eventUnregister_64(providerHandle providerHandle) (win32err error) = advapi32.EventUnregister -//sys eventWriteTransfer_64(providerHandle providerHandle, descriptor *eventDescriptor, activityID *windows.GUID, relatedActivityID *windows.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) = advapi32.EventWriteTransfer +//sys eventWriteTransfer_64(providerHandle providerHandle, descriptor *eventDescriptor, activityID *guid.GUID, relatedActivityID *guid.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) = advapi32.EventWriteTransfer //sys eventSetInformation_64(providerHandle providerHandle, class eventInfoClass, information uintptr, length uint32) (win32err error) = advapi32.EventSetInformation //sys eventUnregister_32(providerHandle_low uint32, providerHandle_high uint32) (win32err error) = advapi32.EventUnregister -//sys eventWriteTransfer_32(providerHandle_low uint32, providerHandle_high uint32, descriptor *eventDescriptor, activityID *windows.GUID, relatedActivityID *windows.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) = advapi32.EventWriteTransfer +//sys eventWriteTransfer_32(providerHandle_low uint32, providerHandle_high uint32, descriptor *eventDescriptor, activityID *guid.GUID, relatedActivityID *guid.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) = advapi32.EventWriteTransfer //sys eventSetInformation_32(providerHandle_low uint32, providerHandle_high uint32, class eventInfoClass, information uintptr, length uint32) (win32err error) = advapi32.EventSetInformation + +// ULONG EVNTAPI EventActivityIdControl( +// [in] ULONG ControlCode, +// [in, out] LPGUID ActivityId +// ); +// +// https://learn.microsoft.com/en-us/windows/win32/api/evntprov/nf-evntprov-eventactivityidcontrol +// +//sys eventActivityIdControl(code eventActivityIDControlCode, activityID *guid.GUID) (win32err error)= advapi32.EventActivityIdControl? diff --git a/pkg/etw/wrapper_32.go b/pkg/etw/wrapper_32.go index 14c49984..2a3dede5 100644 --- a/pkg/etw/wrapper_32.go +++ b/pkg/etw/wrapper_32.go @@ -24,8 +24,8 @@ func eventUnregister(providerHandle providerHandle) (win32err error) { func eventWriteTransfer( providerHandle providerHandle, descriptor *eventDescriptor, - activityID *windows.GUID, - relatedActivityID *windows.GUID, + activityID *guid.GUID, + relatedActivityID *guid.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) { diff --git a/pkg/etw/wrapper_64.go b/pkg/etw/wrapper_64.go index 8cfe2e8c..509cc9ff 100644 --- a/pkg/etw/wrapper_64.go +++ b/pkg/etw/wrapper_64.go @@ -6,7 +6,6 @@ package etw import ( "github.com/Microsoft/go-winio/pkg/guid" - "golang.org/x/sys/windows" ) func eventUnregister(providerHandle providerHandle) (win32err error) { @@ -16,8 +15,8 @@ func eventUnregister(providerHandle providerHandle) (win32err error) { func eventWriteTransfer( providerHandle providerHandle, descriptor *eventDescriptor, - activityID *windows.GUID, - relatedActivityID *windows.GUID, + activityID *guid.GUID, + relatedActivityID *guid.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) { return eventWriteTransfer_64( diff --git a/pkg/etw/zsyscall_windows.go b/pkg/etw/zsyscall_windows.go index c78a6ed5..bf7616c5 100644 --- a/pkg/etw/zsyscall_windows.go +++ b/pkg/etw/zsyscall_windows.go @@ -8,6 +8,7 @@ import ( "syscall" "unsafe" + "github.com/Microsoft/go-winio/pkg/guid" "golang.org/x/sys/windows" ) @@ -42,13 +43,26 @@ func errnoErr(e syscall.Errno) error { var ( modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") - procEventRegister = modadvapi32.NewProc("EventRegister") - procEventSetInformation = modadvapi32.NewProc("EventSetInformation") - procEventUnregister = modadvapi32.NewProc("EventUnregister") - procEventWriteTransfer = modadvapi32.NewProc("EventWriteTransfer") + procEventActivityIdControl = modadvapi32.NewProc("EventActivityIdControl") + procEventRegister = modadvapi32.NewProc("EventRegister") + procEventSetInformation = modadvapi32.NewProc("EventSetInformation") + procEventUnregister = modadvapi32.NewProc("EventUnregister") + procEventWriteTransfer = modadvapi32.NewProc("EventWriteTransfer") ) -func eventRegister(providerId *windows.GUID, callback uintptr, callbackContext uintptr, providerHandle *providerHandle) (win32err error) { +func eventActivityIdControl(code eventActivityIDControlCode, activityID *guid.GUID) (win32err error) { + win32err = procEventActivityIdControl.Find() + if win32err != nil { + return + } + r0, _, _ := syscall.Syscall(procEventActivityIdControl.Addr(), 2, uintptr(code), uintptr(unsafe.Pointer(activityID)), 0) + if r0 != 0 { + win32err = syscall.Errno(r0) + } + return +} + +func eventRegister(providerId *guid.GUID, callback uintptr, callbackContext uintptr, providerHandle *providerHandle) (win32err error) { r0, _, _ := syscall.Syscall6(procEventRegister.Addr(), 4, uintptr(unsafe.Pointer(providerId)), uintptr(callback), uintptr(callbackContext), uintptr(unsafe.Pointer(providerHandle)), 0, 0) if r0 != 0 { win32err = syscall.Errno(r0) @@ -88,7 +102,7 @@ func eventUnregister_32(providerHandle_low uint32, providerHandle_high uint32) ( return } -func eventWriteTransfer_64(providerHandle providerHandle, descriptor *eventDescriptor, activityID *windows.GUID, relatedActivityID *windows.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) { +func eventWriteTransfer_64(providerHandle providerHandle, descriptor *eventDescriptor, activityID *guid.GUID, relatedActivityID *guid.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) { r0, _, _ := syscall.Syscall6(procEventWriteTransfer.Addr(), 6, uintptr(providerHandle), uintptr(unsafe.Pointer(descriptor)), uintptr(unsafe.Pointer(activityID)), uintptr(unsafe.Pointer(relatedActivityID)), uintptr(dataDescriptorCount), uintptr(unsafe.Pointer(dataDescriptors))) if r0 != 0 { win32err = syscall.Errno(r0) @@ -96,7 +110,7 @@ func eventWriteTransfer_64(providerHandle providerHandle, descriptor *eventDescr return } -func eventWriteTransfer_32(providerHandle_low uint32, providerHandle_high uint32, descriptor *eventDescriptor, activityID *windows.GUID, relatedActivityID *windows.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) { +func eventWriteTransfer_32(providerHandle_low uint32, providerHandle_high uint32, descriptor *eventDescriptor, activityID *guid.GUID, relatedActivityID *guid.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) { r0, _, _ := syscall.Syscall9(procEventWriteTransfer.Addr(), 7, uintptr(providerHandle_low), uintptr(providerHandle_high), uintptr(unsafe.Pointer(descriptor)), uintptr(unsafe.Pointer(activityID)), uintptr(unsafe.Pointer(relatedActivityID)), uintptr(dataDescriptorCount), uintptr(unsafe.Pointer(dataDescriptors)), 0, 0) if r0 != 0 { win32err = syscall.Errno(r0) diff --git a/pkg/guid/guid.go b/pkg/guid/guid.go index 48ce4e92..01ed3666 100644 --- a/pkg/guid/guid.go +++ b/pkg/guid/guid.go @@ -220,7 +220,7 @@ func (g GUID) MarshalText() ([]byte, error) { return []byte(g.String()), nil } -// UnmarshalText takes the textual representation of a GUID, and unmarhals it +// UnmarshalText takes the textual representation of a GUID, and unmarshals it // into this GUID. func (g *GUID) UnmarshalText(text []byte) error { g2, err := FromString(string(text)) @@ -230,3 +230,9 @@ func (g *GUID) UnmarshalText(text []byte) error { *g = g2 return nil } + +// hopefully this saves on allocating a new GUID per g.Empty() call +var empty GUID + +// IsEmpty returns if the GUID is equal to 00000000-0000-0000-0000-000000000000 +func (g *GUID) IsEmpty() bool { return *g == empty } diff --git a/tools/mkwinsyscall/mkwinsyscall.go b/tools/mkwinsyscall/mkwinsyscall.go index 20d9e3d2..3a8cf8da 100644 --- a/tools/mkwinsyscall/mkwinsyscall.go +++ b/tools/mkwinsyscall/mkwinsyscall.go @@ -53,6 +53,7 @@ var ( winio = flag.Bool("winio", false, `import this package ("github.com/Microsoft/go-winio")`) utf16 = flag.Bool("utf16", true, "encode string arguments as UTF-16 for syscalls not ending in 'A' or 'W'") sortdecls = flag.Bool("sort", true, "sort DLL and function declarations") + extraImports = flag.String("imports", "", "comma separated list of additional `packages` to import") ) func trim(s string) string { @@ -873,6 +874,12 @@ func (src *Source) Generate(w io.Writer) error { if packageName != "syscall" { src.Import("syscall") } + if *extraImports != "" { + for _, pkg := range strings.Split(*extraImports, ",") { + src.ExternalImport(pkg) + } + } + funcMap := template.FuncMap{ "packagename": packagename, "syscalldot": syscalldot,