Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switching to using memfd for input data #990

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions dmoj/checkers/bridged.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess

from dmoj.contrib import contrib_modules
from dmoj.cptbox.filesystem_policies import ExactFile
from dmoj.error import InternalError
from dmoj.judgeenv import env, get_problem_root
from dmoj.result import CheckerResult
Expand All @@ -25,10 +26,10 @@ def get_executor(problem_id, files, flags, lang, compiler_time_limit):
def check(
process_output,
judge_output,
judge_input,
problem_id,
files,
lang,
case,
time_limit=env['generator_time_limit'],
memory_limit=env['generator_memory_limit'],
compiler_time_limit=env['generator_compiler_limit'],
Expand All @@ -46,16 +47,23 @@ def check(

args_format_string = args_format_string or contrib_modules[type].ContribModule.get_checker_args_format_string()

with mktemp(judge_input) as input_file, mktemp(process_output) as output_file, mktemp(judge_output) as answer_file:
with mktemp(process_output) as output_file, mktemp(judge_output) as answer_file:
input_path = case.input_data_fd().to_path()

checker_args = shlex.split(
args_format_string.format(
input_file=shlex.quote(input_file.name),
input_file=shlex.quote(input_path),
output_file=shlex.quote(output_file.name),
answer_file=shlex.quote(answer_file.name),
)
)
process = executor.launch(
*checker_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, memory=memory_limit, time=time_limit
*checker_args,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
memory=memory_limit,
time=time_limit,
extra_fs=[ExactFile(input_path)],
)

proc_output, error = process.communicate()
Expand Down
3 changes: 3 additions & 0 deletions dmoj/cptbox/_cptbox.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ bsd_get_proc_fdno: Callable[[int, int], str]

memory_fd_create: Callable[[], int]
memory_fd_seal: Callable[[int], None]

class BufferProxy:
def _get_real_buffer(self): ...
9 changes: 9 additions & 0 deletions dmoj/cptbox/_cptbox.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# cython: language_level=3
from cpython.exc cimport PyErr_NoMemory, PyErr_SetFromErrno
from cpython.buffer cimport PyObject_GetBuffer
from cpython.bytes cimport PyBytes_AsString, PyBytes_FromStringAndSize
from libc.stdio cimport FILE, fopen, fclose, fgets, sprintf
from libc.stdlib cimport malloc, free, strtoul
Expand Down Expand Up @@ -576,3 +577,11 @@ cdef class Process:
if not self._exited:
return None
return self._exitcode


cdef class BufferProxy:
def _get_real_buffer(self):
raise NotImplementedError

def __getbuffer__(self, Py_buffer *buffer, int flags):
PyObject_GetBuffer(self._get_real_buffer(), buffer, flags)
2 changes: 1 addition & 1 deletion dmoj/cptbox/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ int memory_fd_create(void) {
#ifdef __FreeBSD__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, is this function called on FreeBSD anymore? Are you creating the tempfile in Python instead?

char filename[] = "/tmp/cptbox-memoryfd-XXXXXXXX";
int fd = mkstemp(filename);
if (fd > 0)
if (fd >= 0)
unlink(filename);
return fd;
#else
Expand Down
2 changes: 1 addition & 1 deletion dmoj/cptbox/isolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _file_access_check(
real = os.path.realpath(file)

try:
same = normalized == real or os.path.samefile(projected, real)
same = normalized == real or real.startswith('/memfd:') or os.path.samefile(projected, real)
except OSError:
log.debug('Denying access due to inability to stat: normalizes to: %s, actually: %s', normalized, real)
return file, ACCESS_ENOENT
Expand Down
88 changes: 88 additions & 0 deletions dmoj/cptbox/lazy_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Based off https://github.com/django/django/blob/main/django/utils/functional.py, licensed under 3-clause BSD.
from functools import total_ordering

from dmoj.cptbox._cptbox import BufferProxy

_SENTINEL = object()


@total_ordering
class LazyBytes(BufferProxy):
"""
Encapsulate a function call and act as a proxy for methods that are
called on the result of that function. The function is not evaluated
until one of the methods on the result is called.
"""

def __init__(self, func):
self.__func = func
self.__value = _SENTINEL

def __get_value(self):
if self.__value is _SENTINEL:
self.__value = self.__func()
return self.__value

@classmethod
def _create_promise(cls, method_name):
# Builds a wrapper around some magic method
def wrapper(self, *args, **kw):
# Automatically triggers the evaluation of a lazy value and
# applies the given magic method of the result type.
res = self.__get_value()
return getattr(res, method_name)(*args, **kw)

return wrapper

def __cast(self):
return bytes(self.__get_value())

def _get_real_buffer(self):
return self.__cast()

def __bytes__(self):
return self.__cast()

def __repr__(self):
return repr(self.__cast())

def __str__(self):
return str(self.__cast())

def __eq__(self, other):
if isinstance(other, LazyBytes):
other = other.__cast()
return self.__cast() == other

def __lt__(self, other):
if isinstance(other, LazyBytes):
other = other.__cast()
return self.__cast() < other

def __hash__(self):
return hash(self.__cast())

def __mod__(self, rhs):
return self.__cast() % rhs

def __add__(self, other):
return self.__cast() + other

def __radd__(self, other):
return other + self.__cast()

def __deepcopy__(self, memo):
# Instances of this class are effectively immutable. It's just a
# collection of functions. So we don't need to do anything
# complicated for copying.
memo[id(self)] = self
return self


for type_ in bytes.mro():
for method_name in type_.__dict__:
# All __promise__ return the same wrapper method, they
# look up the correct implementation when called.
if hasattr(LazyBytes, method_name):
continue
setattr(LazyBytes, method_name, LazyBytes._create_promise(method_name))
57 changes: 54 additions & 3 deletions dmoj/cptbox/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,62 @@
import errno
import io
import mmap
import os
from tempfile import NamedTemporaryFile
from typing import Optional

from dmoj.cptbox._cptbox import memory_fd_create, memory_fd_seal
from dmoj.cptbox.tracer import FREEBSD


class MemoryIO(io.FileIO):
def __init__(self) -> None:
super().__init__(memory_fd_create(), 'r+')
_name: Optional[str] = None

def __init__(self, prefill: Optional[bytes] = None, seal=False) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe one or both of these should be required kwargs. I'm thinking the second should. What is the difference between prefilling with nothing, and passing None?

if FREEBSD:
with NamedTemporaryFile(delete=False) as f:
self._name = f.name
super().__init__(os.dup(f.fileno()), 'r+')
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to dup and specify delete=False ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. Otherwise the fd gets closed or the file gets unlinked, respectively.

super().__init__(memory_fd_create(), 'r+')

if prefill:
self.write(prefill)
if seal:
self.seal()

def seal(self) -> None:
memory_fd_seal(self.fileno())
fd = self.fileno()
try:
memory_fd_seal(fd)
except OSError as e:
if e.errno == errno.ENOSYS:
# FreeBSD
self.seek(0, os.SEEK_SET)
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure what this does. Does it deserve more of a comment?

raise

new_fd = os.open(f'/proc/self/fd/{fd}', os.O_RDONLY)
try:
os.dup2(new_fd, fd)
finally:
os.close(new_fd)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could use a comment about why this dup is needed. Also, why isn't it implemented in the C function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C code is just that much more painful to maintain.


def close(self) -> None:
super().close()
if self._name:
os.unlink(self._name)

def to_path(self) -> str:
if self._name:
return self._name
return f'/proc/{os.getpid()}/fd/{self.fileno()}'

def to_bytes(self) -> bytes:
try:
with mmap.mmap(self.fileno(), 0, access=mmap.ACCESS_READ) as f:
return bytes(f)
except ValueError as e:
if e.args[0] == 'cannot mmap an empty file':
return b''
raise
9 changes: 6 additions & 3 deletions dmoj/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,11 @@ def _add_syscalls(self, sec: IsolateTracer) -> IsolateTracer:
sec[getattr(syscalls, f'sys_{name}')] = handler
return sec

def get_security(self, launch_kwargs=None) -> IsolateTracer:
sec = IsolateTracer(self.get_fs(), write_fs=self.get_write_fs())
def get_security(self, launch_kwargs=None, extra_fs=None) -> IsolateTracer:
read_fs = self.get_fs()
if extra_fs:
read_fs += extra_fs
sec = IsolateTracer(read_fs, write_fs=self.get_write_fs())
return self._add_syscalls(sec)

def get_fs(self) -> List[FilesystemAccessRule]:
Expand Down Expand Up @@ -281,7 +284,7 @@ def launch(self, *args, **kwargs) -> TracedPopen:
return TracedPopen(
[utf8bytes(a) for a in self.get_cmdline(**kwargs) + list(args)],
executable=utf8bytes(executable),
security=self.get_security(launch_kwargs=kwargs),
security=self.get_security(launch_kwargs=kwargs, extra_fs=kwargs.get('extra_fs')),
address_grace=self.get_address_grace(),
data_grace=self.data_grace,
personality=self.personality,
Expand Down
4 changes: 2 additions & 2 deletions dmoj/executors/shell_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def get_fs(self):
def get_allowed_syscalls(self):
return super().get_allowed_syscalls() + ['fork', 'waitpid', 'wait4']

def get_security(self, launch_kwargs=None):
def get_security(self, launch_kwargs=None, extra_fs=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that this is a direct result of this PR, but maybe this should be converted to **kwargs instead?

from dmoj.cptbox.syscalls import sys_execve, sys_access, sys_eaccess

sec = super().get_security(launch_kwargs)
sec = super().get_security(launch_kwargs=launch_kwargs, extra_fs=extra_fs)
allowed = set(self.get_allowed_exec())

def handle_execve(debugger):
Expand Down
12 changes: 8 additions & 4 deletions dmoj/graders/bridged.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess

from dmoj.contrib import contrib_modules
from dmoj.cptbox.filesystem_policies import ExactFile
from dmoj.error import InternalError
from dmoj.graders.standard import StandardGrader
from dmoj.judgeenv import env, get_problem_root
Expand Down Expand Up @@ -38,7 +39,7 @@ def check_result(self, case, result):

return (not result.result_flag) and parsed_result

def _launch_process(self, case):
def _launch_process(self, case, input_file=None):
self._interactor_stdin_pipe, submission_stdout_pipe = os.pipe()
submission_stdin_pipe, self._interactor_stdout_pipe = os.pipe()
self._current_proc = self.binary.launch(
Expand All @@ -53,7 +54,7 @@ def _launch_process(self, case):
os.close(submission_stdin_pipe)
os.close(submission_stdout_pipe)

def _interact_with_process(self, case, result, input):
def _interact_with_process(self, case, result):
judge_output = case.output_data()
# Give TL + 2s by default, so we do not race (and incorrectly throw IE) if submission gets TLE
self._interactor_time_limit = (self.handler_data.preprocessing_time or 2) + self.problem.time_limit
Expand All @@ -63,14 +64,16 @@ def _interact_with_process(self, case, result, input):
or contrib_modules[self.contrib_type].ContribModule.get_interactor_args_format_string()
)

with mktemp(input) as input_file, mktemp(judge_output) as answer_file:
with mktemp(judge_output) as answer_file:
input_path = case.input_data_fd().to_path()

# TODO(@kirito): testlib.h expects a file they can write to,
# but we currently don't have a sane way to allow this.
# Thus we pass /dev/null for now so testlib interactors will still
# work, albeit with diminished capabilities
interactor_args = shlex.split(
args_format_string.format(
input_file=shlex.quote(input_file.name),
input_file=shlex.quote(input_path),
output_file=shlex.quote(os.devnull),
answer_file=shlex.quote(answer_file.name),
)
Expand All @@ -82,6 +85,7 @@ def _interact_with_process(self, case, result, input):
stdin=self._interactor_stdin_pipe,
stdout=self._interactor_stdout_pipe,
stderr=subprocess.PIPE,
extra_fs=[ExactFile(input_path)],
)

os.close(self._interactor_stdin_pipe)
Expand Down
5 changes: 4 additions & 1 deletion dmoj/graders/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def close(self):


class InteractiveGrader(StandardGrader):
def _interact_with_process(self, case, result, input):
def _launch_process(self, case, input_file=None):
super()._launch_process(case, input_file=None)

def _interact_with_process(self, case, result):
interactor = Interactor(self._current_proc)
self.check = False
self.feedback = None
Expand Down
Loading