Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shared memory permission issue in a shared pod environment #813

Merged
merged 43 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
c32f066
update
XiaohanZhangCMU Oct 25, 2024
7934dcc
update
XiaohanZhangCMU Oct 28, 2024
98ac253
update
XiaohanZhangCMU Oct 28, 2024
56c364e
update
XiaohanZhangCMU Oct 28, 2024
06a585d
lint
XiaohanZhangCMU Oct 28, 2024
beac17b
update
XiaohanZhangCMU Oct 28, 2024
e48c11b
add test
XiaohanZhangCMU Oct 28, 2024
c9f4ff3
update
XiaohanZhangCMU Oct 28, 2024
aa5a808
update
XiaohanZhangCMU Oct 29, 2024
60969b3
lint
XiaohanZhangCMU Oct 29, 2024
bd81ba9
update
XiaohanZhangCMU Oct 29, 2024
d4b4715
update
XiaohanZhangCMU Oct 29, 2024
ea388d4
update
XiaohanZhangCMU Oct 29, 2024
f87e1a9
add prints
XiaohanZhangCMU Oct 29, 2024
13884ff
update
XiaohanZhangCMU Oct 29, 2024
f588bd5
update
XiaohanZhangCMU Oct 29, 2024
aa3c798
update
XiaohanZhangCMU Oct 29, 2024
44144fd
refactoring
XiaohanZhangCMU Oct 30, 2024
cdd3472
update
XiaohanZhangCMU Oct 30, 2024
00fa94e
update
XiaohanZhangCMU Oct 30, 2024
5bb325a
update
XiaohanZhangCMU Oct 30, 2024
3bc6f5c
update
XiaohanZhangCMU Oct 30, 2024
1158e6d
update
XiaohanZhangCMU Oct 30, 2024
e00976c
update
XiaohanZhangCMU Oct 30, 2024
4d550fd
update
XiaohanZhangCMU Oct 30, 2024
14cb77c
update
XiaohanZhangCMU Oct 30, 2024
50edd61
update
XiaohanZhangCMU Oct 30, 2024
587532a
update
XiaohanZhangCMU Oct 30, 2024
161768f
update
XiaohanZhangCMU Oct 30, 2024
2c1915e
update
XiaohanZhangCMU Oct 30, 2024
fd56d62
update
XiaohanZhangCMU Oct 30, 2024
f401ba3
update
XiaohanZhangCMU Oct 30, 2024
8309cc4
update
XiaohanZhangCMU Oct 30, 2024
4bbe87b
update
XiaohanZhangCMU Oct 30, 2024
f6fc0d7
update
XiaohanZhangCMU Oct 30, 2024
4cf8ade
update
XiaohanZhangCMU Oct 30, 2024
0861452
update
XiaohanZhangCMU Oct 30, 2024
c8ced84
update
XiaohanZhangCMU Oct 30, 2024
1532d72
update
XiaohanZhangCMU Nov 1, 2024
4b98159
update
XiaohanZhangCMU Nov 1, 2024
0ec6b84
update
XiaohanZhangCMU Nov 1, 2024
5dc5d14
update
XiaohanZhangCMU Nov 1, 2024
9d90a23
update
XiaohanZhangCMU Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion streaming/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The Streaming Version."""

__version__ = '0.10.0.dev0'
__version__ = '0.10.0.dev2'
2 changes: 1 addition & 1 deletion streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def __init__(self,
]
self._shm_prefix_int, self._locals_shm = get_shm_prefix(streams_local, streams_remote,
self._unique_rank_world)
self._filelock_root = os.path.join(gettempdir(), 'streaming')
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
self._filelock_root = gettempdir()
os.makedirs(self._filelock_root, exist_ok=True)

# Create the shared memory-backed barrier, without its lock, which is unpickleable.
Expand Down
94 changes: 73 additions & 21 deletions streaming/base/shared/prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
prevent shared resources like shared memory from colliding.
"""

import os
from collections import Counter
from tempfile import gettempdir
from time import sleep
from typing import Iterator, Union

import numpy as np
from torch import distributed as dist

from streaming.base.constant import LOCALS, TICK
from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK, LOCALS, SHM_TO_CLEAN, TICK
from streaming.base.shared import SharedMemory
from streaming.base.world import World

