diff --git a/dmoj/checkers/bridged.py b/dmoj/checkers/bridged.py index 81c0a53cd..ba5cb00eb 100644 --- a/dmoj/checkers/bridged.py +++ b/dmoj/checkers/bridged.py @@ -3,6 +3,7 @@ import subprocess from dmoj.contrib import contrib_modules +from dmoj.cptbox.filesystem_policies import ExactFile from dmoj.error import InternalError from dmoj.judgeenv import env, get_problem_root from dmoj.result import CheckerResult @@ -25,10 +26,10 @@ def get_executor(problem_id, files, flags, lang, compiler_time_limit): def check( process_output, judge_output, - judge_input, problem_id, files, lang, + case, time_limit=env['generator_time_limit'], memory_limit=env['generator_memory_limit'], compiler_time_limit=env['generator_compiler_limit'], @@ -46,16 +47,23 @@ def check( args_format_string = args_format_string or contrib_modules[type].ContribModule.get_checker_args_format_string() - with mktemp(judge_input) as input_file, mktemp(process_output) as output_file, mktemp(judge_output) as answer_file: + with mktemp(process_output) as output_file, mktemp(judge_output) as answer_file: + input_path = case.input_data_fd().to_path() + checker_args = shlex.split( args_format_string.format( - input_file=shlex.quote(input_file.name), + input_file=shlex.quote(input_path), output_file=shlex.quote(output_file.name), answer_file=shlex.quote(answer_file.name), ) ) process = executor.launch( - *checker_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, memory=memory_limit, time=time_limit + *checker_args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + memory=memory_limit, + time=time_limit, + extra_fs=[ExactFile(input_path)], ) proc_output, error = process.communicate() diff --git a/dmoj/cptbox/_cptbox.pyi b/dmoj/cptbox/_cptbox.pyi index 89824f501..479942fdd 100644 --- a/dmoj/cptbox/_cptbox.pyi +++ b/dmoj/cptbox/_cptbox.pyi @@ -98,3 +98,6 @@ bsd_get_proc_fdno: Callable[[int, int], str] memory_fd_create: Callable[[], int] memory_fd_seal: Callable[[int], None] + +class BufferProxy: + def _get_real_buffer(self): ... diff --git a/dmoj/cptbox/_cptbox.pyx b/dmoj/cptbox/_cptbox.pyx index e011082ea..0a7684be0 100644 --- a/dmoj/cptbox/_cptbox.pyx +++ b/dmoj/cptbox/_cptbox.pyx @@ -1,5 +1,6 @@ # cython: language_level=3 from cpython.exc cimport PyErr_NoMemory, PyErr_SetFromErrno +from cpython.buffer cimport PyObject_GetBuffer from cpython.bytes cimport PyBytes_AsString, PyBytes_FromStringAndSize from libc.stdio cimport FILE, fopen, fclose, fgets, sprintf from libc.stdlib cimport malloc, free, strtoul @@ -576,3 +577,11 @@ cdef class Process: if not self._exited: return None return self._exitcode + + +cdef class BufferProxy: + def _get_real_buffer(self): + raise NotImplementedError + + def __getbuffer__(self, Py_buffer *buffer, int flags): + PyObject_GetBuffer(self._get_real_buffer(), buffer, flags) diff --git a/dmoj/cptbox/helper.cpp b/dmoj/cptbox/helper.cpp index 6776423fb..aabc31b6c 100644 --- a/dmoj/cptbox/helper.cpp +++ b/dmoj/cptbox/helper.cpp @@ -306,7 +306,7 @@ int memory_fd_create(void) { #ifdef __FreeBSD__ char filename[] = "/tmp/cptbox-memoryfd-XXXXXXXX"; int fd = mkstemp(filename); - if (fd > 0) + if (fd >= 0) unlink(filename); return fd; #else diff --git a/dmoj/cptbox/isolate.py b/dmoj/cptbox/isolate.py index 364414010..8c83ffd24 100644 --- a/dmoj/cptbox/isolate.py +++ b/dmoj/cptbox/isolate.py @@ -295,7 +295,7 @@ def _file_access_check( real = os.path.realpath(file) try: - same = normalized == real or os.path.samefile(projected, real) + same = normalized == real or real.startswith('/memfd:') or os.path.samefile(projected, real) except OSError: log.debug('Denying access due to inability to stat: normalizes to: %s, actually: %s', normalized, real) return file, ACCESS_ENOENT diff --git a/dmoj/cptbox/lazy_bytes.py b/dmoj/cptbox/lazy_bytes.py new file mode 100644 index 000000000..b6b3cd8f7 --- /dev/null +++ b/dmoj/cptbox/lazy_bytes.py @@ -0,0 +1,88 @@ +# Based off https://github.com/django/django/blob/main/django/utils/functional.py, licensed under 3-clause BSD. +from functools import total_ordering + +from dmoj.cptbox._cptbox import BufferProxy + +_SENTINEL = object() + + +@total_ordering +class LazyBytes(BufferProxy): + """ + Encapsulate a function call and act as a proxy for methods that are + called on the result of that function. The function is not evaluated + until one of the methods on the result is called. + """ + + def __init__(self, func): + self.__func = func + self.__value = _SENTINEL + + def __get_value(self): + if self.__value is _SENTINEL: + self.__value = self.__func() + return self.__value + + @classmethod + def _create_promise(cls, method_name): + # Builds a wrapper around some magic method + def wrapper(self, *args, **kw): + # Automatically triggers the evaluation of a lazy value and + # applies the given magic method of the result type. + res = self.__get_value() + return getattr(res, method_name)(*args, **kw) + + return wrapper + + def __cast(self): + return bytes(self.__get_value()) + + def _get_real_buffer(self): + return self.__cast() + + def __bytes__(self): + return self.__cast() + + def __repr__(self): + return repr(self.__cast()) + + def __str__(self): + return str(self.__cast()) + + def __eq__(self, other): + if isinstance(other, LazyBytes): + other = other.__cast() + return self.__cast() == other + + def __lt__(self, other): + if isinstance(other, LazyBytes): + other = other.__cast() + return self.__cast() < other + + def __hash__(self): + return hash(self.__cast()) + + def __mod__(self, rhs): + return self.__cast() % rhs + + def __add__(self, other): + return self.__cast() + other + + def __radd__(self, other): + return other + self.__cast() + + def __deepcopy__(self, memo): + # Instances of this class are effectively immutable. It's just a + # collection of functions. So we don't need to do anything + # complicated for copying. + memo[id(self)] = self + return self + + +for type_ in bytes.mro(): + for method_name in type_.__dict__: + # All __promise__ return the same wrapper method, they + # look up the correct implementation when called. + if hasattr(LazyBytes, method_name): + continue + setattr(LazyBytes, method_name, LazyBytes._create_promise(method_name)) diff --git a/dmoj/cptbox/utils.py b/dmoj/cptbox/utils.py index 476387b10..51f96b8db 100644 --- a/dmoj/cptbox/utils.py +++ b/dmoj/cptbox/utils.py @@ -1,11 +1,62 @@ +import errno import io +import mmap +import os +from tempfile import NamedTemporaryFile +from typing import Optional from dmoj.cptbox._cptbox import memory_fd_create, memory_fd_seal +from dmoj.cptbox.tracer import FREEBSD class MemoryIO(io.FileIO): - def __init__(self) -> None: - super().__init__(memory_fd_create(), 'r+') + _name: Optional[str] = None + + def __init__(self, prefill: Optional[bytes] = None, seal=False) -> None: + if FREEBSD: + with NamedTemporaryFile(delete=False) as f: + self._name = f.name + super().__init__(os.dup(f.fileno()), 'r+') + else: + super().__init__(memory_fd_create(), 'r+') + + if prefill: + self.write(prefill) + if seal: + self.seal() def seal(self) -> None: - memory_fd_seal(self.fileno()) + fd = self.fileno() + try: + memory_fd_seal(fd) + except OSError as e: + if e.errno == errno.ENOSYS: + # FreeBSD + self.seek(0, os.SEEK_SET) + return + raise + + new_fd = os.open(f'/proc/self/fd/{fd}', os.O_RDONLY) + try: + os.dup2(new_fd, fd) + finally: + os.close(new_fd) + + def close(self) -> None: + super().close() + if self._name: + os.unlink(self._name) + + def to_path(self) -> str: + if self._name: + return self._name + return f'/proc/{os.getpid()}/fd/{self.fileno()}' + + def to_bytes(self) -> bytes: + try: + with mmap.mmap(self.fileno(), 0, access=mmap.ACCESS_READ) as f: + return bytes(f) + except ValueError as e: + if e.args[0] == 'cannot mmap an empty file': + return b'' + raise diff --git a/dmoj/executors/base_executor.py b/dmoj/executors/base_executor.py index ff42ecd49..45dd0e342 100644 --- a/dmoj/executors/base_executor.py +++ b/dmoj/executors/base_executor.py @@ -212,8 +212,11 @@ def _add_syscalls(self, sec: IsolateTracer) -> IsolateTracer: sec[getattr(syscalls, f'sys_{name}')] = handler return sec - def get_security(self, launch_kwargs=None) -> IsolateTracer: - sec = IsolateTracer(self.get_fs(), write_fs=self.get_write_fs()) + def get_security(self, launch_kwargs=None, extra_fs=None) -> IsolateTracer: + read_fs = self.get_fs() + if extra_fs: + read_fs += extra_fs + sec = IsolateTracer(read_fs, write_fs=self.get_write_fs()) return self._add_syscalls(sec) def get_fs(self) -> List[FilesystemAccessRule]: @@ -281,7 +284,7 @@ def launch(self, *args, **kwargs) -> TracedPopen: return TracedPopen( [utf8bytes(a) for a in self.get_cmdline(**kwargs) + list(args)], executable=utf8bytes(executable), - security=self.get_security(launch_kwargs=kwargs), + security=self.get_security(launch_kwargs=kwargs, extra_fs=kwargs.get('extra_fs')), address_grace=self.get_address_grace(), data_grace=self.data_grace, personality=self.personality, diff --git a/dmoj/executors/shell_executor.py b/dmoj/executors/shell_executor.py index 11b6f4b08..dc18fadca 100644 --- a/dmoj/executors/shell_executor.py +++ b/dmoj/executors/shell_executor.py @@ -21,10 +21,10 @@ def get_fs(self): def get_allowed_syscalls(self): return super().get_allowed_syscalls() + ['fork', 'waitpid', 'wait4'] - def get_security(self, launch_kwargs=None): + def get_security(self, launch_kwargs=None, extra_fs=None): from dmoj.cptbox.syscalls import sys_execve, sys_access, sys_eaccess - sec = super().get_security(launch_kwargs) + sec = super().get_security(launch_kwargs=launch_kwargs, extra_fs=extra_fs) allowed = set(self.get_allowed_exec()) def handle_execve(debugger): diff --git a/dmoj/graders/bridged.py b/dmoj/graders/bridged.py index d1ef69774..e0e34c405 100644 --- a/dmoj/graders/bridged.py +++ b/dmoj/graders/bridged.py @@ -3,6 +3,7 @@ import subprocess from dmoj.contrib import contrib_modules +from dmoj.cptbox.filesystem_policies import ExactFile from dmoj.error import InternalError from dmoj.graders.standard import StandardGrader from dmoj.judgeenv import env, get_problem_root @@ -38,7 +39,7 @@ def check_result(self, case, result): return (not result.result_flag) and parsed_result - def _launch_process(self, case): + def _launch_process(self, case, input_file=None): self._interactor_stdin_pipe, submission_stdout_pipe = os.pipe() submission_stdin_pipe, self._interactor_stdout_pipe = os.pipe() self._current_proc = self.binary.launch( @@ -53,7 +54,7 @@ def _launch_process(self, case): os.close(submission_stdin_pipe) os.close(submission_stdout_pipe) - def _interact_with_process(self, case, result, input): + def _interact_with_process(self, case, result): judge_output = case.output_data() # Give TL + 2s by default, so we do not race (and incorrectly throw IE) if submission gets TLE self._interactor_time_limit = (self.handler_data.preprocessing_time or 2) + self.problem.time_limit @@ -63,14 +64,16 @@ def _interact_with_process(self, case, result, input): or contrib_modules[self.contrib_type].ContribModule.get_interactor_args_format_string() ) - with mktemp(input) as input_file, mktemp(judge_output) as answer_file: + with mktemp(judge_output) as answer_file: + input_path = case.input_data_fd().to_path() + # TODO(@kirito): testlib.h expects a file they can write to, # but we currently don't have a sane way to allow this. # Thus we pass /dev/null for now so testlib interactors will still # work, albeit with diminished capabilities interactor_args = shlex.split( args_format_string.format( - input_file=shlex.quote(input_file.name), + input_file=shlex.quote(input_path), output_file=shlex.quote(os.devnull), answer_file=shlex.quote(answer_file.name), ) @@ -82,6 +85,7 @@ def _interact_with_process(self, case, result, input): stdin=self._interactor_stdin_pipe, stdout=self._interactor_stdout_pipe, stderr=subprocess.PIPE, + extra_fs=[ExactFile(input_path)], ) os.close(self._interactor_stdin_pipe) diff --git a/dmoj/graders/interactive.py b/dmoj/graders/interactive.py index ab240f2bd..17071c930 100644 --- a/dmoj/graders/interactive.py +++ b/dmoj/graders/interactive.py @@ -91,7 +91,10 @@ def close(self): class InteractiveGrader(StandardGrader): - def _interact_with_process(self, case, result, input): + def _launch_process(self, case, input_file=None): + super()._launch_process(case, input_file=None) + + def _interact_with_process(self, case, result): interactor = Interactor(self._current_proc) self.check = False self.feedback = None diff --git a/dmoj/graders/standard.py b/dmoj/graders/standard.py index c7ff0dabd..90d07920f 100644 --- a/dmoj/graders/standard.py +++ b/dmoj/graders/standard.py @@ -1,6 +1,7 @@ import logging import subprocess +from dmoj.cptbox.lazy_bytes import LazyBytes from dmoj.error import OutputLimitExceeded from dmoj.executors import executors from dmoj.graders.base import BaseGrader @@ -13,11 +14,11 @@ class StandardGrader(BaseGrader): def grade(self, case): result = Result(case) - input = case.input_data() # cache generator data + input_file = case.input_data_fd() - self._launch_process(case) + self._launch_process(case, input_file) - error = self._interact_with_process(case, result, input) + error = self._interact_with_process(case, result) process = self._current_proc @@ -55,7 +56,7 @@ def check_result(self, case, result): result.proc_output, case.output_data(), submission_source=self.source, - judge_input=case.input_data(), + judge_input=LazyBytes(case.input_data), point_value=case.points, case_position=case.position, batch=case.batch, @@ -63,6 +64,7 @@ def check_result(self, case, result): binary_data=case.has_binary_data, execution_time=result.execution_time, problem_id=self.problem.id, + case=case, ) except UnicodeDecodeError: # Don't rely on problemsetters to do sane things when it comes to Unicode handling, so @@ -74,22 +76,22 @@ def check_result(self, case, result): return check - def _launch_process(self, case): + def _launch_process(self, case, input_file=None): self._current_proc = self.binary.launch( time=self.problem.time_limit, memory=self.problem.memory_limit, symlinks=case.config.symlinks, - stdin=subprocess.PIPE, + stdin=input_file or subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, wall_time=case.config.wall_time_factor * self.problem.time_limit, ) - def _interact_with_process(self, case, result, input): + def _interact_with_process(self, case, result): process = self._current_proc try: result.proc_output, error = process.communicate( - input, outlimit=case.config.output_limit_length, errlimit=1048576 + None, outlimit=case.config.output_limit_length, errlimit=1048576 ) except OutputLimitExceeded: error = b'' diff --git a/dmoj/problem.py b/dmoj/problem.py index 794099656..68669cca3 100644 --- a/dmoj/problem.py +++ b/dmoj/problem.py @@ -1,6 +1,7 @@ import itertools import os import re +import shutil import subprocess import zipfile from collections import defaultdict @@ -12,9 +13,11 @@ from dmoj import checkers from dmoj.config import ConfigNode, InvalidInitException +from dmoj.cptbox.utils import MemoryIO from dmoj.judgeenv import env, get_problem_root from dmoj.utils.helper_files import compile_with_auxiliary_files, parse_helper_file_error from dmoj.utils.module import load_module_from_file +from dmoj.utils.normalize import normalized_file_copy DEFAULT_TEST_CASE_INPUT_PATTERN = r'^(?=.*?\.in|in).*?(?:(?:^|\W)(?P\d+)[^\d\s]+)?(?P\d+)[^\d\s]*$' DEFAULT_TEST_CASE_OUTPUT_PATTERN = r'^(?=.*?\.out|out).*?(?:(?:^|\W)(?P\d+)[^\d\s]+)?(?P\d+)[^\d\s]*$' @@ -177,17 +180,29 @@ def __init__(self, problem_root_dir, **kwargs): self.problem_root_dir = problem_root_dir self.archive = None - def __missing__(self, key): + def open(self, key): try: - with open(os.path.join(self.problem_root_dir, key), 'rb') as f: - return f.read() + return open(os.path.join(self.problem_root_dir, key), 'rb') except IOError: if self.archive: zipinfo = self.archive.getinfo(key) - with self.archive.open(zipinfo) as f: - return f.read() + return self.archive.open(zipinfo) raise KeyError('file "%s" could not be found in "%s"' % (key, self.problem_root_dir)) + def as_fd(self, key, normalize=False): + memory = MemoryIO() + with self.open(key) as f: + if normalize: + normalized_file_copy(f, memory) + else: + shutil.copyfileobj(f, memory) + memory.seal() + return memory + + def __missing__(self, key): + with self.open(key) as f: + return f.read() + def __del__(self): if self.archive: self.archive.close() @@ -241,6 +256,7 @@ def __init__(self, count, batch_no, config, problem): self.output_prefix_length = config.output_prefix_length self.has_binary_data = config.binary_data self._generated = None + self._input_data_fd = None def _normalize(self, data): # Perhaps the correct answer may be 'no output', in which case it'll be @@ -330,6 +346,14 @@ def _run_generator(self, gen, args=None): parse_helper_file_error(proc, executor, 'generator', stderr, time_limit, memory_limit) def input_data(self): + return self.input_data_fd().to_bytes() + + def input_data_fd(self): + if not self._input_data_fd: + self._input_data_fd = self._make_input_data_fd() + return self._input_data_fd + + def _make_input_data_fd(self): gen = self.config.generator # don't try running the generator if we specify an output file explicitly, @@ -337,10 +361,18 @@ def input_data(self): if gen and (not self.config['out'] or not self.config['in']): if self._generated is None: self._run_generator(gen, args=self.config.generator_args) + # FIXME: generate into the MemoryIO. if self._generated[0]: - return self._generated[0] + memory = MemoryIO() + memory.write(self._generated[0]) + memory.seal() + return memory + # in file is optional - return self._normalize(self.problem.problem_data[self.config['in']]) if self.config['in'] else b'' + if self.config['in']: + return self.problem.problem_data.as_fd(self.config['in'], normalize=not self.has_binary_data) + else: + return MemoryIO(seal=True) def output_data(self): if self.config.out: @@ -376,13 +408,15 @@ def checker(self): def free_data(self): self._generated = None + if self._input_data_fd: + self._input_data_fd.close() def __str__(self): return 'TestCase{in=%s,out=%s,points=%s}' % (self.config['in'], self.config['out'], self.config['points']) # FIXME(tbrindus): this is a hack working around the fact we can't pickle these fields, but we do need parts of # TestCase itself on the other end of the IPC. - _pickle_blacklist = ('_generated', 'config', 'problem') + _pickle_blacklist = ('_generated', 'config', 'problem', '_input_data_fd') def __getstate__(self): k = {k: v for k, v in self.__dict__.items() if k not in self._pickle_blacklist} diff --git a/dmoj/tests/test_normalize.py b/dmoj/tests/test_normalize.py new file mode 100644 index 000000000..b7fa2d868 --- /dev/null +++ b/dmoj/tests/test_normalize.py @@ -0,0 +1,56 @@ +import unittest +from io import BytesIO + +from dmoj.utils.normalize import normalized_file_copy + +TEST_CASE = b'a\r\n\r\r\nb\r\r\nc\nd\n' +TEST_CASE_NO_NEWLINE = b'a\r\n\r\r\nb\r\r\nc\nd' +TEST_CASE_TRAILING_R = b'a\r\n\r\r\nb\r\r\nc\nd\r' +RESULT = b'a\n\n\nb\n\nc\nd\n' + + +class TestNormalizedCopy(unittest.TestCase): + def test_simple(self): + with BytesIO(TEST_CASE) as src, BytesIO() as dst: + normalized_file_copy(src, dst) + self.assertEqual(dst.getvalue(), RESULT) + + def test_newline_add(self): + with BytesIO(TEST_CASE_NO_NEWLINE) as src, BytesIO() as dst: + normalized_file_copy(src, dst) + self.assertEqual(dst.getvalue(), RESULT) + + def test_break_after_r(self): + with BytesIO(TEST_CASE) as src, BytesIO() as dst: + normalized_file_copy(src, dst, block_size=TEST_CASE.rindex(b'\r\n')) + self.assertEqual(dst.getvalue(), RESULT) + + def test_break_after_r_newline_add(self): + with BytesIO(TEST_CASE_NO_NEWLINE) as src, BytesIO() as dst: + normalized_file_copy(src, dst, block_size=TEST_CASE_NO_NEWLINE.rindex(b'\r\n')) + self.assertEqual(dst.getvalue(), RESULT) + + def test_break_between_r_n(self): + with BytesIO(TEST_CASE) as src, BytesIO() as dst: + normalized_file_copy(src, dst, block_size=TEST_CASE.rindex(b'\r\n') + 1) + self.assertEqual(dst.getvalue(), RESULT) + + def test_break_between_r_n_newline_add(self): + with BytesIO(TEST_CASE_NO_NEWLINE) as src, BytesIO() as dst: + normalized_file_copy(src, dst, block_size=TEST_CASE_NO_NEWLINE.rindex(b'\r\n') + 1) + self.assertEqual(dst.getvalue(), RESULT) + + def test_break_before_trailing_newline(self): + with BytesIO(TEST_CASE) as src, BytesIO() as dst: + normalized_file_copy(src, dst, block_size=len(TEST_CASE) - 1) + self.assertEqual(dst.getvalue(), RESULT) + + def test_trailing_r(self): + with BytesIO(TEST_CASE_TRAILING_R) as src, BytesIO() as dst: + normalized_file_copy(src, dst) + self.assertEqual(dst.getvalue(), RESULT) + + def test_break_before_trailing_r(self): + with BytesIO(TEST_CASE_TRAILING_R) as src, BytesIO() as dst: + normalized_file_copy(src, dst, block_size=len(TEST_CASE_TRAILING_R)) + self.assertEqual(dst.getvalue(), RESULT) diff --git a/dmoj/utils/normalize.py b/dmoj/utils/normalize.py new file mode 100644 index 000000000..e03d4914b --- /dev/null +++ b/dmoj/utils/normalize.py @@ -0,0 +1,20 @@ +from io import TextIOWrapper + + +def normalized_file_copy(src, dst, block_size=16384): + src_wrap = TextIOWrapper(src, encoding='iso-8859-1', newline=None) + dst_wrap = TextIOWrapper(dst, encoding='iso-8859-1', newline='') + + add_newline = False + while True: + buf = src_wrap.read(block_size) + if not buf: + break + dst_wrap.write(buf) + add_newline = not buf.endswith('\n') + + if add_newline: + dst_wrap.write('\n') + + src_wrap.detach() + dst_wrap.detach() diff --git a/testsuite/shortest1/shortest1.py b/testsuite/shortest1/shortest1.py index a1ef8260d..e94a1914f 100644 --- a/testsuite/shortest1/shortest1.py +++ b/testsuite/shortest1/shortest1.py @@ -8,7 +8,7 @@ def check_result(self, case, result): result.result_flag &= ~Result.TLE & ~Result.RTE & ~Result.IR return CheckerResult(passed, min((9. / len(self.source)) ** 5 * case.points, case.points) if passed else 0) - def _interact_with_process(self, case, result, input): + def _interact_with_process(self, case, result): process = self._current_proc for handle in [process.stdin, process.stdout, process.stderr]: if handle: