Skip to content

Commit

Permalink
cptbox: add types for tracer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
quantum5 committed Sep 9, 2021
1 parent ad76261 commit eefbcbb
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 45 deletions.
8 changes: 8 additions & 0 deletions dmoj/cptbox/_cptbox.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ class Process:
@property
def returncode(self) -> Optional[int]: ...

MAX_SYSCALL_NUMBER: int
NATIVE_ABI: int

PTBOX_SPAWN_FAIL_NO_NEW_PRIVS: int
PTBOX_SPAWN_FAIL_SECCOMP: int
PTBOX_SPAWN_FAIL_TRACEME: int
PTBOX_SPAWN_FAIL_EXECVE: int

AT_FDCWD: int
bsd_get_proc_cwd: Callable[[int], str]
bsd_get_proc_fdno: Callable[[int, int], str]
Expand Down
4 changes: 2 additions & 2 deletions dmoj/cptbox/syscalls.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Dict
from typing import List, Dict, Tuple

translator: List[List[int]]
translator: List[Tuple[List[int], ...]]
by_name: Dict[str, int]
by_id: List[str]
SYSCALL_COUNT: int
Expand Down
98 changes: 55 additions & 43 deletions dmoj/cptbox/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
import sys
import threading
from typing import Callable, List, Optional
from typing import Callable, Dict, List, Optional, Tuple, Type

from dmoj.cptbox._cptbox import *
from dmoj.cptbox.handlers import ALLOW, DISALLOW, ErrnoHandlerCallback, _CALLBACK
Expand Down Expand Up @@ -87,29 +87,34 @@ def readstr(self, address, max_size=4096):


class TracedPopen(Process):
def create_debugger(self):
return AdvancedDebugger(self)
_executable: bytes
_last_ptrace_errno: Optional[int]
_spawn_error: Optional[Type[BaseException]]

debugger: AdvancedDebugger
protection_fault: Optional[Tuple[int, str, List[int], Optional[int]]]

def __init__(
self,
args,
avoid_seccomp=False,
executable=None,
args: List[bytes],
*,
executable: bytes,
avoid_seccomp: bool = False,
security=None,
time=0,
memory=0,
stdin=PIPE,
stdout=PIPE,
stderr=None,
env=None,
nproc=0,
fsize=0,
address_grace=4096,
data_grace=0,
personality=0,
cwd='',
wall_time=None,
):
time: int = 0,
memory: int = 0,
stdin: Optional[int] = PIPE,
stdout: Optional[int] = PIPE,
stderr: Optional[int] = None,
env: Optional[Dict[str, str]] = None,
nproc: int = 0,
fsize: int = 0,
address_grace: int = 4096,
data_grace: int = 0,
personality: int = 0,
cwd: bytes = b'',
wall_time: Optional[float] = None,
) -> None:
self._executable = executable
self.use_seccomp = security is not None and not avoid_seccomp

