diff --git a/torchrec/optim/tests/test_keyed.py b/torchrec/optim/tests/test_keyed.py index e6bf19609..99ced8afe 100644 --- a/torchrec/optim/tests/test_keyed.py +++ b/torchrec/optim/tests/test_keyed.py @@ -12,7 +12,6 @@ import torch import torch.distributed as dist -from hpc.optimizers.optimizer_modules import OptimizerModule from torch.autograd import Variable from torch.distributed._shard import sharded_tensor, sharding_spec from torchrec.optim.keyed import ( @@ -24,14 +23,19 @@ from torchrec.test_utils import get_free_port -class DummyOptimizerModule(OptimizerModule): +class DummyOptimizerModule: def __init__( self, tensor: torch.Tensor, ) -> None: - super(DummyOptimizerModule, self).__init__() self.tensor = tensor + def state_dict(self) -> Dict[str, Any]: + return {"tensor": self.tensor} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.tensor.detach().copy_(state_dict["tensor"]) + class TestKeyedOptimizer(unittest.TestCase): def _assert_state_dict_equals(