Skip to content

Commit

Permalink
Fix bug that prevented dispatcher exit with downed DB (ansible#14469)
Browse files Browse the repository at this point in the history
* Separate handling of original sitTERM and sigINT
  • Loading branch information
AlanCoding authored Oct 26, 2023
1 parent bef0a8b commit fc0b58f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 9 deletions.
29 changes: 22 additions & 7 deletions awx/main/tasks/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,54 @@ class SignalExit(Exception):
class SignalState:
def reset(self):
self.sigterm_flag = False
self.is_active = False
self.sigint_flag = False

self.is_active = False # for nested context managers
self.original_sigterm = None
self.original_sigint = None
self.raise_exception = False

def __init__(self):
self.reset()

def set_flag(self, *args):
"""Method to pass into the python signal.signal method to receive signals"""
self.sigterm_flag = True
def raise_if_needed(self):
if self.raise_exception:
self.raise_exception = False # so it is not raised a second time in error handling
raise SignalExit()

def set_sigterm_flag(self, *args):
self.sigterm_flag = True
self.raise_if_needed()

def set_sigint_flag(self, *args):
self.sigint_flag = True
self.raise_if_needed()

def connect_signals(self):
self.original_sigterm = signal.getsignal(signal.SIGTERM)
self.original_sigint = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGTERM, self.set_flag)
signal.signal(signal.SIGINT, self.set_flag)
signal.signal(signal.SIGTERM, self.set_sigterm_flag)
signal.signal(signal.SIGINT, self.set_sigint_flag)
self.is_active = True

def restore_signals(self):
signal.signal(signal.SIGTERM, self.original_sigterm)
signal.signal(signal.SIGINT, self.original_sigint)
# if we got a signal while context manager was active, call parent methods.
if self.sigterm_flag:
if callable(self.original_sigterm):
self.original_sigterm()
if self.sigint_flag:
if callable(self.original_sigint):
self.original_sigint()
self.reset()


signal_state = SignalState()


def signal_callback():
return signal_state.sigterm_flag
return bool(signal_state.sigterm_flag or signal_state.sigint_flag)


def with_signal_handling(f):
Expand Down
48 changes: 46 additions & 2 deletions awx/main/tests/unit/tasks/test_signals.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,43 @@
import signal
import functools

from awx.main.tasks.signals import signal_state, signal_callback, with_signal_handling


def pytest_sigint():
pytest_sigint.called_count += 1


def pytest_sigterm():
pytest_sigterm.called_count += 1


def tmp_signals_for_test(func):
"""
When we run our internal signal handlers, it will call the original signal
handlers when its own work is finished.
This would crash the test runners normally, because those methods will
shut down the process.
So this is a decorator to safely replace existing signal handlers
with new signal handlers that do nothing so that tests do not crash.
"""

@functools.wraps(func)
def wrapper():
original_sigterm = signal.getsignal(signal.SIGTERM)
original_sigint = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGTERM, pytest_sigterm)
signal.signal(signal.SIGINT, pytest_sigint)
pytest_sigterm.called_count = 0
pytest_sigint.called_count = 0
func()
signal.signal(signal.SIGTERM, original_sigterm)
signal.signal(signal.SIGINT, original_sigint)

return wrapper


@tmp_signals_for_test
def test_outer_inner_signal_handling():
"""
Even if the flag is set in the outer context, its value should persist in the inner context
Expand All @@ -15,17 +50,22 @@ def f2():
@with_signal_handling
def f1():
assert signal_callback() is False
signal_state.set_flag()
signal_state.set_sigterm_flag()
assert signal_callback()
f2()

original_sigterm = signal.getsignal(signal.SIGTERM)
assert signal_callback() is False
assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0
f1()
assert signal_callback() is False
assert signal.getsignal(signal.SIGTERM) is original_sigterm
assert pytest_sigterm.called_count == 1
assert pytest_sigint.called_count == 0


@tmp_signals_for_test
def test_inner_outer_signal_handling():
"""
Even if the flag is set in the inner context, its value should persist in the outer context
Expand All @@ -34,7 +74,7 @@ def test_inner_outer_signal_handling():
@with_signal_handling
def f2():
assert signal_callback() is False
signal_state.set_flag()
signal_state.set_sigint_flag()
assert signal_callback()

@with_signal_handling
Expand All @@ -45,6 +85,10 @@ def f1():

original_sigterm = signal.getsignal(signal.SIGTERM)
assert signal_callback() is False
assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0
f1()
assert signal_callback() is False
assert signal.getsignal(signal.SIGTERM) is original_sigterm
assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 1

0 comments on commit fc0b58f

Please sign in to comment.