Skip to content

Commit

Permalink
Add upper bound for prefix_int
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaohanZhangCMU committed Nov 5, 2024
1 parent 06b1d7f commit d19dce6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
3 changes: 3 additions & 0 deletions streaming/base/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@

# Default download timeout
DEFAULT_TIMEOUT = 60.0

# Maximum prefix integers
MAX_PREFIX_INT = 1000
8 changes: 7 additions & 1 deletion streaming/base/shared/prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
from torch import distributed as dist

from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK, LOCALS, SHM_TO_CLEAN, TICK
from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK, LOCALS, SHM_TO_CLEAN, TICK, MAX_PREFIX_INT
from streaming.base.shared import SharedMemory
from streaming.base.world import World

Expand Down Expand Up @@ -113,6 +113,12 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No

for prefix_int in _each_prefix_int():

if prefix_int >= MAX_PREFIX_INT:
raise ValueError(f"prefix_int exceeds {MAX_PREFIX_INT}. This may happen " +
f"when you mock os.path.exists or os.stat so the filelock " +
f"checks always returns ``True`` " +
f"you need to clean up TMPDIR.")

name = _get_path(prefix_int, shm_name)

# Check if any shared memory filelocks exist for the current prefix
Expand Down
13 changes: 13 additions & 0 deletions tests/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,16 @@ def test_shared_memory_permission_error(mock_shared_memory_class: MagicMock):
with patch('os.path.exists', return_value=False):
next_prefix = _check_and_find(['local'], [None], LOCALS)
assert next_prefix == 1


@pytest.mark.usefixtures('local_remote_dir')
def test_shared_memory_infinity_exception(local_remote_dir: tuple[str, str]):
local, remote = local_remote_dir
with patch('os.path.exists', return_value=True):
with pytest.raises(ValueError, match='prefix_int exceeds .*clean up TMPDIR.'):
_, _ = get_shm_prefix(streams_local=[local],
streams_remote=[remote],
world=World.detect())



0 comments on commit d19dce6

Please sign in to comment.