Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaohanZhangCMU committed Oct 30, 2024
1 parent 00fa94e commit 5bb325a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 52 deletions.
94 changes: 44 additions & 50 deletions streaming/base/shared/prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _check_self(streams_local: list[str]) -> None:


def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, None]],
shm_name: str, local_leader: bool) -> int:
shm_name: str) -> int:
"""Find the next available prefix while checking existing local dirs for overlap.
Local leader walks the existing shm prefixes starting from zero, verifying that there is no
Expand All @@ -110,12 +110,13 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No
"""
prefix_int = 0
for prefix_int in _each_prefix_int():
print(f'{shm_name=} - {prefix_int=}')
temproot = gettempdir()
print(f'{temproot=}')
name = _get_path(prefix_int, shm_name)

# Check if any shared memory filelocks exist for the current prefix
try:
print(f'{prefix_int=}')
print(f'{os.path.exists!r}')
filelock_exists = any(
os.path.exists(os.path.join(gettempdir(), _get_path(prefix_int, filelock_name)))
for filelock_name in [BARRIER_FILELOCK, CACHE_FILELOCK])
Expand All @@ -130,42 +131,34 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No
except PermissionError:
continue
except FileNotFoundError:
if not local_leader and shm_name == LOCALS:
raise RuntimeError(f'Internal error: shared memory prefix was not registered by ' +
f'local leader. This may be because you specified ' +
f'different ``local`` parameters from different ranks.')
break

their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf))

if not local_leader:
if streams_local == their_locals and prefix_int == their_prefix_int:
break
continue

# Do not check for a conflicting local directories across existing shared memory if
# remote directories are None. Get the next prefix.
if any(streams_remote):
# Get the indices of the local directories which matches with the current
# shared memory.
matching_index = np.where(np.isin(streams_local, their_locals))[0]
if matching_index.size > 0:
for idx in matching_index:
# If there is a conflicting local directory for a non-None remote directory,
# raise an exception.
if streams_remote[idx] is not None:
raise ValueError(
f'Reused local directory: {streams_local} vs ' +
f'{their_locals}. Provide a different one. If using ' +
f'a unique local directory, try deleting the local directory and ' +
f'call `streaming.base.util.clean_stale_shared_memory()` only once ' +
f'in your script to clean up the stale shared memory before ' +
f'instantiation of `StreamingDataset`.')
if shm_name == LOCALS:
their_locals, _ = _unpack_locals(bytes(shm.buf))

# Do not check for a conflicting local directories across existing shared memory if
# remote directories are None. Get the next prefix.
if any(streams_remote):
# Get the indices of the local directories which matches with the current
# shared memory.
matching_index = np.where(np.isin(streams_local, their_locals))[0]
if matching_index.size > 0:
for idx in matching_index:
# If there is a conflicting local directory for a non-None remote directory,
# raise an exception.
if streams_remote[idx] is not None:
raise ValueError(
f'Reused local directory: {streams_local} vs ' +
f'{their_locals}. Provide a different one. If using ' +
f'a unique local directory, try deleting the local directory and ' +
f'call `streaming.base.util.clean_stale_shared_memory()` only once ' +
f'in your script to clean up the stale shared memory before ' +
f'instantiation of `StreamingDataset`.')
return prefix_int


def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Union[str, None]],
shm_name: str, local_leader: bool, retry: int) -> int:
shm_name: str, retry: int) -> int:
"""Find the next available prefix while checking existing dirs for overlap.
If an overlap is found, sleeps for a tick and then tries again, up to "retry" times. We allow
Expand All @@ -185,7 +178,7 @@ def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Unio
errs = []
for _ in range(1 + retry):
try:
return _check_and_find(streams_local, streams_remote, shm_name, local_leader)
return _check_and_find(streams_local, streams_remote, shm_name)
except ValueError as err:
errs.append(err)
sleep(TICK)
Expand Down Expand Up @@ -214,34 +207,35 @@ def get_shm_prefix(streams_local: list[str],
# Check my locals for overlap.
_check_self(streams_local)

prefix_int = max([
_check_and_find_retrying(streams_local,
streams_remote,
shm_name=shm_name,
retry=retry) for shm_name in SHM_TO_CLEAN
])

# First, the local leader registers the first available shm prefix, recording its locals.
if world.is_local_leader:
prefix_int = max([
_check_and_find_retrying(streams_local,
streams_remote,
shm_name=shm_name,
local_leader=True,
retry=retry) for shm_name in SHM_TO_CLEAN
])
name = _get_path(prefix_int, LOCALS)
data = _pack_locals(streams_local, prefix_int)
shm = SharedMemory(name, True, len(data))
shm.buf[:len(data)] = data
their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf))

if dist.is_available() and dist.is_initialized():
dist.barrier()

# Non-local leaders go next, searching for match.
if not world.is_local_leader:
prefix_int = 0
prefix_int = max([
_check_and_find_retrying(streams_local,
streams_remote,
shm_name=shm_name,
local_leader=False,
retry=1) for shm_name in SHM_TO_CLEAN
])
name = _get_path(prefix_int, LOCALS)
shm = SharedMemory(name, False)
try:
shm = SharedMemory(name, False)
except FileNotFoundError:
raise RuntimeError(f'Internal error: shared memory prefix was not registered by ' +
f'local leader. This may be because you specified ' +
f'different ``local`` parameters from different ranks.')

if streams_local != their_locals or prefix_int != their_prefix_int:
raise RuntimeError(f'Internal error: shared memory registered does not match ' +
f'local leader as streams_local or prefix_int not match.')
return prefix_int, shm # pyright: ignore
4 changes: 2 additions & 2 deletions tests/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_check_and_find_skips_filelock_conflict():
mock_exists.side_effect = lambda path: path == bf_path

# Expect _check_and_find to return 1 as the next available prefix
next_prefix = _check_and_find(['local_dir'], [None], LOCALS, True)
next_prefix = _check_and_find(['local_dir'], [None], LOCALS)
assert next_prefix == 1


Expand All @@ -188,5 +188,5 @@ def test_check_and_find_skips_filelock_conflict():
])
def test_shared_memory_permission_error(mock_shared_memory_class: MagicMock):
with patch('os.path.exists', return_value=False):
next_prefix = _check_and_find(['local'], [None], LOCALS, True)
next_prefix = _check_and_find(['local'], [None], LOCALS)
assert next_prefix == 1

0 comments on commit 5bb325a

Please sign in to comment.