Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cptbox: add types for debugger, process, and tracer #895

Merged
merged 4 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 79 additions & 3 deletions dmoj/cptbox/_cptbox.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, List
from typing import Callable, Dict, List, Optional

PTBOX_ABI_X86: int
PTBOX_ABI_X64: int
Expand All @@ -13,8 +13,84 @@ IS_WSL1: int
ALL_ABIS: List[int]
SUPPORTED_ABIS: List[int]

Debugger: Any
Process: Any
class Debugger:
syscall: int
result: int
errno: int
arg0: int
arg1: int
arg2: int
arg3: int
arg4: int
arg5: int

uresult: int
uarg0: int
uarg1: int
uarg2: int
uarg3: int
uarg4: int
uarg5: int

pid: int
tid: int
abi: int
def __init__(self, process: Process): ...
def readstr(self, address: int, max_size: int = ...) -> str: ...
def on_return(self, callback: Callable[[], None]): ...

class Process:
debugger: Debugger
_child_stdin: int
_child_stdout: int
_child_stderr: int
_child_memory: int
_child_address: int
_child_personality: int
_cpu_time: int
_nproc: int
_fsize: int

use_seccomp: bool
_trace_syscalls: bool
def create_debugger(self) -> Debugger: ...
def _callback(self, syscall: int) -> bool: ...
def _ptrace_error(self, errno: int) -> None: ...
def _protection_fault(self, syscall: int, is_update: bool) -> None: ...
def _cpu_time_exceeded(self) -> None: ...
def _handler(self, abi: int, syscall: int, handler: int) -> None: ...
def _get_seccomp_whitelist(self) -> List[bool]: ...
def _get_seccomp_errnolist(self) -> List[int]: ...
def _spawn(self, file: bytes, args: List[bytes], env: List[bytes], chdir: bytes = ...) -> None: ...
def _monitor(self) -> int: ...
@property
def _exited(self): ...
@property
def _exitcode(self): ...
@property
def was_initialized(self) -> bool: ...
@property
def pid(self) -> int: ...
@property
def execution_time(self) -> float: ...
@property
def wall_clock_time(self) -> float: ...
@property
def cpu_time(self) -> float: ...
@property
def max_memory(self) -> int: ...
@property
def signal(self) -> Optional[int]: ...
@property
def returncode(self) -> Optional[int]: ...

MAX_SYSCALL_NUMBER: int
NATIVE_ABI: int

PTBOX_SPAWN_FAIL_NO_NEW_PRIVS: int
PTBOX_SPAWN_FAIL_SECCOMP: int
PTBOX_SPAWN_FAIL_TRACEME: int
PTBOX_SPAWN_FAIL_EXECVE: int

AT_FDCWD: int
bsd_get_proc_cwd: Callable[[int], str]
Expand Down
4 changes: 2 additions & 2 deletions dmoj/cptbox/syscalls.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Dict
from typing import List, Dict, Tuple

translator: List[List[int]]
translator: List[Tuple[List[int], ...]]
by_name: Dict[str, int]
by_id: List[str]
SYSCALL_COUNT: int
Expand Down
98 changes: 55 additions & 43 deletions dmoj/cptbox/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
import sys
import threading
from typing import Callable, List, Optional
from typing import Callable, Dict, List, Optional, Tuple, Type

from dmoj.cptbox._cptbox import *
from dmoj.cptbox.handlers import ALLOW, DISALLOW, ErrnoHandlerCallback, _CALLBACK
Expand Down Expand Up @@ -86,36 +86,41 @@ def readstr(self, address, max_size=4096):


class TracedPopen(Process):
def create_debugger(self):
return AdvancedDebugger(self)
_executable: bytes
_last_ptrace_errno: Optional[int]
_spawn_error: Optional[Type[BaseException]]

debugger: AdvancedDebugger
protection_fault: Optional[Tuple[int, str, List[int], Optional[int]]]

def __init__(
self,
args,
avoid_seccomp=False,
executable=None,
args: List[bytes],
*,
executable: bytes,
avoid_seccomp: bool = False,
security=None,
time=0,
memory=0,
stdin=PIPE,
stdout=PIPE,
stderr=None,
env=None,
nproc=0,
fsize=0,
address_grace=4096,
data_grace=0,
personality=0,
cwd='',
wall_time=None,
):
time: int = 0,
memory: int = 0,
stdin: Optional[int] = PIPE,
stdout: Optional[int] = PIPE,
stderr: Optional[int] = None,
env: Optional[Dict[str, str]] = None,
nproc: int = 0,
fsize: int = 0,
address_grace: int = 4096,
data_grace: int = 0,
personality: int = 0,
cwd: bytes = b'',
wall_time: Optional[float] = None,
) -> None:
self._executable = executable
self.use_seccomp = security is not None and not avoid_seccomp

