From ab2ee3455ea516e1d25e8fe07369c927226f9bf2 Mon Sep 17 00:00:00 2001 From: Michael Shi Date: Mon, 27 Nov 2023 16:09:40 -0800 Subject: [PATCH] Make test_keyed standalone and remove HPC dependency (#1540) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1540 We remove the `hpc.optimizers` dependency on `optimizer_modules` by enhancing the `DummyOptimizerModule` under `test_keyed.py`. This ensures that the tests are standalone, and can be properly OSS'ed. Thanks to joshuadeng for catching the error! Reviewed By: henrylhtsang Differential Revision: D51599854 fbshipit-source-id: cf77ffe994b946acda9abed5fbedb8e0a8a4dd42 --- torchrec/optim/tests/test_keyed.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchrec/optim/tests/test_keyed.py b/torchrec/optim/tests/test_keyed.py index e6bf19609..99ced8afe 100644 --- a/torchrec/optim/tests/test_keyed.py +++ b/torchrec/optim/tests/test_keyed.py @@ -12,7 +12,6 @@ import torch import torch.distributed as dist -from hpc.optimizers.optimizer_modules import OptimizerModule from torch.autograd import Variable from torch.distributed._shard import sharded_tensor, sharding_spec from torchrec.optim.keyed import ( @@ -24,14 +23,19 @@ from torchrec.test_utils import get_free_port -class DummyOptimizerModule(OptimizerModule): +class DummyOptimizerModule: def __init__( self, tensor: torch.Tensor, ) -> None: - super(DummyOptimizerModule, self).__init__() self.tensor = tensor + def state_dict(self) -> Dict[str, Any]: + return {"tensor": self.tensor} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.tensor.detach().copy_(state_dict["tensor"]) + class TestKeyedOptimizer(unittest.TestCase): def _assert_state_dict_equals(