diff --git a/dmoj/cptbox/_cptbox.pyx b/dmoj/cptbox/_cptbox.pyx index cf36c9136..837ea34d0 100644 --- a/dmoj/cptbox/_cptbox.pyx +++ b/dmoj/cptbox/_cptbox.pyx @@ -120,7 +120,7 @@ cdef extern from 'helper.h' nogil: int stderr_ bool use_seccomp int abi_for_seccomp - bint *seccomp_whitelist + int *seccomp_handlers void cptbox_closefrom(int lowfd) int cptbox_child_run(child_config *) @@ -474,14 +474,14 @@ cdef class Process: cpdef _cpu_time_exceeded(self): pass - cpdef _get_seccomp_whitelist(self): - raise NotImplementedError() + cpdef _get_seccomp_handlers(self): + return [-1] * MAX_SYSCALL cpdef _spawn(self, file, args, env=(), chdir=''): cdef child_config config config.argv = NULL config.envp = NULL - config.seccomp_whitelist = NULL + config.seccomp_handlers = NULL try: config.address_space = self._child_address @@ -500,20 +500,22 @@ cdef class Process: config.use_seccomp = self._use_seccomp() if config.use_seccomp: - whitelist = self._get_seccomp_whitelist() - assert len(whitelist) == MAX_SYSCALL - config.seccomp_whitelist = malloc(sizeof(bint) * MAX_SYSCALL) - if not config.seccomp_whitelist: + handlers = self._get_seccomp_handlers() + assert len(handlers) == MAX_SYSCALL + + config.seccomp_handlers = malloc(sizeof(int) * MAX_SYSCALL) + if not config.seccomp_handlers: PyErr_NoMemory() + for i in range(MAX_SYSCALL): - config.seccomp_whitelist[i] = whitelist[i] + config.seccomp_handlers[i] = handlers[i] if self.process.spawn(pt_child, &config): raise RuntimeError('failed to spawn child') finally: free(config.argv) free(config.envp) - free(config.seccomp_whitelist) + free(config.seccomp_handlers) cpdef _monitor(self): cdef int exitcode diff --git a/dmoj/cptbox/handlers.py b/dmoj/cptbox/handlers.py index 25fad47e2..34971f7b2 100644 --- a/dmoj/cptbox/handlers.py +++ b/dmoj/cptbox/handlers.py @@ -1,25 +1,29 @@ import errno +from dmoj.cptbox._cptbox import Debugger + DISALLOW = 0 ALLOW = 1 _CALLBACK = 2 STDOUTERR = 3 -def errno_handler(name, code): - def handler(debugger): +class ErrnoHandlerCallback: + errno: int + error_name: str + + def __init__(self, error_name: str, errno: int) -> None: + self.errno = errno + self.error_name = error_name + + def __call__(self, debugger: Debugger) -> bool: def on_return(): - debugger.errno = code + debugger.errno = self.errno debugger.syscall = -1 debugger.on_return(on_return) return True - handler.error_name = name - handler.errno = code - return handler - -for err in dir(errno): - if err[0] == 'E': - globals()['ACCESS_%s' % err] = errno_handler(err, getattr(errno, err)) +for code, name in errno.errorcode.items(): + globals()[f'ACCESS_{name}'] = ErrnoHandlerCallback(name, code) diff --git a/dmoj/cptbox/handlers.pyi b/dmoj/cptbox/handlers.pyi index a9536dd20..61f168630 100644 --- a/dmoj/cptbox/handlers.pyi +++ b/dmoj/cptbox/handlers.pyi @@ -1,13 +1,19 @@ -from typing import Any +from dmoj.cptbox._cptbox import Debugger ALLOW: int DISALLOW: int _CALLBACK: int STDOUTERR: int -ACCESS_EACCES: Any -ACCESS_EAGAIN: Any -ACCESS_EFAULT: Any -ACCESS_EINVAL: Any -ACCESS_ENOENT: Any -ACCESS_EPERM: Any -ACCESS_ENAMETOOLONG: Any + +class ErrnoHandlerCallback: + errno: int + error_name: str + def __call__(self, debugger: Debugger) -> bool: ... + +ACCESS_EACCES: ErrnoHandlerCallback +ACCESS_EAGAIN: ErrnoHandlerCallback +ACCESS_EFAULT: ErrnoHandlerCallback +ACCESS_EINVAL: ErrnoHandlerCallback +ACCESS_ENOENT: ErrnoHandlerCallback +ACCESS_EPERM: ErrnoHandlerCallback +ACCESS_ENAMETOOLONG: ErrnoHandlerCallback diff --git a/dmoj/cptbox/helper.cpp b/dmoj/cptbox/helper.cpp index 9e660cafa..7e49b0f21 100644 --- a/dmoj/cptbox/helper.cpp +++ b/dmoj/cptbox/helper.cpp @@ -93,9 +93,16 @@ int cptbox_child_run(const struct child_config *config) { } for (int syscall = 0; syscall < MAX_SYSCALL; syscall++) { - if (config->seccomp_whitelist[syscall]) { + int handler = config->seccomp_handlers[syscall]; + if (handler == 0) { if ((rc = seccomp_rule_add(ctx, SCMP_ACT_ALLOW, syscall, 0))) { - fprintf(stderr, "seccomp_rule_add(..., %d): %s\n", syscall, strerror(-rc)); + fprintf(stderr, "seccomp_rule_add(..., SCMP_ACT_ALLOW, %d): %s\n", syscall, strerror(-rc)); + // This failure is not fatal, it'll just cause the syscall to trap anyway. + } + } else if (handler > 0) { + if ((rc = seccomp_rule_add(ctx, SCMP_ACT_ERRNO(handler), syscall, 0))) { + fprintf(stderr, "seccomp_rule_add(..., SCMP_ACT_ERRNO(%d), %d): %s\n", + handler, syscall, strerror(-rc)); // This failure is not fatal, it'll just cause the syscall to trap anyway. } } diff --git a/dmoj/cptbox/helper.h b/dmoj/cptbox/helper.h index ff6219ff0..b8558e029 100644 --- a/dmoj/cptbox/helper.h +++ b/dmoj/cptbox/helper.h @@ -22,7 +22,7 @@ struct child_config { int stdout_; int stderr_; bool use_seccomp; - int *seccomp_whitelist; + int *seccomp_handlers; }; void cptbox_closefrom(int lowfd); diff --git a/dmoj/cptbox/isolate.py b/dmoj/cptbox/isolate.py index f3cff3fd0..e1391542f 100644 --- a/dmoj/cptbox/isolate.py +++ b/dmoj/cptbox/isolate.py @@ -1,8 +1,9 @@ import logging import os import sys +from typing import Optional, Tuple -from dmoj.cptbox._cptbox import AT_FDCWD, bsd_get_proc_cwd, bsd_get_proc_fdno +from dmoj.cptbox._cptbox import AT_FDCWD, Debugger, bsd_get_proc_cwd, bsd_get_proc_fdno from dmoj.cptbox.filesystem_policies import FilesystemPolicy from dmoj.cptbox.handlers import ( ACCESS_EACCES, @@ -12,9 +13,10 @@ ACCESS_ENOENT, ACCESS_EPERM, ALLOW, + ErrnoHandlerCallback, ) from dmoj.cptbox.syscalls import * -from dmoj.cptbox.tracer import MaxLengthExceeded +from dmoj.cptbox.tracer import HandlerCallback, MaxLengthExceeded from dmoj.utils.unicode import utf8text log = logging.getLogger('dmoj.security') @@ -194,7 +196,7 @@ def __init__(self, read_fs, write_fs=None, writable=(1, 2)): def _compile_fs_jail(self, fs): return FilesystemPolicy(fs or []) - def is_write_flags(self, open_flags): + def is_write_flags(self, open_flags: int) -> bool: for flag in open_write_flags: # Strict equality is necessary here, since e.g. O_TMPFILE has multiple bits set, # and O_DIRECTORY & O_TMPFILE > 0. @@ -203,8 +205,8 @@ def is_write_flags(self, open_flags): return False - def check_file_access(self, syscall, argument, is_open=False): - def check(debugger): + def check_file_access(self, syscall, argument, is_open=False) -> HandlerCallback: + def check(debugger: Debugger) -> bool: file_ptr = getattr(debugger, 'uarg%d' % argument) try: file = debugger.readstr(file_ptr) @@ -224,8 +226,8 @@ def check(debugger): return check - def check_file_access_at(self, syscall, is_open=False): - def check(debugger): + def check_file_access_at(self, syscall, is_open=False) -> HandlerCallback: + def check(debugger: Debugger) -> bool: try: file = debugger.readstr(debugger.uarg1) except MaxLengthExceeded as e: @@ -244,7 +246,9 @@ def check(debugger): return check - def _file_access_check(self, rel_file, debugger, is_open, flag_reg=1, dirfd=AT_FDCWD): + def _file_access_check( + self, rel_file, debugger, is_open, flag_reg=1, dirfd=AT_FDCWD + ) -> Tuple[str, Optional[ErrnoHandlerCallback]]: # Either process called open(NULL, ...), or we failed to read the path # in cptbox. Either way this call should not be allowed; if the path # was indeed NULL we can end the request before it gets to the kernel @@ -304,7 +308,7 @@ def _file_access_check(self, rel_file, debugger, is_open, flag_reg=1, dirfd=AT_F return real, None - def get_full_path(self, debugger, file, dirfd=AT_FDCWD): + def get_full_path(self, debugger: Debugger, file: str, dirfd: int = AT_FDCWD) -> str: dirfd = (dirfd & 0x7FFFFFFF) - (dirfd & 0x80000000) if not file.startswith('/'): dir = self._getcwd_pid(debugger.pid) if dirfd == AT_FDCWD else self._getfd_pid(debugger.pid, dirfd) @@ -312,15 +316,15 @@ def get_full_path(self, debugger, file, dirfd=AT_FDCWD): file = '/' + os.path.normpath(file).lstrip('/') return file - def do_kill(self, debugger): + def do_kill(self, debugger: Debugger) -> bool: # Allow tgkill to execute as long as the target thread group is the debugged process # libstdc++ seems to use this to signal itself, see return True if debugger.uarg0 == debugger.pid else ACCESS_EPERM(debugger) - def do_prlimit(self, debugger): + def do_prlimit(self, debugger: Debugger) -> bool: return True if debugger.uarg0 in (0, debugger.pid) else ACCESS_EPERM(debugger) - def do_prctl(self, debugger): + def do_prctl(self, debugger: Debugger) -> bool: PR_GET_DUMPABLE = 3 PR_SET_NAME = 15 PR_GET_NAME = 16 diff --git a/dmoj/cptbox/tracer.py b/dmoj/cptbox/tracer.py index 7e6e681c6..608e491d6 100644 --- a/dmoj/cptbox/tracer.py +++ b/dmoj/cptbox/tracer.py @@ -6,10 +6,10 @@ import subprocess import sys import threading -from typing import List, Optional +from typing import Callable, List, Optional from dmoj.cptbox._cptbox import * -from dmoj.cptbox.handlers import ALLOW, DISALLOW, _CALLBACK +from dmoj.cptbox.handlers import ALLOW, DISALLOW, ErrnoHandlerCallback, _CALLBACK from dmoj.cptbox.syscalls import SYSCALL_COUNT, by_id, sys_exit, sys_exit_group, sys_getpid, translator from dmoj.utils.communicate import safe_communicate as _safe_communicate from dmoj.utils.os_ext import OOM_SCORE_ADJ_MAX, oom_score_adj @@ -39,6 +39,8 @@ PTBOX_ABI_FREEBSD_X64: 64, } +HandlerCallback = Callable[[Debugger], bool] + class MaxLengthExceeded(ValueError): pass @@ -166,8 +168,8 @@ def __init__( if self._spawn_error: raise self._spawn_error - def _get_seccomp_whitelist(self): - whitelist = [False] * MAX_SYSCALL_NUMBER + def _get_seccomp_handlers(self): + handlers = [-1] * MAX_SYSCALL_NUMBER index = _SYSCALL_INDICIES[NATIVE_ABI] for i in range(SYSCALL_COUNT): # Ensure at least one syscall traps. @@ -178,9 +180,11 @@ def _get_seccomp_whitelist(self): for call in translator[i][index]: if call is None: continue - if isinstance(handler, int): - whitelist[call] = handler == ALLOW - return whitelist + if isinstance(handler, int) and handler == ALLOW: + handlers[call] = 0 + elif isinstance(handler, ErrnoHandlerCallback): + handlers[call] = handler.errno + return handlers def wait(self): self._died.wait() diff --git a/testsuite/helloworld/tests/sandbox_py3_mkdir/helloworld.py b/testsuite/helloworld/tests/sandbox_py3_mkdir/helloworld.py new file mode 100644 index 000000000..4f6d56aa3 --- /dev/null +++ b/testsuite/helloworld/tests/sandbox_py3_mkdir/helloworld.py @@ -0,0 +1,2 @@ +import os +os.mkdir('test') diff --git a/testsuite/helloworld/tests/sandbox_py3_mkdir/test.yml b/testsuite/helloworld/tests/sandbox_py3_mkdir/test.yml new file mode 100644 index 000000000..36d143c1e --- /dev/null +++ b/testsuite/helloworld/tests/sandbox_py3_mkdir/test.yml @@ -0,0 +1,6 @@ +language: PY3 +time: 2 +memory: 65536 +source: helloworld.py +expect: IR +feedback: PermissionError