self._args = args
self._chdir = cwd
self._env = [
utf8bytes('%s=%s' % (arg, val))
utf8bytes(f'{arg}={val}')
for arg, val in (env if env is not None else os.environ).items()
if val is not None
]
Expand All @@ -141,6 +146,7 @@ def __init__(
else:
for abi in SUPPORTED_ABIS:
index = _SYSCALL_INDICIES[abi]
assert index is not None
for i in range(SYSCALL_COUNT):
for call in translator[i][index]:
if call is None:
Expand Down Expand Up @@ -168,9 +174,13 @@ def __init__(
if self._spawn_error:
raise self._spawn_error

def _get_seccomp_handlers(self):
def create_debugger(self) -> AdvancedDebugger:
return AdvancedDebugger(self)

def _get_seccomp_handlers(self) -> List[int]:
handlers = [-1] * MAX_SYSCALL_NUMBER
index = _SYSCALL_INDICIES[NATIVE_ABI]
assert index is not None
for i in range(SYSCALL_COUNT):
# Ensure at least one syscall traps.
# Otherwise, a simple assembly program could terminate without ever trapping.
Expand All @@ -186,8 +196,9 @@ def _get_seccomp_handlers(self):
handlers[call] = handler.errno
return handlers

def wait(self):
def wait(self) -> int:
self._died.wait()
assert self.returncode is not None
if not self.was_initialized:
if self.returncode == PTBOX_SPAWN_FAIL_NO_NEW_PRIVS:
raise RuntimeError('failed to call prctl(PR_SET_NO_NEW_PRIVS)')
Expand All @@ -205,33 +216,34 @@ def wait(self):
raise RuntimeError('process failed to initialize with unknown exit code: %d' % self.returncode)
return self.returncode

def poll(self):
def poll(self) -> Optional[int]:
return self.returncode

def mark_ole(self):
def mark_ole(self) -> None:
self._is_ole = True

@property
def is_ir(self):
def is_ir(self) -> bool:
assert self.returncode is not None
return self.returncode > 0

@property
def is_mle(self):
return self._memory and self.max_memory > self._memory
def is_mle(self) -> bool:
return self._memory != 0 and self.max_memory > self._memory

@property
def is_ole(self):
def is_ole(self) -> bool:
return self._is_ole

@property
def is_rte(self):
def is_rte(self) -> bool:
return self.returncode is None or self.returncode < 0 # Killed by signal

@property
def is_tle(self):
def is_tle(self) -> bool:
return self._is_tle

def kill(self):
def kill(self) -> None:
# FIXME(quantum): this is actually a race. The process may exit before we kill it.
# Under very unlikely circumstances, the pid could be reused and we will end up
# killing the wrong process.
Expand All @@ -246,7 +258,7 @@ def kill(self):
else:
log.warning('Skipping the killing of process because it already exited: %s', self.pid)

def _callback(self, syscall):
def _callback(self, syscall) -> bool:
if self.debugger.abi == PTBOX_ABI_INVALID:
log.warning('Received invalid ABI when handling syscall %d', syscall)
return False
Expand All @@ -263,7 +275,7 @@ def _callback(self, syscall):
return callback(self.debugger)
return False

def _protection_fault(self, syscall, is_update):
def _protection_fault(self, syscall: int, is_update: bool) -> None:
# When signed, 0xFFFFFFFF is equal to -1, meaning that ptrace failed to read the syscall for some reason.
# We can't continue debugging as this could potentially be unsafe, so we should exit loudly.
# See <https://github.com/DMOJ/judge/issues/181> for more details.
Expand All @@ -290,20 +302,20 @@ def _protection_fault(self, syscall, is_update):
self._last_ptrace_errno if is_update else None,
)

def _ptrace_error(self, error):
def _ptrace_error(self, error: int) -> None:
self._last_ptrace_errno = error

def _cpu_time_exceeded(self):
def _cpu_time_exceeded(self) -> None:
log.warning('SIGXCPU in process %d', self.pid)
self._is_tle = True

def _run_process(self):
def _run_process(self) -> Optional[int]:
try:
self._spawn(self._executable, self._args, self._env, self._chdir)
except: # noqa: E722, need to catch absolutely everything
self._spawn_error = sys.exc_info()[0]
self._died.set()
return
return None
finally:
if self.stdin_needs_close:
os.close(self._child_stdin)
Expand Down Expand Up @@ -333,14 +345,14 @@ def _run_process(self):

return code

def _shocker_thread(self):
def _shocker_thread(self) -> None:
# On Linux, ignored signals still cause a notification under ptrace.
# Hence, we use SIGWINCH, harmless and ignored signal to make wait4 return
# pt_process::monitor, causing time to be updated.
# On FreeBSD, a signal must not be ignored in order for wait4 to return.
# Hence, we swallow SIGSTOP, which should never be used anyway, and use it
# force an update.
wake_signal = signal.SIGSTOP if 'freebsd' in sys.platform else signal.SIGWINCH
wake_signal = signal.SIGSTOP if FREEBSD else signal.SIGWINCH
self._spawned_or_errored.wait()

while not self._died.wait(1):
Expand All @@ -354,7 +366,7 @@ def _shocker_thread(self):
except OSError:
pass

def __init_streams(self, stdin, stdout, stderr):
def __init_streams(self, stdin, stdout, stderr) -> None:
self.stdin = self.stdout = self.stderr = None
self.stdin_needs_close = self.stdout_needs_close = self.stderr_needs_close = False

Expand Down Expand Up @@ -393,9 +405,9 @@ def __init_streams(self, stdin, stdout, stderr):

communicate = _safe_communicate

def unsafe_communicate(self, input=None):
def unsafe_communicate(self, input: Optional[bytes] = None) -> Tuple[bytes, bytes]:
return _safe_communicate(self, input=input, outlimit=sys.maxsize, errlimit=sys.maxsize)


def can_debug(abi):
def can_debug(abi: int) -> bool:
return abi in SUPPORTED_ABIS
Loading