Skip to content

Commit

Permalink
fix: signal SIGINT/SIGTERM in windows correctly
Browse files Browse the repository at this point in the history
emissary tries to send a signal but `os/Process.Kill` only supports
sending SIGKILL and returns an error for all other cases.

Using code found in hcsshim this changes signal handling in emissary for
windows by translating SIGINT and SIGTERM to their appropriate windows
signal and sending it to the process.

Signed-off-by: Michael Weibel <[email protected]>
  • Loading branch information
mweibel committed Oct 3, 2024
1 parent ba75efb commit e43d4e1
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 5 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ require (
go.opencensus.io v0.24.0 // indirect
go.starlark.net v0.0.0-20230525235612-a134d8f9ddca // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/sys v0.21.0
golang.org/x/term v0.21.0
golang.org/x/text v0.16.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
Expand Down
101 changes: 97 additions & 4 deletions workflow/executor/os-specific/signal_windows.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
package os_specific

import (
"fmt"
"os"
"syscall"
"unsafe"

"github.com/argoproj/argo-workflows/v3/util/errors"
"golang.org/x/sys/windows"
)

var (
Term = os.Interrupt

modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procCreateRemoteThread = modkernel32.NewProc("CreateRemoteThread")
procCtrlRoutine = modkernel32.NewProc("CtrlRoutine")
)

func CanIgnoreSignal(s os.Signal) bool {
Expand All @@ -19,11 +26,23 @@ func Kill(pid int, s syscall.Signal) error {
if pid < 0 {
pid = -pid // // we cannot kill a negative process on windows
}
p, err := os.FindProcess(pid)
if err != nil {
return err

winSignal := -1
switch s {
case syscall.SIGTERM:
winSignal = windows.CTRL_SHUTDOWN_EVENT
case syscall.SIGINT:
winSignal = windows.CTRL_C_EVENT
}
return p.Signal(s)

if winSignal == -1 {
p, err := os.FindProcess(pid)
if err != nil {
return err
}
return p.Signal(s)
}
return signalProcess(uint32(pid), winSignal)
}

func Setpgid(a *syscall.SysProcAttr) {
Expand All @@ -37,3 +56,77 @@ func Wait(process *os.Process) error {
}
return err
}

// signalProcess sends the specified signal to a process.
//
// Code +/- copied from: https://github.com/microsoft/hcsshim/blob/1d69a9c658655b77dd4e5275bff99caad6b38416/internal/jobcontainers/process.go#L251
// License: MIT
// Author: Microsoft
func signalProcess(pid uint32, signal int) error {
hProc, err := windows.OpenProcess(windows.PROCESS_TERMINATE, true, pid)
if err != nil {
return fmt.Errorf("failed to open process: %w", err)
}
defer func() {
_ = windows.Close(hProc)
}()

if err := procCtrlRoutine.Find(); err != nil {
return fmt.Errorf("failed to load CtrlRoutine: %w", err)
}

threadHandle, err := createRemoteThread(hProc, nil, 0, procCtrlRoutine.Addr(), uintptr(signal), 0, nil)
if err != nil {
return fmt.Errorf("failed to open remote thread in target process %d: %w", pid, err)
}
defer func() {
_ = windows.Close(windows.Handle(threadHandle))
}()
return nil
}

// Following code has been generated using github.com/Microsoft/go-winio/tools/mkwinsyscall and inlined
// for easier usage

// HANDLE CreateRemoteThread(
//
// HANDLE hProcess,
// LPSECURITY_ATTRIBUTES lpThreadAttributes,
// SIZE_T dwStackSize,
// LPTHREAD_START_ROUTINE lpStartAddress,
// LPVOID lpParameter,
// DWORD dwCreationFlags,
// LPDWORD lpThreadId
//
// );
func createRemoteThread(process windows.Handle, sa *windows.SecurityAttributes, stackSize uint32, startAddr uintptr, parameter uintptr, creationFlags uint32, threadID *uint32) (handle windows.Handle, err error) {
r0, _, e1 := syscall.SyscallN(procCreateRemoteThread.Addr(), uintptr(process), uintptr(unsafe.Pointer(sa)), uintptr(stackSize), uintptr(startAddr), uintptr(parameter), uintptr(creationFlags), uintptr(unsafe.Pointer(threadID)))
handle = windows.Handle(r0)
if handle == 0 {
err = errnoErr(e1)
}
return
}

// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)

var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)

// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
return e
}
47 changes: 47 additions & 0 deletions workflow/executor/os-specific/signal_windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//go:build windows

package os_specific

import (
"bytes"
"os/exec"
"sync"
"syscall"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestKill(t *testing.T) {
shell := "pwsh.exe"
cmd := exec.Command(shell, "-c", `echo "running"; while(1) { sleep 600000 }`)
var stdout bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stdout

_, err := StartCommand(cmd)
require.NoError(t, err)

var wg sync.WaitGroup
go func() {
wg.Add(1)
defer wg.Done()

err = cmd.Wait()
// we'll get an exit code
assert.Error(t, err)
}()

// Wait for echo to have run before calling Kill
time.Sleep(500 * time.Millisecond)

err = Kill(cmd.Process.Pid, syscall.SIGTERM)
require.NoError(t, err)

wg.Wait()

expected := "running\r\n"
assert.Equal(t, expected, stdout.String())
}

0 comments on commit e43d4e1

Please sign in to comment.