diff --git a/src/caustics/io.py b/src/caustics/io.py index d5a146a8..ddb517c1 100644 --- a/src/caustics/io.py +++ b/src/caustics/io.py @@ -44,7 +44,9 @@ def to_file( # Normalize path to pathlib.Path object path = _normalize_path(path) - path.write_bytes(data) + with open(path, "wb") as f: + f.write(data) + return str(path.absolute()) diff --git a/src/caustics/sims/state_dict.py b/src/caustics/sims/state_dict.py index 66bcade0..49e26df3 100644 --- a/src/caustics/sims/state_dict.py +++ b/src/caustics/sims/state_dict.py @@ -2,7 +2,6 @@ from collections import OrderedDict from typing import Any, Dict, Optional from pathlib import Path -import os from torch import Tensor import torch @@ -185,7 +184,7 @@ def save(self, file_path: Optional[str] = None) -> str: The final path of the saved file """ if not file_path: - file_path = Path(os.path.curdir) / self.__st_file + file_path = Path.cwd() / self.__st_file elif isinstance(file_path, str): file_path = Path(file_path) diff --git a/tests/sims/test_simulator.py b/tests/sims/test_simulator.py index 1a5d528b..e6adc2be 100644 --- a/tests/sims/test_simulator.py +++ b/tests/sims/test_simulator.py @@ -43,6 +43,10 @@ def test_set_module_params(self, simple_common_sim): assert simple_common_sim.param1 == params["param1"] assert simple_common_sim.param2 == params["param2"] + @pytest.mark.skipif( + sys.platform.startswith("win"), + reason="Built-in open has different behavior on Windows", + ) def test_load_state_dict(self, simple_common_sim): fpath = simple_common_sim.state_dict().save() loaded_state_dict = StateDict.load(fpath) @@ -65,6 +69,5 @@ def test_load_state_dict(self, simple_common_sim): == simple_common_sim.z_s.value ) - # Cleanup after only for non-windows - if not sys.platform.startswith("win"): - Path(fpath).unlink(missing_ok=True) + # Cleanup after + Path(fpath).unlink() diff --git a/tests/sims/test_state_dict.py b/tests/sims/test_state_dict.py index ae420770..4d15591d 100644 --- a/tests/sims/test_state_dict.py +++ b/tests/sims/test_state_dict.py @@ -1,6 +1,5 @@ from pathlib import Path from tempfile import TemporaryDirectory -import os import sys import pytest @@ -136,17 +135,20 @@ def test_st_file_string(self, simple_state_dict): assert simple_state_dict._StateDict__st_file == expected_file + @pytest.mark.skipif( + sys.platform.startswith("win"), + reason="Built-in open has different behavior on Windows", + ) def test_save(self, simple_state_dict): # Check for default save path - expected_fpath = Path(os.path.curdir) / simple_state_dict._StateDict__st_file + expected_fpath = Path.cwd() / simple_state_dict._StateDict__st_file default_fpath = simple_state_dict.save() assert Path(default_fpath).exists() assert default_fpath == str(expected_fpath.absolute()) - # Cleanup after only for non-windows - if not sys.platform.startswith("win"): - Path(default_fpath).unlink(missing_ok=True) + # Cleanup after + Path(default_fpath).unlink() # Check for specified save path with TemporaryDirectory() as tempdir: @@ -163,11 +165,14 @@ def test_save(self, simple_state_dict): with pytest.raises(ValueError): saved_path = simple_state_dict.save(str(wrong_fpath.absolute())) + @pytest.mark.skipif( + sys.platform.startswith("win"), + reason="Built-in open has different behavior on Windows", + ) def test_load(self, simple_state_dict): fpath = simple_state_dict.save() loaded_state_dict = StateDict.load(fpath) assert loaded_state_dict == simple_state_dict - # Cleanup after only for non-windows - if not sys.platform.startswith("win"): - Path(fpath).unlink(missing_ok=True) + # Cleanup after + Path(fpath).unlink() diff --git a/tests/test_io.py b/tests/test_io.py index 18d56bfc..d3116b7c 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -14,16 +14,16 @@ def test_normalize_path(): + path_obj = Path().joinpath("path", "to", "file.txt") # Test with a string path - path_str = "/path/to/file.txt" + path_str = str(path_obj) normalized_path = _normalize_path(path_str) - assert normalized_path == Path(path_str) - assert str(normalized_path), path_str + assert normalized_path == path_obj.absolute() + assert str(normalized_path) == str(path_obj.absolute()) # Test with a Path object - path_obj = Path("/path/to/file.txt") normalized_path = _normalize_path(path_obj) - assert normalized_path == path_obj + assert normalized_path == path_obj.absolute() def test_to_and_from_file():