Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
test: Create utility and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lsetiawan committed Jan 19, 2024
1 parent 27b430f commit 9dea243
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 46 deletions.
50 changes: 50 additions & 0 deletions tests/sims/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,56 @@
from caustics.lenses import EPL
from caustics.light import Sersic
from caustics.cosmology import FlatLambdaCDM
from caustics.namespace_dict import NestedNamespaceDict
from caustics.sims.state_dict import _sanitize


class SimUtilities:
@staticmethod
def extract_tensors(params, include_params=False):
# Extract the "static" and "dynamic" parameters
param_dicts = list(params.values())

# Extract the "static" and "dynamic" parameters
# to a single merged dictionary
final_dict = NestedNamespaceDict()
for pdict in param_dicts:
for k, v in pdict.items():
if k not in final_dict:
final_dict[k] = v
else:
final_dict[k] = {**final_dict[k], **v}

# flatten function only exists for NestedNamespaceDict
all_params = final_dict.flatten()

tensors_dict = _sanitize({k: v.value for k, v in all_params.items()})
if include_params:
return tensors_dict, all_params
return tensors_dict

@staticmethod
def isEquals(a, b):
# Go through each key and values
# change empty torch to be None
# since we can't directly compare
# empty torch
truthy = []
for k, v in a.items():
if k not in b:
return False
kv = b[k]
if (v.nelement() == 0) or (kv.nelement() == 0):
v = None
kv = None
truthy.append(v == kv)

return all(truthy)


@pytest.fixture
def sim_utils():
return SimUtilities


@pytest.fixture
Expand Down
9 changes: 4 additions & 5 deletions tests/sims/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ def state_dict(simple_common_sim):


@pytest.fixture
def expected_tensors(simple_common_sim):
static_params = simple_common_sim.params["static"].flatten()
return {k: v.value for k, v in static_params.items()}
def expected_tensors(simple_common_sim, sim_utils):
return sim_utils.extract_tensors(simple_common_sim.params)


class TestSimulator:
def test_state_dict(self, state_dict, expected_tensors):
def test_state_dict(self, state_dict, expected_tensors, sim_utils):
# Check state_dict type and default keys
assert isinstance(state_dict, StateDict)

Expand All @@ -28,4 +27,4 @@ def test_state_dict(self, state_dict, expected_tensors):
assert "created_time" in state_dict._metadata

# Check params
assert dict(state_dict) == expected_tensors
assert sim_utils.isEquals(dict(state_dict), expected_tensors)
46 changes: 5 additions & 41 deletions tests/sims/test_state_dict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Dict
import pytest
import torch
from safetensors.torch import save, load
from datetime import datetime as dt
from caustics.parameter import Parameter
from caustics.namespace_dict import NamespaceDict, NestedNamespaceDict
from caustics.sims.state_dict import StateDict, IMMUTABLE_ERR, _sanitize
from caustics.sims.state_dict import StateDict, IMMUTABLE_ERR
from caustics import __version__


Expand Down Expand Up @@ -40,55 +39,20 @@ def test_delitem(self, simple_state_dict):
with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)):
del simple_state_dict["var1"]

def test_from_params(self, simple_common_sim):
def test_from_params(self, simple_common_sim, sim_utils):
params: NestedNamespaceDict = simple_common_sim.params

# Extract the "static" and "dynamic" parameters
param_dicts = list(params.values())

# Extract the "static" and "dynamic" parameters
# to a single merged dictionary
final_dict = NestedNamespaceDict()
for pdict in param_dicts:
for k, v in pdict.items():
if k not in final_dict:
final_dict[k] = v
else:
final_dict[k] = {**final_dict[k], **v}

# flatten function only exists for NestedNamespaceDict
all_params = final_dict.flatten()

tensors_dict: Dict[str, torch.Tensor] = _sanitize(
{k: v.value for k, v in all_params.items()}
)
tensors_dict, all_params = sim_utils.extract_tensors(params, True)

expected_state_dict = StateDict(tensors_dict)

def isEquals(a, b):
# Go through each key and values
# change empty torch to be None
# since we can't directly compare
# empty torch
truthy = []
for k, v in a.items():
if k not in b:
return False
kv = b[k]
if (v.nelement() == 0) or (kv.nelement() == 0):
v = None
kv = None
truthy.append(v == kv)

return all(truthy)

# Full parameters
state_dict = StateDict.from_params(params)
assert isEquals(state_dict, expected_state_dict)
assert sim_utils.isEquals(state_dict, expected_state_dict)

# Static only
state_dict = StateDict.from_params(all_params)
assert isEquals(state_dict, expected_state_dict)
assert sim_utils.isEquals(state_dict, expected_state_dict)

def test_to_params(self, simple_state_dict):
params = simple_state_dict.to_params()
Expand Down

0 comments on commit 9dea243

Please sign in to comment.