From 5bb325aa66f01dc984445d22995d47cf3fc9c72b Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Tue, 29 Oct 2024 23:46:35 -0700 Subject: [PATCH] update --- streaming/base/shared/prefix.py | 94 +++++++++++++++------------------ tests/test_shared.py | 4 +- 2 files changed, 46 insertions(+), 52 deletions(-) diff --git a/streaming/base/shared/prefix.py b/streaming/base/shared/prefix.py index e514334c6..bff7dc067 100644 --- a/streaming/base/shared/prefix.py +++ b/streaming/base/shared/prefix.py @@ -94,7 +94,7 @@ def _check_self(streams_local: list[str]) -> None: def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, None]], - shm_name: str, local_leader: bool) -> int: + shm_name: str) -> int: """Find the next available prefix while checking existing local dirs for overlap. Local leader walks the existing shm prefixes starting from zero, verifying that there is no @@ -110,12 +110,13 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No """ prefix_int = 0 for prefix_int in _each_prefix_int(): + print(f'{shm_name=} - {prefix_int=}') + temproot = gettempdir() + print(f'{temproot=}') name = _get_path(prefix_int, shm_name) # Check if any shared memory filelocks exist for the current prefix try: - print(f'{prefix_int=}') - print(f'{os.path.exists!r}') filelock_exists = any( os.path.exists(os.path.join(gettempdir(), _get_path(prefix_int, filelock_name))) for filelock_name in [BARRIER_FILELOCK, CACHE_FILELOCK]) @@ -130,42 +131,34 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No except PermissionError: continue except FileNotFoundError: - if not local_leader and shm_name == LOCALS: - raise RuntimeError(f'Internal error: shared memory prefix was not registered by ' + - f'local leader. This may be because you specified ' + - f'different ``local`` parameters from different ranks.') break - their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf)) - - if not local_leader: - if streams_local == their_locals and prefix_int == their_prefix_int: - break - continue - - # Do not check for a conflicting local directories across existing shared memory if - # remote directories are None. Get the next prefix. - if any(streams_remote): - # Get the indices of the local directories which matches with the current - # shared memory. - matching_index = np.where(np.isin(streams_local, their_locals))[0] - if matching_index.size > 0: - for idx in matching_index: - # If there is a conflicting local directory for a non-None remote directory, - # raise an exception. - if streams_remote[idx] is not None: - raise ValueError( - f'Reused local directory: {streams_local} vs ' + - f'{their_locals}. Provide a different one. If using ' + - f'a unique local directory, try deleting the local directory and ' + - f'call `streaming.base.util.clean_stale_shared_memory()` only once ' + - f'in your script to clean up the stale shared memory before ' + - f'instantiation of `StreamingDataset`.') + if shm_name == LOCALS: + their_locals, _ = _unpack_locals(bytes(shm.buf)) + + # Do not check for a conflicting local directories across existing shared memory if + # remote directories are None. Get the next prefix. + if any(streams_remote): + # Get the indices of the local directories which matches with the current + # shared memory. + matching_index = np.where(np.isin(streams_local, their_locals))[0] + if matching_index.size > 0: + for idx in matching_index: + # If there is a conflicting local directory for a non-None remote directory, + # raise an exception. + if streams_remote[idx] is not None: + raise ValueError( + f'Reused local directory: {streams_local} vs ' + + f'{their_locals}. Provide a different one. If using ' + + f'a unique local directory, try deleting the local directory and ' + + f'call `streaming.base.util.clean_stale_shared_memory()` only once ' + + f'in your script to clean up the stale shared memory before ' + + f'instantiation of `StreamingDataset`.') return prefix_int def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Union[str, None]], - shm_name: str, local_leader: bool, retry: int) -> int: + shm_name: str, retry: int) -> int: """Find the next available prefix while checking existing dirs for overlap. If an overlap is found, sleeps for a tick and then tries again, up to "retry" times. We allow @@ -185,7 +178,7 @@ def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Unio errs = [] for _ in range(1 + retry): try: - return _check_and_find(streams_local, streams_remote, shm_name, local_leader) + return _check_and_find(streams_local, streams_remote, shm_name) except ValueError as err: errs.append(err) sleep(TICK) @@ -214,34 +207,35 @@ def get_shm_prefix(streams_local: list[str], # Check my locals for overlap. _check_self(streams_local) + prefix_int = max([ + _check_and_find_retrying(streams_local, + streams_remote, + shm_name=shm_name, + retry=retry) for shm_name in SHM_TO_CLEAN + ]) + # First, the local leader registers the first available shm prefix, recording its locals. if world.is_local_leader: - prefix_int = max([ - _check_and_find_retrying(streams_local, - streams_remote, - shm_name=shm_name, - local_leader=True, - retry=retry) for shm_name in SHM_TO_CLEAN - ]) name = _get_path(prefix_int, LOCALS) data = _pack_locals(streams_local, prefix_int) shm = SharedMemory(name, True, len(data)) shm.buf[:len(data)] = data + their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf)) if dist.is_available() and dist.is_initialized(): dist.barrier() # Non-local leaders go next, searching for match. if not world.is_local_leader: - prefix_int = 0 - prefix_int = max([ - _check_and_find_retrying(streams_local, - streams_remote, - shm_name=shm_name, - local_leader=False, - retry=1) for shm_name in SHM_TO_CLEAN - ]) name = _get_path(prefix_int, LOCALS) - shm = SharedMemory(name, False) + try: + shm = SharedMemory(name, False) + except FileNotFoundError: + raise RuntimeError(f'Internal error: shared memory prefix was not registered by ' + + f'local leader. This may be because you specified ' + + f'different ``local`` parameters from different ranks.') + if streams_local != their_locals or prefix_int != their_prefix_int: + raise RuntimeError(f'Internal error: shared memory registered does not match ' + + f'local leader as streams_local or prefix_int not match.') return prefix_int, shm # pyright: ignore diff --git a/tests/test_shared.py b/tests/test_shared.py index 877432801..d1914617d 100644 --- a/tests/test_shared.py +++ b/tests/test_shared.py @@ -176,7 +176,7 @@ def test_check_and_find_skips_filelock_conflict(): mock_exists.side_effect = lambda path: path == bf_path # Expect _check_and_find to return 1 as the next available prefix - next_prefix = _check_and_find(['local_dir'], [None], LOCALS, True) + next_prefix = _check_and_find(['local_dir'], [None], LOCALS) assert next_prefix == 1 @@ -188,5 +188,5 @@ def test_check_and_find_skips_filelock_conflict(): ]) def test_shared_memory_permission_error(mock_shared_memory_class: MagicMock): with patch('os.path.exists', return_value=False): - next_prefix = _check_and_find(['local'], [None], LOCALS, True) + next_prefix = _check_and_find(['local'], [None], LOCALS) assert next_prefix == 1