Skip to content

Commit

Permalink
Improve error message in _read_snapshot_metadata
Browse files Browse the repository at this point in the history
Summary:
This is to mitigate confusion about what went wrong about the
snapshot.

Reviewed By: JKSenthil

Differential Revision: D54705863
  • Loading branch information
schwarzmx authored and facebook-github-bot committed Mar 11, 2024
1 parent 514e43d commit ac51bab
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
19 changes: 19 additions & 0 deletions tests/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import copy
from pathlib import Path
from typing import Any, Dict, List
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -226,3 +227,21 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
snapshot = Snapshot.take(app_state={"state": src}, path=str(tmp_path))
snapshot.restore(app_state={"state": dst})
assert check_state_dict_eq(src.state_dict(), dst.state_dict())


@pytest.mark.usefixtures("toggle_batching")
def test_snapshot_metadata_error(tmp_path: Path) -> None:
mock_storage_plugin = MagicMock()
mock_event_loop = MagicMock()
mock_storage_plugin.sync_read.side_effect = Exception(
"Mock error reading from storage"
)
with pytest.raises(
expected_exception=RuntimeError,
match=(
"Failed to read .snapshot_metadata. "
"Ensure path to snapshot is correct, "
"otherwise snapshot is likely incomplete or corrupted."
),
):
Snapshot._read_snapshot_metadata(mock_storage_plugin, mock_event_loop)
9 changes: 8 additions & 1 deletion torchsnapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,14 @@ def _read_snapshot_metadata(
storage: StoragePlugin, event_loop: asyncio.AbstractEventLoop
) -> SnapshotMetadata:
read_io = ReadIO(path=SNAPSHOT_METADATA_FNAME)
storage.sync_read(read_io=read_io, event_loop=event_loop)
try:
storage.sync_read(read_io=read_io, event_loop=event_loop)
except Exception as e:
raise RuntimeError(
f"Failed to read {SNAPSHOT_METADATA_FNAME}. "
"Ensure path to snapshot is correct, "
"otherwise snapshot is likely incomplete or corrupted."
) from e
yaml_str = read_io.buf.getvalue().decode("utf-8")
return SnapshotMetadata.from_yaml(yaml_str)

Expand Down

0 comments on commit ac51bab

Please sign in to comment.