Expand Down Expand Up @@ -91,7 +93,8 @@ def _check_self(streams_local: list[str]) -> None:
f'Reused local directory: {duplicate_local_dirs}. Provide a different one.')


def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, None]]) -> int:
def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, None]],
shm_name: str) -> int:
"""Find the next available prefix while checking existing local dirs for overlap.

Local leader walks the existing shm prefixes starting from zero, verifying that there is no
Expand All @@ -101,18 +104,43 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No
Args:
streams_local (List[str]): Our local working directories.
streams_remote (List[Union[str, None]]): Our remote working directories.
shm_name (str): The shared memory file name, e.g., LOCALS, BARRIER etc.

Returns:
int: Next available prefix int.
"""
prefix_int = 0

for prefix_int in _each_prefix_int():
name = _get_path(prefix_int, LOCALS)

print(f'{shm_name=} - {prefix_int=}')
temproot = gettempdir()
print(f'{temproot=}')
name = _get_path(prefix_int, shm_name)

# Check if any shared memory filelocks exist for the current prefix
try:
filelock_exists = any(
os.path.exists(os.path.join(gettempdir(), _get_path(prefix_int, filelock_name)))
for filelock_name in [BARRIER_FILELOCK, CACHE_FILELOCK])
if filelock_exists:
continue
except PermissionError:
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
continue

# Attempt to access shared memory by name. Use prefix_int if files do not exist
try:
shm = SharedMemory(name, False)
except PermissionError:
continue
except FileNotFoundError:
break

if shm_name != LOCALS:
continue

their_locals, _ = _unpack_locals(bytes(shm.buf))

# Do not check for a conflicting local directories across existing shared memory if
# remote directories are None. Get the next prefix.
if any(streams_remote):
Expand All @@ -127,15 +155,16 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No
raise ValueError(
f'Reused local directory: {streams_local} vs ' +
f'{their_locals}. Provide a different one. If using ' +
f'a unique local directory, try deleting the local directory and ' +
f'call `streaming.base.util.clean_stale_shared_memory()` only once ' +
f'in your script to clean up the stale shared memory before ' +
f'a unique local directory, try deleting the local directory and '
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
+
f'call `streaming.base.util.clean_stale_shared_memory()` only once '
+ f'in your script to clean up the stale shared memory before ' +
f'instantiation of `StreamingDataset`.')
return prefix_int


