Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaohanZhangCMU committed Nov 6, 2024
1 parent 066edce commit e68903e
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,7 @@ def test_ft_dataloader_with_extra_keys():
# TODO: Change this back to xfail after figuring out why it caused CI to hang
@pytest.mark.skip
def test_text_dataloader_with_extra_keys():
from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK
max_seq_len = 1024
cfg = {
'dataset': {
Expand Down Expand Up @@ -1549,15 +1550,23 @@ def test_text_dataloader_with_extra_keys():

device_batch_size = 2

mock_stat = MagicMock()
mock_stat.st_size = 1024 # Mock st_size with a desired value
mock_stat.st_mode = 33188 # Regular file mode for Unix-based systems
def custom_stat_mock(path: Any):
if any([BARRIER_FILELOCK in path, CACHE_FILELOCK in path]):
return original_os_stat(path)
else:
mock_stat = MagicMock()
mock_stat.st_size = 1024 # Mock st_size with a desired value
mock_stat.st_mode = 33188 # Regular file mode for Unix-based systems
return mock_stat

original_os_stat = os.stat


#with patch('streaming.base.stream.get_shards', return_value=None):
with patch('os.makedirs'), \
patch('builtins.open', new_callable=mock_open, read_data='{"version": 2, "shards": []}'), \
patch('json.load') as mock_json_load, \
patch('os.stat', return_value=mock_stat), \
patch('os.stat', side_effect=custom_stat_mock), \
patch('torch.distributed.is_available', return_value=True), \
patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.broadcast_object_list'), \
Expand Down

0 comments on commit e68903e

Please sign in to comment.