From 06b072f8df807392f410039fc2cc96ea57d3120f Mon Sep 17 00:00:00 2001 From: Quantum Date: Tue, 7 Sep 2021 06:02:11 -0400 Subject: [PATCH] cptbox: add types for tracer.py --- dmoj/cptbox/_cptbox.pyi | 8 ++++ dmoj/cptbox/syscalls.pyi | 4 +- dmoj/cptbox/tracer.py | 98 ++++++++++++++++++++++------------------ 3 files changed, 65 insertions(+), 45 deletions(-) diff --git a/dmoj/cptbox/_cptbox.pyi b/dmoj/cptbox/_cptbox.pyi index 2d2bada73..0e410a252 100644 --- a/dmoj/cptbox/_cptbox.pyi +++ b/dmoj/cptbox/_cptbox.pyi @@ -84,6 +84,14 @@ class Process: @property def returncode(self) -> Optional[int]: ... +MAX_SYSCALL_NUMBER: int +NATIVE_ABI: int + +PTBOX_SPAWN_FAIL_NO_NEW_PRIVS: int +PTBOX_SPAWN_FAIL_SECCOMP: int +PTBOX_SPAWN_FAIL_TRACEME: int +PTBOX_SPAWN_FAIL_EXECVE: int + AT_FDCWD: int bsd_get_proc_cwd: Callable[[int], str] bsd_get_proc_fdno: Callable[[int, int], str] diff --git a/dmoj/cptbox/syscalls.pyi b/dmoj/cptbox/syscalls.pyi index 41c428ef0..a824fecc7 100644 --- a/dmoj/cptbox/syscalls.pyi +++ b/dmoj/cptbox/syscalls.pyi @@ -1,6 +1,6 @@ -from typing import List, Dict +from typing import List, Dict, Tuple -translator: List[List[int]] +translator: List[Tuple[List[int], ...]] by_name: Dict[str, int] by_id: List[str] SYSCALL_COUNT: int diff --git a/dmoj/cptbox/tracer.py b/dmoj/cptbox/tracer.py index 608e491d6..f423b3f1e 100644 --- a/dmoj/cptbox/tracer.py +++ b/dmoj/cptbox/tracer.py @@ -6,7 +6,7 @@ import subprocess import sys import threading -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional, Tuple, Type from dmoj.cptbox._cptbox import * from dmoj.cptbox.handlers import ALLOW, DISALLOW, ErrnoHandlerCallback, _CALLBACK @@ -86,36 +86,41 @@ def readstr(self, address, max_size=4096): class TracedPopen(Process): - def create_debugger(self): - return AdvancedDebugger(self) + _executable: bytes + _last_ptrace_errno: Optional[int] + _spawn_error: Optional[Type[BaseException]] + + debugger: AdvancedDebugger + protection_fault: Optional[Tuple[int, str, List[int], Optional[int]]] def __init__( self, - args, - avoid_seccomp=False, - executable=None, + args: List[bytes], + *, + executable: bytes, + avoid_seccomp: bool = False, security=None, - time=0, - memory=0, - stdin=PIPE, - stdout=PIPE, - stderr=None, - env=None, - nproc=0, - fsize=0, - address_grace=4096, - data_grace=0, - personality=0, - cwd='', - wall_time=None, - ): + time: int = 0, + memory: int = 0, + stdin: Optional[int] = PIPE, + stdout: Optional[int] = PIPE, + stderr: Optional[int] = None, + env: Optional[Dict[str, str]] = None, + nproc: int = 0, + fsize: int = 0, + address_grace: int = 4096, + data_grace: int = 0, + personality: int = 0, + cwd: bytes = b'', + wall_time: Optional[float] = None, + ) -> None: self._executable = executable self.use_seccomp = security is not None and not avoid_seccomp self._args = args self._chdir = cwd self._env = [ - utf8bytes('%s=%s' % (arg, val)) + utf8bytes(f'{arg}={val}') for arg, val in (env if env is not None else os.environ).items() if val is not None ] @@ -141,6 +146,7 @@ def __init__( else: for abi in SUPPORTED_ABIS: index = _SYSCALL_INDICIES[abi] + assert index is not None for i in range(SYSCALL_COUNT): for call in translator[i][index]: if call is None: @@ -168,9 +174,13 @@ def __init__( if self._spawn_error: raise self._spawn_error - def _get_seccomp_handlers(self): + def create_debugger(self) -> AdvancedDebugger: + return AdvancedDebugger(self) + + def _get_seccomp_handlers(self) -> List[int]: handlers = [-1] * MAX_SYSCALL_NUMBER index = _SYSCALL_INDICIES[NATIVE_ABI] + assert index is not None for i in range(SYSCALL_COUNT): # Ensure at least one syscall traps. # Otherwise, a simple assembly program could terminate without ever trapping. @@ -186,8 +196,9 @@ def _get_seccomp_handlers(self): handlers[call] = handler.errno return handlers - def wait(self): + def wait(self) -> int: self._died.wait() + assert self.returncode is not None if not self.was_initialized: if self.returncode == PTBOX_SPAWN_FAIL_NO_NEW_PRIVS: raise RuntimeError('failed to call prctl(PR_SET_NO_NEW_PRIVS)') @@ -205,33 +216,34 @@ def wait(self): raise RuntimeError('process failed to initialize with unknown exit code: %d' % self.returncode) return self.returncode - def poll(self): + def poll(self) -> Optional[int]: return self.returncode - def mark_ole(self): + def mark_ole(self) -> None: self._is_ole = True @property - def is_ir(self): + def is_ir(self) -> bool: + assert self.returncode is not None return self.returncode > 0 @property - def is_mle(self): - return self._memory and self.max_memory > self._memory + def is_mle(self) -> bool: + return self._memory != 0 and self.max_memory > self._memory @property - def is_ole(self): + def is_ole(self) -> bool: return self._is_ole @property - def is_rte(self): + def is_rte(self) -> bool: return self.returncode is None or self.returncode < 0 # Killed by signal @property - def is_tle(self): + def is_tle(self) -> bool: return self._is_tle - def kill(self): + def kill(self) -> None: # FIXME(quantum): this is actually a race. The process may exit before we kill it. # Under very unlikely circumstances, the pid could be reused and we will end up # killing the wrong process. @@ -246,7 +258,7 @@ def kill(self): else: log.warning('Skipping the killing of process because it already exited: %s', self.pid) - def _callback(self, syscall): + def _callback(self, syscall) -> bool: if self.debugger.abi == PTBOX_ABI_INVALID: log.warning('Received invalid ABI when handling syscall %d', syscall) return False @@ -263,7 +275,7 @@ def _callback(self, syscall): return callback(self.debugger) return False - def _protection_fault(self, syscall, is_update): + def _protection_fault(self, syscall: int, is_update: bool) -> None: # When signed, 0xFFFFFFFF is equal to -1, meaning that ptrace failed to read the syscall for some reason. # We can't continue debugging as this could potentially be unsafe, so we should exit loudly. # See for more details. @@ -290,20 +302,20 @@ def _protection_fault(self, syscall, is_update): self._last_ptrace_errno if is_update else None, ) - def _ptrace_error(self, error): + def _ptrace_error(self, error: int) -> None: self._last_ptrace_errno = error - def _cpu_time_exceeded(self): + def _cpu_time_exceeded(self) -> None: log.warning('SIGXCPU in process %d', self.pid) self._is_tle = True - def _run_process(self): + def _run_process(self) -> Optional[int]: try: self._spawn(self._executable, self._args, self._env, self._chdir) except: # noqa: E722, need to catch absolutely everything self._spawn_error = sys.exc_info()[0] self._died.set() - return + return None finally: if self.stdin_needs_close: os.close(self._child_stdin) @@ -333,14 +345,14 @@ def _run_process(self): return code - def _shocker_thread(self): + def _shocker_thread(self) -> None: # On Linux, ignored signals still cause a notification under ptrace. # Hence, we use SIGWINCH, harmless and ignored signal to make wait4 return # pt_process::monitor, causing time to be updated. # On FreeBSD, a signal must not be ignored in order for wait4 to return. # Hence, we swallow SIGSTOP, which should never be used anyway, and use it # force an update. - wake_signal = signal.SIGSTOP if 'freebsd' in sys.platform else signal.SIGWINCH + wake_signal = signal.SIGSTOP if FREEBSD else signal.SIGWINCH self._spawned_or_errored.wait() while not self._died.wait(1): @@ -354,7 +366,7 @@ def _shocker_thread(self): except OSError: pass - def __init_streams(self, stdin, stdout, stderr): + def __init_streams(self, stdin, stdout, stderr) -> None: self.stdin = self.stdout = self.stderr = None self.stdin_needs_close = self.stdout_needs_close = self.stderr_needs_close = False @@ -393,9 +405,9 @@ def __init_streams(self, stdin, stdout, stderr): communicate = _safe_communicate - def unsafe_communicate(self, input=None): + def unsafe_communicate(self, input: Optional[bytes] = None) -> Tuple[bytes, bytes]: return _safe_communicate(self, input=input, outlimit=sys.maxsize, errlimit=sys.maxsize) -def can_debug(abi): +def can_debug(abi: int) -> bool: return abi in SUPPORTED_ABIS