Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaohanZhangCMU committed Oct 30, 2024
1 parent 44144fd commit cdd3472
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
29 changes: 20 additions & 9 deletions streaming/base/shared/prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
from torch import distributed as dist

from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK, SHM_TO_CLEAN, LOCALS, TICK
from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK, LOCALS, SHM_TO_CLEAN, TICK
from streaming.base.shared import SharedMemory
from streaming.base.world import World

Expand Down Expand Up @@ -93,10 +93,8 @@ def _check_self(streams_local: list[str]) -> None:
f'Reused local directory: {duplicate_local_dirs}. Provide a different one.')


def _check_and_find(streams_local: list[str],
streams_remote: list[Union[str, None]],
shm_name: str,
local_leader: bool) -> int:
def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, None]],
shm_name: str, local_leader: bool) -> int:
"""Find the next available prefix while checking existing local dirs for overlap.
Local leader walks the existing shm prefixes starting from zero, verifying that there is no
Expand All @@ -116,8 +114,8 @@ def _check_and_find(streams_local: list[str],

# Check if any shared memory filelocks exist for the current prefix
try:
print(f"{prefix_int=}")
print(f"{os.path.exists!r}")
print(f'{prefix_int=}')
print(f'{os.path.exists!r}')
filelock_exists = any(
os.path.exists(os.path.join(gettempdir(), _get_path(prefix_int, filelock_name)))
for filelock_name in [BARRIER_FILELOCK, CACHE_FILELOCK])
Expand Down Expand Up @@ -165,6 +163,7 @@ def _check_and_find(streams_local: list[str],
f'instantiation of `StreamingDataset`.')
return prefix_int


def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Union[str, None]],
shm_name: str, local_leader: bool, retry: int) -> int:
"""Find the next available prefix while checking existing dirs for overlap.
Expand Down Expand Up @@ -217,7 +216,13 @@ def get_shm_prefix(streams_local: list[str],

# First, the local leader registers the first available shm prefix, recording its locals.
if world.is_local_leader:
prefix_int = max([_check_and_find_retrying(streams_local, streams_remote, shm_name=shm_name, local_leader=True, retry=retry) for shm_name in SHM_TO_CLEAN])
prefix_int = max([
_check_and_find_retrying(streams_local,
streams_remote,
shm_name=shm_name,
local_leader=True,
retry=retry) for shm_name in SHM_TO_CLEAN
])
name = _get_path(prefix_int, LOCALS)
data = _pack_locals(streams_local, prefix_int)
shm = SharedMemory(name, True, len(data))
Expand All @@ -229,7 +234,13 @@ def get_shm_prefix(streams_local: list[str],
# Non-local leaders go next, searching for match.
if not world.is_local_leader:
prefix_int = 0
prefix_int = max([_check_and_find_retrying(streams_local, streams_remote, shm_name=shm_name, local_leader=False, retry=2) for shm_name in SHM_TO_CLEAN])
prefix_int = max([
_check_and_find_retrying(streams_local,
streams_remote,
shm_name=shm_name,
local_leader=False,
retry=2) for shm_name in SHM_TO_CLEAN
])
name = _get_path(prefix_int, LOCALS)
shm = SharedMemory(name, False)

Expand Down
2 changes: 1 addition & 1 deletion streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import torch.distributed as dist

from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK, SHM_TO_CLEAN
from streaming.base.constant import SHM_TO_CLEAN
from streaming.base.distributed import get_local_rank, maybe_init_dist
from streaming.base.format.index import get_index_basename
from streaming.base.shared.prefix import _get_path
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import pytest

from streaming.base import StreamingDataset
from streaming.base.constant import LOCALS
from streaming.base.shared import SharedArray, get_shm_prefix
from streaming.base.shared.memory import SharedMemory
from streaming.base.shared.prefix import _check_and_find
from streaming.base.util import clean_stale_shared_memory
from streaming.base.world import World
from streaming.base.constant import LOCALS
from tests.common.utils import convert_to_mds


Expand Down

0 comments on commit cdd3472

Please sign in to comment.