diff --git a/open_spiel/python/pytorch/dqn.py b/open_spiel/python/pytorch/dqn.py index f229fc8d9a..7b5bc775e9 100644 --- a/open_spiel/python/pytorch/dqn.py +++ b/open_spiel/python/pytorch/dqn.py @@ -428,7 +428,7 @@ def load(self, data_path, optimizer_data_path=None): relative or absolute but the filename should be included. For example: optimizer.pt or /path/to/optimizer.pt """ - torch.load(self._q_network, data_path) - torch.load(self._target_q_network, data_path) + self._q_network = torch.load(data_path) + self._target_q_network = torch.load(data_path) if optimizer_data_path is not None: - torch.load(self._optimizer, optimizer_data_path) + self._optimizer = torch.load(optimizer_data_path)