Skip to content

Commit

Permalink
Tests for socket inheritance
Browse files Browse the repository at this point in the history
Closes: #3
  • Loading branch information
waveform80 committed Mar 1, 2024
1 parent d3cd911 commit 7f2fa48
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 36 deletions.
9 changes: 4 additions & 5 deletions nobodd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def __init__(self, server_address, boards):
self.boards = boards
self.images = {}
if isinstance(server_address, int):
st = os.fstat(server_address)
if not stat.S_ISSOCK(st.st_mode):
if not stat.S_ISSOCK(os.fstat(server_address).st_mode):
raise RuntimeError(
f'inherited fd {server_address} is not a socket')
# If we've been passed an fd directly, we don't actually want the
Expand Down Expand Up @@ -235,18 +234,18 @@ def main(args=None):
board.serial: board
for board in conf.boards
}
if not boards:
raise RuntimeError('No boards defined')

if conf.listen == 'stdin':
# Yes, this should always be zero but ... just in case
server_address = sys.stdin.fileno()
elif conf.listen == 'stdout':
server_address = sys.stdout.fileno()
elif conf.listen == 'systemd':
fds = sd.listen_fds()
if len(fds) != 1:
raise RuntimeError(
f'Expected 1 fd from systemd but got {len(fds)}')
server_address = fds.pop()
server_address, name = fds.popitem()
else:
server_address = (conf.listen, conf.port)