def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Union[str, None]],
retry: int) -> int:
shm_name: str, retry: int) -> int:
"""Find the next available prefix while checking existing dirs for overlap.

If an overlap is found, sleeps for a tick and then tries again, up to "retry" times. We allow
Expand All @@ -145,6 +174,7 @@ def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Unio
Args:
streams_local (List[str]): Our local working directories.
streams_remote (List[Union[str, None]]): Our remote working directories.
shm_name (str): The shared memory file name, e.g., LOCALS, BARRIER etc.
retry (int): Number of retries upon failure before raising an exception.

Returns:
Expand All @@ -155,7 +185,7 @@ def _check_and_find_retrying(streams_local: list[str], streams_remote: list[Unio
errs = []
for _ in range(1 + retry):
try:
return _check_and_find(streams_local, streams_remote)
return _check_and_find(streams_local, streams_remote, shm_name)
except ValueError as err:
errs.append(err)
sleep(TICK)
Expand Down Expand Up @@ -184,29 +214,51 @@ def get_shm_prefix(streams_local: list[str],
# Check my locals for overlap.
_check_self(streams_local)

prefix_int = max([
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
_check_and_find_retrying(streams_local, streams_remote, shm_name=shm_name, retry=retry)
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
for shm_name in SHM_TO_CLEAN
])
print(f'before first barrier, {world.rank=}')

if dist.is_available() and dist.is_initialized():
print(f'before: in dist.avail 1 : {world.rank=}')
dist.barrier()
print(f'after: in dist.avail 1 : {world.rank=}')
print(f'after first barrier, {world.rank=}')

# First, the local leader registers the first available shm prefix, recording its locals.
if world.is_local_leader:
prefix_int = _check_and_find_retrying(streams_local, streams_remote, retry)
print(f'in local leader')
name = _get_path(prefix_int, LOCALS)
data = _pack_locals(streams_local, prefix_int)
shm = SharedMemory(name, True, len(data))
shm.buf[:len(data)] = data
their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf))
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
print(f"In world leader: {name=}, {their_locals=}, {their_prefix_int}, {data=}")

sleep(3)
print(f'{world.rank=}')
if dist.is_available() and dist.is_initialized():
print(f'before: in dist.avail 2: {world.rank=}')
dist.barrier()
print(f'after: in dist.avail 2: {world.rank=}')

#sleep(3)

# Non-local leaders go next, searching for match.
if not world.is_local_leader:
for prefix_int in _each_prefix_int():
name = _get_path(prefix_int, LOCALS)
try:
shm = SharedMemory(name, False)
except FileNotFoundError:
raise RuntimeError(f'Internal error: shared memory prefix was not registered by ' +
f'local leader. This may be because you specified ' +
f'different ``local`` parameters from different ranks.')
their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf))
if streams_local == their_locals and prefix_int == their_prefix_int:
break

print(f'in non local world leader')
name = _get_path(prefix_int, LOCALS)
try:
shm = SharedMemory(name, False)
except FileNotFoundError:
raise RuntimeError(f'Internal error: shared memory prefix was not registered by ' +
f'local leader. This may be because you specified ' +
f'different ``local`` parameters from different ranks.')

their_locals, their_prefix_int = _unpack_locals(bytes(shm.buf))
print(f"{name=}, {streams_local=}, {their_locals=}, {prefix_int=}, {their_prefix_int}")
if streams_local != their_locals or prefix_int != their_prefix_int:
karan6181 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(f'Internal error: shared memory registered does not match ' +
f'local leader as streams_local or prefix_int not match.')
return prefix_int, shm # pyright: ignore
11 changes: 8 additions & 3 deletions streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,23 @@ def clean_stale_shared_memory() -> None:
# Initialize torch.distributed ourselves, if necessary.
destroy_dist = maybe_init_dist()

print('I am here')
# Perform clean up on local rank 0
if get_local_rank() == 0:
for prefix_int in range(1000000):
print(f'shm -- {prefix_int=}')
leaked_shm = False
for shm_name in SHM_TO_CLEAN:
name = _get_path(prefix_int, shm_name)
try:
shm = BuiltinSharedMemory(name, True, 4)
except FileExistsError:
shm = BuiltinSharedMemory(name, False, 4)
leaked_shm = True
finally:
try:
shm = BuiltinSharedMemory(name, False, 4)
leaked_shm = True
except PermissionError:
continue
if shm:
shm.close() # pyright: ignore
shm.unlink()
# Come out of loop if no leaked shared memory
Expand Down
33 changes: 33 additions & 0 deletions tests/test_shared.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

import os
import tempfile
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from streaming.base import StreamingDataset
from streaming.base.constant import LOCALS
from streaming.base.shared import SharedArray, get_shm_prefix
from streaming.base.shared.memory import SharedMemory
from streaming.base.shared.prefix import _check_and_find
from streaming.base.util import clean_stale_shared_memory
from streaming.base.world import World
from tests.common.utils import convert_to_mds

Expand Down Expand Up @@ -157,3 +163,30 @@ def test_shared_array_size_is_integer(mock_shared_memory: MagicMock, dtype: type
mock_shared_memory.assert_called_once() # pyright: ignore
size_arg = mock_shared_memory.call_args[1]['size']
assert isinstance(size_arg, int), 'Size passed to SharedMemory is not an integer'


def test_check_and_find_skips_filelock_conflict():
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
"""Test _check_and_find skips prefix due to file lock conflict."""
clean_stale_shared_memory()

with patch('os.path.exists') as mock_exists, \
patch('multiprocessing.shared_memory.SharedMemory', side_effect=FileNotFoundError):
# Simulate that `/000000.barrier_filelock` exists, indicating a lock conflict
bf_path = os.path.join(tempfile.gettempdir(), '000000_barrier_filelock')
mock_exists.side_effect = lambda path: path == bf_path

# Expect _check_and_find to return 1 as the next available prefix
next_prefix = _check_and_find(['local_dir'], [None], LOCALS)
assert next_prefix == 1


@patch.object(SharedMemory,
'__init__',
side_effect=[
PermissionError('Mocked permission error'),
FileNotFoundError('Mocked file not found error')
])
def test_shared_memory_permission_error(mock_shared_memory_class: MagicMock):
with patch('os.path.exists', return_value=False):
next_prefix = _check_and_find(['local'], [None], LOCALS)
assert next_prefix == 1
Loading