Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

fix: Fix I/O for accounting windows #61

Merged
merged 5 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/caustics/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def to_file(
# Normalize path to pathlib.Path object
path = _normalize_path(path)

path.write_bytes(data)
with open(path, "wb") as f:
f.write(data)

return str(path.absolute())


Expand Down
3 changes: 1 addition & 2 deletions src/caustics/sims/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import OrderedDict
from typing import Any, Dict, Optional
from pathlib import Path
import os

from torch import Tensor
import torch
Expand Down Expand Up @@ -185,7 +184,7 @@ def save(self, file_path: Optional[str] = None) -> str:
The final path of the saved file
"""
if not file_path:
file_path = Path(os.path.curdir) / self.__st_file
file_path = Path.cwd() / self.__st_file
elif isinstance(file_path, str):
file_path = Path(file_path)

Expand Down
9 changes: 6 additions & 3 deletions tests/sims/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def test_set_module_params(self, simple_common_sim):
assert simple_common_sim.param1 == params["param1"]
assert simple_common_sim.param2 == params["param2"]

@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Built-in open has different behavior on Windows",
)
def test_load_state_dict(self, simple_common_sim):
fpath = simple_common_sim.state_dict().save()
loaded_state_dict = StateDict.load(fpath)
Expand All @@ -65,6 +69,5 @@ def test_load_state_dict(self, simple_common_sim):
== simple_common_sim.z_s.value
)

# Cleanup after only for non-windows
if not sys.platform.startswith("win"):
Path(fpath).unlink(missing_ok=True)
# Cleanup after
Path(fpath).unlink()
21 changes: 13 additions & 8 deletions tests/sims/test_state_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pathlib import Path
from tempfile import TemporaryDirectory
import os
import sys

import pytest
Expand Down Expand Up @@ -136,17 +135,20 @@ def test_st_file_string(self, simple_state_dict):

assert simple_state_dict._StateDict__st_file == expected_file

@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Built-in open has different behavior on Windows",
)
def test_save(self, simple_state_dict):
# Check for default save path
expected_fpath = Path(os.path.curdir) / simple_state_dict._StateDict__st_file
expected_fpath = Path.cwd() / simple_state_dict._StateDict__st_file
default_fpath = simple_state_dict.save()

assert Path(default_fpath).exists()
assert default_fpath == str(expected_fpath.absolute())

# Cleanup after only for non-windows
if not sys.platform.startswith("win"):
Path(default_fpath).unlink(missing_ok=True)
# Cleanup after
Path(default_fpath).unlink()

# Check for specified save path
with TemporaryDirectory() as tempdir:
Expand All @@ -163,11 +165,14 @@ def test_save(self, simple_state_dict):
with pytest.raises(ValueError):
saved_path = simple_state_dict.save(str(wrong_fpath.absolute()))

@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Built-in open has different behavior on Windows",
)
def test_load(self, simple_state_dict):
fpath = simple_state_dict.save()
loaded_state_dict = StateDict.load(fpath)
assert loaded_state_dict == simple_state_dict

# Cleanup after only for non-windows
if not sys.platform.startswith("win"):
Path(fpath).unlink(missing_ok=True)
# Cleanup after
Path(fpath).unlink()
10 changes: 5 additions & 5 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@


def test_normalize_path():
path_obj = Path().joinpath("path", "to", "file.txt")
# Test with a string path
path_str = "/path/to/file.txt"
path_str = str(path_obj)
normalized_path = _normalize_path(path_str)
assert normalized_path == Path(path_str)
assert str(normalized_path), path_str
assert normalized_path == path_obj.absolute()
assert str(normalized_path) == str(path_obj.absolute())

# Test with a Path object
path_obj = Path("/path/to/file.txt")
normalized_path = _normalize_path(path_obj)
assert normalized_path == path_obj
assert normalized_path == path_obj.absolute()


def test_to_and_from_file():
Expand Down
Loading