Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix output capturing in ConcurrentWrapper for concurrent processes #261

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions egg/nest/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,28 @@ def __init__(self, runnable, log_dir, job_id):

def __call__(self, args):
stdout_path = pathlib.Path(self.log_dir) / f"{self.job_id}.out"
self.stdout = open(stdout_path, "w")

stderr_path = pathlib.Path(self.log_dir) / f"{self.job_id}.err"
self.stderr = open(stderr_path, "w")

sys.stdout = self.stdout
sys.stderr = self.stderr
cuda_id = -1
n_devices = torch.cuda.device_count()
if n_devices > 0:
cuda_id = self.job_id % n_devices
print(f"# {json.dumps(args)}", flush=True)

with torch.cuda.device(cuda_id):
self.runnable(args)

with open(stdout_path, "w") as self.stdout, open(
stderr_path, "w"
) as self.stderr:
original_stdout = sys.stdout
original_stderr = sys.stderr
sys.stdout = self.stdout
sys.stderr = self.stderr

cuda_id = -1
n_devices = torch.cuda.device_count()
if n_devices > 0:
cuda_id = self.job_id % n_devices

print(f"# {json.dumps(args)}", flush=True)

with torch.cuda.device(cuda_id):
self.runnable(args)

sys.stdout.flush()
sys.stderr.flush()

sys.stdout = original_stdout
sys.stderr = original_stderr
76 changes: 76 additions & 0 deletions tests/test_concurrent_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import json
import multiprocessing
import pathlib
import sys
import time

import pytest

from egg.nest.wrappers import ConcurrentWrapper

multiprocessing.set_start_method(
"spawn", force=True
) # avoiding issue with CUDA re-initialization in a forked subprocess


def dummy_runnable(args):
print("Running dummy_runnable")
print(json.dumps(args), file=sys.stderr)


def test_file_descriptor_closure(tmp_path):
"""
Test to check if file descriptors are closed.
Attempting to write to a closed file should raise a ValueError
"""
runnable = dummy_runnable
log_dir = tmp_path
job_id = 1

wrapper = ConcurrentWrapper(runnable, log_dir, job_id)
wrapper({"key": "value"})

with pytest.raises(ValueError):
wrapper.stdout.write("This should fail if the file is closed.")

with pytest.raises(ValueError):
wrapper.stderr.write("This should fail if the file is closed.")


def test_stdout_stderr_restoration(tmp_path):
"""Test to ensure sys.stdout and sys.stderr are restored"""
original_stdout = sys.stdout
original_stderr = sys.stderr

runnable = dummy_runnable
log_dir = tmp_path
job_id = 2

wrapper = ConcurrentWrapper(runnable, log_dir, job_id)
wrapper({"another_key": "another_value"})

assert sys.stdout == original_stdout
assert sys.stderr == original_stderr


def delayed_print_runnable(args):
print("This is a test.")
time.sleep(0.1) # Introduce a slight delay


def test_delayed_output_capture(tmp_path):
log_dir = tmp_path
job_id = 1

runner = ConcurrentWrapper(
runnable=delayed_print_runnable, log_dir=log_dir, job_id=job_id
)

runner([])

stdout_path = pathlib.Path(log_dir) / f"{job_id}.out"

with open(stdout_path, "r") as f:
output = f.read()

assert "This is a test." in output, "Expected output was not captured in the file."