Expand Down
1 change: 1 addition & 0 deletions nobodd/systemd.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def listen_fds(self):
:class:`dict` mapping each file-descriptor to its name, or the string
"unknown" if no name was given.
"""
print(repr(os.environ), flush=True)
try:
if int(os.environ['LISTEN_PID']) != os.getpid():
raise ValueError('wrong LISTEN_PID')
Expand Down
222 changes: 191 additions & 31 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def run(self):
class MyBootServer(BootServer):
def server_bind(slf):
super().server_bind()
self.address = slf.server_address
if self.address is None:
self.address = slf.server_address

try:
with mock.patch('nobodd.server.BootServer', MyBootServer):
Expand All @@ -46,7 +47,9 @@ def wait_for_ready(self, capsys):
capture = capsys.readouterr()
if 'Ready' in capture.err:
return
sleep(0.1)
self.join(0.1)
if not self.is_alive():
assert False, 'service died before becoming ready'
assert False, 'service did not become ready'

def __enter__(self):
Expand Down Expand Up @@ -79,7 +82,10 @@ def test_help(capsys):


def test_ctrl_c(main_thread, capsys):
main_thread.argv = ['--listen', '127.0.0.1', '--port', '0']
main_thread.argv = [
'--listen', '127.0.0.1', '--port', '0',
'--board', f'1234abcd,foo.img',
]
with main_thread:
os.kill(os.getpid(), signal.SIGINT)
capture = capsys.readouterr()
Expand All @@ -89,15 +95,47 @@ def test_ctrl_c(main_thread, capsys):


def test_sigterm(main_thread, capsys):
main_thread.argv = ['--listen', '127.0.0.1', '--port', '0']
main_thread.argv = [
'--listen', '127.0.0.1', '--port', '0',
'--board', f'1234abcd,foo.img',
]
with main_thread:
os.kill(os.getpid(), signal.SIGTERM)
capture = capsys.readouterr()
assert capture.err.strip().endswith('Terminated')
assert main_thread.exception is None
assert main_thread.exit_code == 0


def test_sighup(main_thread, capsys):
main_thread.argv = [
'--listen', '127.0.0.1', '--port', '0',
'--board', f'1234abcd,foo.img',
]
with main_thread:
os.kill(os.getpid(), signal.SIGHUP)
os.kill(os.getpid(), signal.SIGTERM)
capture = capsys.readouterr()
assert 'Reloading configuration' in capture.err.strip()
assert capture.err.strip().endswith('Terminated')
assert main_thread.exception is None
assert main_thread.exit_code == 0


def test_error_exit_no_boards(main_thread, capsys, monkeypatch):
with \
monkeypatch.context() as m:

m.delenv('DEBUG', raising=False)
main_thread.argv = ['--listen', '127.0.0.1', '--port', '0']
with main_thread:
pass
capture = capsys.readouterr()
assert 'No boards defined' in capture.err
assert main_thread.exception is None
assert main_thread.exit_code == 1


def test_error_exit_no_debug(main_thread, capsys, monkeypatch):
with \
mock.patch('nobodd.server.get_parser') as get_parser, \
Expand Down Expand Up @@ -144,48 +182,170 @@ def test_error_exit_with_pdb(main_thread, capsys, monkeypatch):
assert post_mortem.called


def test_regular_operation(fat16_disk, main_thread, capsys):
def test_regular_operation(fat16_disk, main_thread, capsys, monkeypatch):
with \
disk.DiskImage(fat16_disk) as img, \
fs.FatFileSystem(img.partitions[1].data) as boot:

expected = (boot.root / 'random').read_bytes()

main_thread.argv = [
'--listen', '127.0.0.1', '--port', '0',
'--board', f'1234abcd,{fat16_disk}',
]
with main_thread:
main_thread.wait_for_ready(capsys)
with monkeypatch.context() as m:
m.delenv('DEBUG', raising=False)
main_thread.argv = [
'--listen', '127.0.0.1', '--port', '0',
'--board', f'1234abcd,{fat16_disk}',
]
with main_thread:
main_thread.wait_for_ready(capsys)

with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as client:
# Start a valid transfer from client...
client.settimeout(10)
client.sendto(
bytes(tftp.RRQPacket('1234abcd/random', 'octet')),
main_thread.address)
received = []
for block, offset in enumerate(range(0, len(expected), 512), start=1):
buf, addr = client.recvfrom(1500)
pkt = tftp.Packet.from_bytes(buf)
assert isinstance(pkt, tftp.DATAPacket)
assert pkt.block == block
received.append(pkt.data)
client.sendto(bytes(tftp.ACKPacket(pkt.block)), addr)
# Because random is a precise multiple of the block size, there
# should be one final (empty) DATA packet
buf, addr = client.recvfrom(1500)
pkt = tftp.Packet.from_bytes(buf)
assert isinstance(pkt, tftp.DATAPacket)
assert pkt.block == block + 1
assert pkt.data == b''
client.sendto(bytes(tftp.ACKPacket(pkt.block)), addr)
assert b''.join(received) == expected

with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as client:
# Start a valid transfer from client...
client.settimeout(10)
client.sendto(
bytes(tftp.RRQPacket('1234abcd/random', 'octet')),
main_thread.address)
received = []
for block, offset in enumerate(range(0, len(expected), 512), start=1):

def test_bad_listen_stdin(main_thread, capsys, tmp_path, monkeypatch):
with (tmp_path / 'foo').open('wb') as f:
with mock.patch('nobodd.server.sys.stdin', f), monkeypatch.context() as m:
m.delenv('DEBUG', raising=False)
main_thread.argv = [
'--listen', 'stdin',
'--board', '1234abcd,foo.img',
]
with main_thread:
pass
capture = capsys.readouterr()
assert f'inherited fd {f.fileno()} is not a socket' in capture.err
assert main_thread.exit_code == 1


def test_listen_stdin(fat16_disk, main_thread, capsys, monkeypatch):
with \
disk.DiskImage(fat16_disk) as img, \
fs.FatFileSystem(img.partitions[1].data) as boot:

expected = (boot.root / 'random').read_bytes()

sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('127.0.0.1', 0))
with monkeypatch.context() as m, mock.patch('nobodd.server.sys.stdin', sock):
m.delenv('DEBUG', raising=False)
main_thread.argv = [
'--listen', 'stdin', '--board', f'1234abcd,{fat16_disk}',
]
main_thread.address = sock.getsockname()
with main_thread:
main_thread.wait_for_ready(capsys)

with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as client:
# Start a valid transfer from client...
client.settimeout(10)
client.sendto(
bytes(tftp.RRQPacket('1234abcd/random', 'octet')),
main_thread.address)
received = []
for block, offset in enumerate(range(0, len(expected), 512), start=1):
buf, addr = client.recvfrom(1500)
pkt = tftp.Packet.from_bytes(buf)
assert isinstance(pkt, tftp.DATAPacket)
assert pkt.block == block
received.append(pkt.data)
client.sendto(bytes(tftp.ACKPacket(pkt.block)), addr)
# Because random is a precise multiple of the block size, there
# should be one final (empty) DATA packet
buf, addr = client.recvfrom(1500)
pkt = tftp.Packet.from_bytes(buf)
assert isinstance(pkt, tftp.DATAPacket)
assert pkt.block == block
received.append(pkt.data)
assert pkt.block == block + 1
assert pkt.data == b''
client.sendto(bytes(tftp.ACKPacket(pkt.block)), addr)
# Because random is a precise multiple of the block size, there should
# be one final (empty) DATA packet
buf, addr = client.recvfrom(1500)
pkt = tftp.Packet.from_bytes(buf)
assert isinstance(pkt, tftp.DATAPacket)
assert pkt.block == block + 1
assert pkt.data == b''
client.sendto(bytes(tftp.ACKPacket(pkt.block)), addr)
assert b''.join(received) == expected
assert b''.join(received) == expected


def test_bad_listen_systemd(main_thread, capsys, monkeypatch):
with monkeypatch.context() as m:
m.delenv('DEBUG', raising=False)
m.setenv('LISTEN_PID', str(os.getpid()))
m.setenv('LISTEN_FDS', '2')
main_thread.argv = [
'--listen', 'systemd',
'--board', '1234abcd,foo.img',
]
with main_thread:
pass
capture = capsys.readouterr()
assert f'Expected 1 fd from systemd but got 2' in capture.err
assert main_thread.exit_code == 1


def test_listen_systemd(fat16_disk, main_thread, capsys, monkeypatch):
with \
disk.DiskImage(fat16_disk) as img, \
fs.FatFileSystem(img.partitions[1].data) as boot:

expected = (boot.root / 'random').read_bytes()

sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('127.0.0.1', 0))
with monkeypatch.context() as m, \
mock.patch('nobodd.systemd.Systemd.LISTEN_FDS_START', sock.fileno()):
m.delenv('DEBUG', raising=False)
m.setenv('LISTEN_PID', str(os.getpid()))
m.setenv('LISTEN_FDS', '1')
main_thread.argv = [
'--listen', 'systemd', '--board', f'1234abcd,{fat16_disk}',
]
main_thread.address = sock.getsockname()
with main_thread:
main_thread.wait_for_ready(capsys)

with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as client:
# Start a valid transfer from client...
client.settimeout(10)
client.sendto(
bytes(tftp.RRQPacket('1234abcd/random', 'octet')),
main_thread.address)
received = []
for block, offset in enumerate(range(0, len(expected), 512), start=1):
buf, addr = client.recvfrom(1500)
pkt = tftp.Packet.from_bytes(buf)
assert isinstance(pkt, tftp.DATAPacket)
assert pkt.block == block
received.append(pkt.data)
client.sendto(bytes(tftp.ACKPacket(pkt.block)), addr)
# Because random is a precise multiple of the block size, there
# should be one final (empty) DATA packet
buf, addr = client.recvfrom(1500)
pkt = tftp.Packet.from_bytes(buf)
assert isinstance(pkt, tftp.DATAPacket)
assert pkt.block == block + 1
assert pkt.data == b''
client.sendto(bytes(tftp.ACKPacket(pkt.block)), addr)
assert b''.join(received) == expected


def test_bad_requests(fat16_disk, main_thread, capsys):
main_thread.argv = [
'--listen', '127.0.0.1', '--port', '54321',
'--listen', '127.0.0.1', '--port', '0',
'--board', f'1234abcd,{fat16_disk}',
'--board', f'5678abcd,{fat16_disk},1,127.0.0.2',
]
Expand Down

0 comments on commit 7f2fa48

Please sign in to comment.