Expand All @@ -120,7 +125,7 @@ def __init__(
self._args = args
self._chdir = cwd
self._env = [
utf8bytes('%s=%s' % (arg, val))
utf8bytes(f'{arg}={val}')
for arg, val in (env if env is not None else os.environ).items()
if val is not None
]
Expand All @@ -146,6 +151,7 @@ def __init__(
else:
for abi in SUPPORTED_ABIS:
index = _SYSCALL_INDICIES[abi]
assert index is not None
for i in range(SYSCALL_COUNT):
for call in translator[i][index]:
if call is None:
Expand Down Expand Up @@ -173,9 +179,13 @@ def __init__(
if self._spawn_error:
raise self._spawn_error

def _get_seccomp_handlers(self):
def create_debugger(self) -> AdvancedDebugger:
return AdvancedDebugger(self)

def _get_seccomp_handlers(self) -> List[int]:
handlers = [-1] * MAX_SYSCALL_NUMBER
index = _SYSCALL_INDICIES[NATIVE_ABI]
assert index is not None
for i in range(SYSCALL_COUNT):
# Ensure at least one syscall traps.
# Otherwise, a simple assembly program could terminate without ever trapping.
Expand All @@ -191,8 +201,9 @@ def _get_seccomp_handlers(self):
handlers[call] = handler.errno
return handlers

def wait(self):
def wait(self) -> int:
self._died.wait()
assert self.returncode is not None
if not self.was_initialized:
if self.returncode == PTBOX_SPAWN_FAIL_NO_NEW_PRIVS:
raise RuntimeError('failed to call prctl(PR_SET_NO_NEW_PRIVS)')
Expand All @@ -210,33 +221,34 @@ def wait(self):
raise RuntimeError('process failed to initialize with unknown exit code: %d' % self.returncode)
return self.returncode

def poll(self):
def poll(self) -> Optional[int]:
return self.returncode

def mark_ole(self):
def mark_ole(self) -> None:
self._is_ole = True

@property
def is_ir(self):
def is_ir(self) -> bool:
assert self.returncode is not None
return self.returncode > 0

@property
def is_mle(self):
return self._memory and self.max_memory > self._memory
def is_mle(self) -> bool:
return self._memory != 0 and self.max_memory > self._memory

@property
def is_ole(self):
def is_ole(self) -> bool:
return self._is_ole

@property
def is_rte(self):
def is_rte(self) -> bool:
return self.returncode is None or self.returncode < 0 # Killed by signal

@property
def is_tle(self):
def is_tle(self) -> bool:
return self._is_tle

def kill(self):
def kill(self) -> None:
# FIXME(quantum): this is actually a race. The process may exit before we kill it.
# Under very unlikely circumstances, the pid could be reused and we will end up
# killing the wrong process.
Expand All @@ -251,7 +263,7 @@ def kill(self):
else:
log.warning('Skipping the killing of process because it already exited: %s', self.pid)

def _callback(self, syscall):
def _callback(self, syscall) -> bool:
if self.debugger.abi == PTBOX_ABI_INVALID:
log.warning('Received invalid ABI when handling syscall %d', syscall)
return False
Expand All @@ -268,7 +280,7 @@ def _callback(self, syscall):
return callback(self.debugger)
return False

def _protection_fault(self, syscall, is_update):
def _protection_fault(self, syscall: int, is_update: bool) -> None:
# When signed, 0xFFFFFFFF is equal to -1, meaning that ptrace failed to read the syscall for some reason.
# We can't continue debugging as this could potentially be unsafe, so we should exit loudly.
# See <https://github.com/DMOJ/judge/issues/181> for more details.
Expand All @@ -295,20 +307,20 @@ def _protection_fault(self, syscall, is_update):
self._last_ptrace_errno if is_update else None,
)

def _ptrace_error(self, error):
def _ptrace_error(self, error: int) -> None:
self._last_ptrace_errno = error

def _cpu_time_exceeded(self):
def _cpu_time_exceeded(self) -> None:
log.warning('SIGXCPU in process %d', self.pid)
self._is_tle = True

def _run_process(self):
def _run_process(self) -> Optional[int]:
try:
self._spawn(self._executable, self._args, self._env, self._chdir)
except: # noqa: E722, need to catch absolutely everything
self._spawn_error = sys.exc_info()[0]
self._died.set()
return
return None
finally:
if self.stdin_needs_close:
os.close(self._child_stdin)
Expand Down Expand Up @@ -338,14 +350,14 @@ def _run_process(self):

return code

def _shocker_thread(self):
def _shocker_thread(self) -> None:
# On Linux, ignored signals still cause a notification under ptrace.
# Hence, we use SIGWINCH, harmless and ignored signal to make wait4 return
# pt_process::monitor, causing time to be updated.
# On FreeBSD, a signal must not be ignored in order for wait4 to return.
# Hence, we swallow SIGSTOP, which should never be used anyway, and use it
# force an update.
wake_signal = signal.SIGSTOP if 'freebsd' in sys.platform else signal.SIGWINCH
wake_signal = signal.SIGSTOP if FREEBSD else signal.SIGWINCH
self._spawned_or_errored.wait()

while not self._died.wait(1):
Expand All @@ -359,7 +371,7 @@ def _shocker_thread(self):
except OSError:
pass

def __init_streams(self, stdin, stdout, stderr):
def __init_streams(self, stdin, stdout, stderr) -> None:
self.stdin = self.stdout = self.stderr = None
self.stdin_needs_close = self.stdout_needs_close = self.stderr_needs_close = False

Expand Down Expand Up @@ -398,9 +410,9 @@ def __init_streams(self, stdin, stdout, stderr):

communicate = _safe_communicate

def unsafe_communicate(self, input=None):
def unsafe_communicate(self, input: Optional[bytes] = None) -> Tuple[bytes, bytes]:
return _safe_communicate(self, input=input, outlimit=sys.maxsize, errlimit=sys.maxsize)


def can_debug(abi):
def can_debug(abi: int) -> bool:
return abi in SUPPORTED_ABIS

0 comments on commit eefbcbb

Please sign in to comment.