Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
feat(StateDict): Added 'save' method to save to file
Browse files Browse the repository at this point in the history
* Added 'save' method for saving state dictionary
safetensors bytes to a file.
* Added 'ImmutableODict' class for 'StateDict' to inherit from,
which cracks down on ensuring that attributes are also immutable
after creation.
* Modified '_metadata' attribute to be of type 'ImmutableODict'
so that users can't also modify the key and values in this dictionary.
  • Loading branch information
lsetiawan committed Jan 19, 2024
1 parent a30900d commit 01ef378
Showing 1 changed file with 74 additions and 16 deletions.
90 changes: 74 additions & 16 deletions src/caustics/sims/state_dict.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,39 @@
from datetime import datetime as dt
from collections import OrderedDict
from typing import Any, Dict
from typing import Any, Dict, Optional
from pathlib import Path

from torch import Tensor
from .._version import __version__
from ..namespace_dict import NamespaceDict, NestedNamespaceDict
from .. import io

from safetensors.torch import save

IMMUTABLE_ERR = TypeError("'StateDict' cannot be modified after creation.")
STATIC_PARAMS = "static"


class StateDict(OrderedDict):
class ImmutableODict(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._created = True

def __delitem__(self, _) -> None:
raise IMMUTABLE_ERR

def __setitem__(self, key: str, value: Any) -> None:
if hasattr(self, "_created"):
raise IMMUTABLE_ERR
super().__setitem__(key, value)

def __setattr__(self, name, value) -> None:
if hasattr(self, "_created"):
raise IMMUTABLE_ERR
return super().__setattr__(name, value)


class StateDict(ImmutableODict):
"""A dictionary object that is immutable after creation.
This is used to store the parameters of a simulator at a given
point in time.
Expand All @@ -23,24 +44,23 @@ class StateDict(OrderedDict):
Convert the state dict to a dictionary of parameters.
"""

__slots__ = ("_metadata", "_created")
__slots__ = ("_metadata", "_created", "_created_time")

def __init__(self, *args, **kwargs):
# Get created time
self._created_time = dt.now()
# Create metadata
metadata = {
"software_version": __version__,
"created_time": self._created_time.isoformat(),
}
# Set metadata
self._metadata = ImmutableODict(metadata)

# Now create the object, this will set _created
# to True, and prevent any further modification
super().__init__(*args, **kwargs)

self._metadata = {}
self._metadata["software_version"] = __version__
self._metadata["created_time"] = dt.utcnow().isoformat()
self._created = True

def __delitem__(self, _) -> None:
raise IMMUTABLE_ERR

def __setitem__(self, key: str, value: Any) -> None:
if hasattr(self, "_created"):
raise IMMUTABLE_ERR
super().__setitem__(key, value)

@classmethod
def from_params(cls, params: "NestedNamespaceDict | NamespaceDict"):
"""Class method to create a StateDict
Expand Down Expand Up @@ -80,5 +100,43 @@ def to_params(self) -> NamespaceDict:
params[k] = Parameter(v)
return params

def save(self, file_path: Optional[str] = None) -> str:
"""
Saves the state dictionary to an optional
``file_path`` as safetensors format.
If ``file_path`` is not given,
this will default to a file in
the current working directory.
*Note: The path specified must
have a '.st' extension.*
Parameters
----------
file_path : str, optional
The file path to save the
state dictionary to, by default None
Returns
-------
str
The final path of the saved file
"""
if not file_path:
file_path = Path(".") / self.__st_file
elif isinstance(file_path, str):
file_path = Path(file_path)

ext = ".st"
if file_path.suffix != ext:
raise ValueError(f"File must have '{ext}' extension")

return io.to_file(file_path, self._to_safetensors())

@property
def __st_file(self) -> str:
file_format = "%Y%m%dT%H%M%S_caustics.st"
return self._created_time.strftime(file_format)

def _to_safetensors(self) -> bytes:
return save(self, metadata=self._metadata)

0 comments on commit 01ef378

Please sign in to comment.