diff --git a/tests/sims/test_state_dict.py b/tests/sims/test_state_dict.py index 93942a8c..07459be4 100644 --- a/tests/sims/test_state_dict.py +++ b/tests/sims/test_state_dict.py @@ -1,3 +1,4 @@ +import hashlib from typing import Dict import pytest import torch @@ -74,5 +75,7 @@ def test_to_params(self, simple_state_dict): def test__to_safetensors(self, simple_state_dict): tensors_bytes = simple_state_dict._to_safetensors() + digest = hashlib.sha256(tensors_bytes).hexdigest() expected_bytes = save(simple_state_dict, metadata=simple_state_dict._metadata) - assert tensors_bytes == expected_bytes + expected_digest = hashlib.sha256(expected_bytes).hexdigest() + assert digest == expected_digest