From 3f16a33760de312e985f61f50129fb7e64a6eb05 Mon Sep 17 00:00:00 2001 From: Landung 'Don' Setiawan Date: Tue, 23 Jan 2024 10:16:01 -0800 Subject: [PATCH 1/7] test(StateDict): Added tests for methods in 'StateDict' Added save and load tests for 'StateDict' to ensure that every functionality within the methods are covered in test. --- tests/sims/test_state_dict.py | 98 ++++++++++++++++++++++++++++++++--- 1 file changed, 92 insertions(+), 6 deletions(-) diff --git a/tests/sims/test_state_dict.py b/tests/sims/test_state_dict.py index ba5ec204..33699651 100644 --- a/tests/sims/test_state_dict.py +++ b/tests/sims/test_state_dict.py @@ -1,15 +1,43 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + import pytest import torch +from collections import OrderedDict from safetensors.torch import save, load from datetime import datetime as dt from caustics.parameter import Parameter from caustics.namespace_dict import NamespaceDict, NestedNamespaceDict -from caustics.sims.state_dict import StateDict, IMMUTABLE_ERR, _sanitize +from caustics.sims.state_dict import ImmutableODict, StateDict, IMMUTABLE_ERR, _sanitize from caustics import __version__ from helpers.sims import extract_tensors +class TestImmutableODict: + def test_constructor(self): + odict = ImmutableODict(a=1, b=2, c=3) + assert isinstance(odict, OrderedDict) + assert odict == {"a": 1, "b": 2, "c": 3} + assert hasattr(odict, "_created") + assert odict._created is True + + def test_setitem(self): + odict = ImmutableODict() + with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): + odict["key"] = "value" + + def test_delitem(self): + odict = ImmutableODict(key="value") + with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): + del odict["key"] + + def test_setattr(self): + odict = ImmutableODict() + with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): + odict.meta = {"key": "value"} + + class TestStateDict: simple_tensors = {"var1": torch.as_tensor(1.0), "var2": torch.as_tensor(2.0)} @@ -33,6 +61,15 @@ def test_constructor(self): assert sd_ct_str == time_str_now assert dict(state_dict) == self.simple_tensors + def test_constructor_with_metadata(self): + time_format = "%Y-%m-%dT%H:%M:%S" + time_str_now = dt.utcnow().strftime(time_format) + metadata = {"created_time": time_str_now, "software_version": "0.0.1"} + state_dict = StateDict(metadata=metadata, **self.simple_tensors) + + assert isinstance(state_dict._metadata, ImmutableODict) + assert dict(state_dict._metadata) == dict(metadata) + def test_setitem(self, simple_state_dict): with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): simple_state_dict["var1"] = torch.as_tensor(3.0) @@ -56,14 +93,26 @@ def test_from_params(self, simple_common_sim): state_dict = StateDict.from_params(all_params) assert state_dict == expected_state_dict - def test_to_params(self, simple_state_dict): - params = simple_state_dict.to_params() + # Check for TypeError when passing a NamespaceDict or NestedNamespaceDict + with pytest.raises(TypeError): + StateDict.from_params({"a": 1, "b": 2}) + + # Check for TypeError when passing a NestedNamespaceDict + # without the "static" and "dynamic" keys + with pytest.raises(ValueError): + StateDict.from_params(NestedNamespaceDict({"a": 1, "b": 2})) + + def test_to_params(self): + params_with_none = {"var3": torch.ones(0), **self.simple_tensors} + state_dict = StateDict(**params_with_none) + params = StateDict(**params_with_none).to_params() assert isinstance(params, NamespaceDict) for k, v in params.items(): - tensor_value = simple_state_dict[k] - assert isinstance(v, Parameter) - assert v.value == tensor_value + tensor_value = state_dict[k] + if tensor_value.nelement() > 0: + assert isinstance(v, Parameter) + assert v.value == tensor_value def test__to_safetensors(self): state_dict = StateDict(**self.simple_tensors) @@ -78,3 +127,40 @@ def test__to_safetensors(self): loaded_tensors = load(tensors_bytes) loaded_expected_tensors = load(expected_bytes) assert loaded_tensors == loaded_expected_tensors + + def test_st_file_string(self, simple_state_dict): + file_format = "%Y%m%dT%H%M%S_caustics.st" + expected_file = simple_state_dict._created_time.strftime(file_format) + + assert simple_state_dict._StateDict__st_file == expected_file + + def test_save(self, simple_state_dict): + # Check for default save path + expected_fpath = Path(".") / simple_state_dict._StateDict__st_file + default_fpath = simple_state_dict.save() + + assert Path(default_fpath).exists() + assert default_fpath == str(expected_fpath.absolute()) + + # Cleanup after + Path(default_fpath).unlink() + + # Check for specified save path + with TemporaryDirectory() as tempdir: + tempdir = Path(tempdir) + # Correct extension and path in a tempdir + fpath = tempdir / "test.st" + saved_path = simple_state_dict.save(str(fpath.absolute())) + + assert Path(saved_path).exists() + assert saved_path == str(fpath.absolute()) + + # Wrong extension + wrong_fpath = tempdir / "test.txt" + with pytest.raises(ValueError): + saved_path = simple_state_dict.save(str(wrong_fpath.absolute())) + + def test_load(self, simple_state_dict): + fpath = simple_state_dict.save() + loaded_state_dict = StateDict.load(fpath) + assert loaded_state_dict == simple_state_dict From a254566cd3c3baafd245a719a2ee1083262a2382 Mon Sep 17 00:00:00 2001 From: Landung 'Don' Setiawan Date: Tue, 23 Jan 2024 10:51:59 -0800 Subject: [PATCH 2/7] test: Added test for 'load_state_dict' for 'Simulator' class Added test for the loading functionalities of a 'Simulator' class to ensure that it's working properly --- tests/sims/test_simulator.py | 37 +++++++++++++++++++++++++++++++++++ tests/sims/test_state_dict.py | 3 +++ 2 files changed, 40 insertions(+) diff --git a/tests/sims/test_simulator.py b/tests/sims/test_simulator.py index 8cedc350..e0c47331 100644 --- a/tests/sims/test_simulator.py +++ b/tests/sims/test_simulator.py @@ -1,4 +1,7 @@ import pytest +from pathlib import Path + +import torch from caustics.sims.state_dict import StateDict from helpers.sims import extract_tensors @@ -29,3 +32,37 @@ def test_state_dict(self, state_dict, expected_tensors): # Check params assert dict(state_dict) == expected_tensors + + def test_set_module_params(self, simple_common_sim): + params = {"param1": torch.as_tensor(1), "param2": torch.as_tensor(2)} + # Call the __set_module_params method + simple_common_sim._Simulator__set_module_params(simple_common_sim, params) + + # Check if the module attributes have been set correctly + assert simple_common_sim.param1 == params["param1"] + assert simple_common_sim.param2 == params["param2"] + + def test_load_state_dict(self, simple_common_sim): + fpath = simple_common_sim.state_dict().save() + loaded_state_dict = StateDict.load(fpath) + + # Change a value in the simulator + simple_common_sim.z_s = 3.0 + + # Ensure that the simulator has been changed + assert ( + loaded_state_dict[f"{simple_common_sim.name}.z_s"] + != simple_common_sim.z_s.value + ) + + # Load the state dict form file + simple_common_sim.load_state_dict(fpath) + + # Once loaded now the values should be the same + assert ( + loaded_state_dict[f"{simple_common_sim.name}.z_s"] + == simple_common_sim.z_s.value + ) + + # Cleanup + Path(fpath).unlink() diff --git a/tests/sims/test_state_dict.py b/tests/sims/test_state_dict.py index 33699651..9fb0ec3c 100644 --- a/tests/sims/test_state_dict.py +++ b/tests/sims/test_state_dict.py @@ -164,3 +164,6 @@ def test_load(self, simple_state_dict): fpath = simple_state_dict.save() loaded_state_dict = StateDict.load(fpath) assert loaded_state_dict == simple_state_dict + + # Cleanup after + Path(fpath).unlink() From 415540f33fd0892cbb9acc68b534186fef950d0a Mon Sep 17 00:00:00 2001 From: Landung 'Don' Setiawan Date: Tue, 23 Jan 2024 13:39:42 -0800 Subject: [PATCH 3/7] test: Added tests for 'io' module Added test functions for the whole 'io' module --- tests/test_io.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests/test_io.py diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 00000000..18d56bfc --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,63 @@ +from pathlib import Path +import tempfile +import struct +import json +import torch +from safetensors.torch import save +from caustics.io import ( + _get_safetensors_header, + _normalize_path, + to_file, + from_file, + get_safetensors_metadata, +) + + +def test_normalize_path(): + # Test with a string path + path_str = "/path/to/file.txt" + normalized_path = _normalize_path(path_str) + assert normalized_path == Path(path_str) + assert str(normalized_path), path_str + + # Test with a Path object + path_obj = Path("/path/to/file.txt") + normalized_path = _normalize_path(path_obj) + assert normalized_path == path_obj + + +def test_to_and_from_file(): + with tempfile.TemporaryDirectory() as tmpdir: + fpath = Path(tmpdir) / "test.txt" + data = "test data" + + # Test to file + ffile = to_file(fpath, data) + + assert Path(ffile).exists() + assert ffile == str(fpath.absolute()) + assert Path(ffile).read_text() == data + + # Test from file + assert from_file(fpath) == data.encode("utf-8") + + +def test_get_safetensors_metadata(): + with tempfile.TemporaryDirectory() as tmpdir: + fpath = Path(tmpdir) / "test.st" + meta_dict = {"meta": "data"} + tensors_bytes = save({"test1": torch.as_tensor(1.0)}, metadata=meta_dict) + fpath.write_bytes(tensors_bytes) + + # Manually get header + first_bytes_length = 8 + (length_of_header,) = struct.unpack(" Date: Tue, 23 Jan 2024 13:44:45 -0800 Subject: [PATCH 4/7] test: Allow for missing ok when unlink, in case it auto deleted --- tests/sims/test_simulator.py | 2 +- tests/sims/test_state_dict.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/sims/test_simulator.py b/tests/sims/test_simulator.py index e0c47331..2c3eb3fe 100644 --- a/tests/sims/test_simulator.py +++ b/tests/sims/test_simulator.py @@ -65,4 +65,4 @@ def test_load_state_dict(self, simple_common_sim): ) # Cleanup - Path(fpath).unlink() + Path(fpath).unlink(missing_ok=True) diff --git a/tests/sims/test_state_dict.py b/tests/sims/test_state_dict.py index 9fb0ec3c..8b4f972e 100644 --- a/tests/sims/test_state_dict.py +++ b/tests/sims/test_state_dict.py @@ -143,7 +143,7 @@ def test_save(self, simple_state_dict): assert default_fpath == str(expected_fpath.absolute()) # Cleanup after - Path(default_fpath).unlink() + Path(default_fpath).unlink(missing_ok=True) # Check for specified save path with TemporaryDirectory() as tempdir: @@ -166,4 +166,4 @@ def test_load(self, simple_state_dict): assert loaded_state_dict == simple_state_dict # Cleanup after - Path(fpath).unlink() + Path(fpath).unlink(missing_ok=True) From 3ebeba704a5804c50eade4eb97c17c9a1c095d96 Mon Sep 17 00:00:00 2001 From: Landung 'Don' Setiawan Date: Tue, 23 Jan 2024 13:58:55 -0800 Subject: [PATCH 5/7] fix: Removed hardcoding of current dir of '.' Removed hardcoding of current directory using '.', instead use 'os.path.curdir' to be more OS agnostic --- src/caustics/sims/state_dict.py | 3 ++- tests/sims/test_state_dict.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/caustics/sims/state_dict.py b/src/caustics/sims/state_dict.py index 5ca04130..66bcade0 100644 --- a/src/caustics/sims/state_dict.py +++ b/src/caustics/sims/state_dict.py @@ -2,6 +2,7 @@ from collections import OrderedDict from typing import Any, Dict, Optional from pathlib import Path +import os from torch import Tensor import torch @@ -184,7 +185,7 @@ def save(self, file_path: Optional[str] = None) -> str: The final path of the saved file """ if not file_path: - file_path = Path(".") / self.__st_file + file_path = Path(os.path.curdir) / self.__st_file elif isinstance(file_path, str): file_path = Path(file_path) diff --git a/tests/sims/test_state_dict.py b/tests/sims/test_state_dict.py index 8b4f972e..4db439f9 100644 --- a/tests/sims/test_state_dict.py +++ b/tests/sims/test_state_dict.py @@ -1,5 +1,6 @@ from pathlib import Path from tempfile import TemporaryDirectory +import os import pytest import torch @@ -136,7 +137,7 @@ def test_st_file_string(self, simple_state_dict): def test_save(self, simple_state_dict): # Check for default save path - expected_fpath = Path(".") / simple_state_dict._StateDict__st_file + expected_fpath = Path(os.path.curdir) / simple_state_dict._StateDict__st_file default_fpath = simple_state_dict.save() assert Path(default_fpath).exists() From 03aba1129eb712bdce6c0413de93b77443bb079d Mon Sep 17 00:00:00 2001 From: Landung 'Don' Setiawan Date: Tue, 23 Jan 2024 14:08:03 -0800 Subject: [PATCH 6/7] test: Add platform check before 'unlink', skip windows --- tests/sims/test_simulator.py | 6 ++++-- tests/sims/test_state_dict.py | 11 +++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/sims/test_simulator.py b/tests/sims/test_simulator.py index 2c3eb3fe..1a5d528b 100644 --- a/tests/sims/test_simulator.py +++ b/tests/sims/test_simulator.py @@ -1,5 +1,6 @@ import pytest from pathlib import Path +import sys import torch @@ -64,5 +65,6 @@ def test_load_state_dict(self, simple_common_sim): == simple_common_sim.z_s.value ) - # Cleanup - Path(fpath).unlink(missing_ok=True) + # Cleanup after only for non-windows + if not sys.platform.startswith("win"): + Path(fpath).unlink(missing_ok=True) diff --git a/tests/sims/test_state_dict.py b/tests/sims/test_state_dict.py index 4db439f9..ae420770 100644 --- a/tests/sims/test_state_dict.py +++ b/tests/sims/test_state_dict.py @@ -1,6 +1,7 @@ from pathlib import Path from tempfile import TemporaryDirectory import os +import sys import pytest import torch @@ -143,8 +144,9 @@ def test_save(self, simple_state_dict): assert Path(default_fpath).exists() assert default_fpath == str(expected_fpath.absolute()) - # Cleanup after - Path(default_fpath).unlink(missing_ok=True) + # Cleanup after only for non-windows + if not sys.platform.startswith("win"): + Path(default_fpath).unlink(missing_ok=True) # Check for specified save path with TemporaryDirectory() as tempdir: @@ -166,5 +168,6 @@ def test_load(self, simple_state_dict): loaded_state_dict = StateDict.load(fpath) assert loaded_state_dict == simple_state_dict - # Cleanup after - Path(fpath).unlink(missing_ok=True) + # Cleanup after only for non-windows + if not sys.platform.startswith("win"): + Path(fpath).unlink(missing_ok=True) From eb1c24f95141abb649e7176fe606d28c3f21594b Mon Sep 17 00:00:00 2001 From: Landung 'Don' Setiawan Date: Tue, 23 Jan 2024 14:53:09 -0800 Subject: [PATCH 7/7] fix: Changed normalized path to use absolute --- src/caustics/io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/caustics/io.py b/src/caustics/io.py index f7042fa2..d5a146a8 100644 --- a/src/caustics/io.py +++ b/src/caustics/io.py @@ -10,7 +10,9 @@ def _normalize_path(path: "str | Path") -> Path: # Convert string path to Path object if isinstance(path, str): path = Path(path) - return path + + # Get absolute path + return path.absolute() def to_file(