diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index cd0e9074ff..045145a4fc 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -5,6 +5,7 @@ Optional, ) +import numpy as np from torch.utils.data import ( Dataset, ) @@ -13,7 +14,7 @@ DataRequirementItem, DeepmdData, ) -import numpy as np + class DeepmdDataSetForLoader(Dataset): def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None: @@ -40,21 +41,21 @@ def __getitem__(self, index): b_data = self._data_system.get_item_torch(index) b_data["natoms"] = self._natoms_vec return b_data - + def _build_element_to_frames(self): """Build mapping from element types to frame indexes and return all unique element types.""" - element_to_frames = {element: [] for element in range(self._ntypes)} - all_elements = set() + element_to_frames = {element: [] for element in range(self._ntypes)} + all_elements = set() all_frame_data = self._data_system.get_batch(self._data_system.nframes) all_elements = np.unique(all_frame_data["type"]) - for i in range(len(self)): + for i in range(len(self)): for element in all_elements: element_to_frames[element].append(i) return element_to_frames, all_elements - + def get_frames_for_element(self, missing_element_name): """Get the frames that contain the specified element type.""" - element_index = self._type_map.index(missing_element_name) + element_index = self._type_map.index(missing_element_name) return self.element_to_frames.get(element_index, []) def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index f624e3d0bd..ff40f6cc3d 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -36,6 +36,7 @@ log = logging.getLogger(__name__) + def make_stat_input(datasets, dataloaders, nbatches): """Pack data for statistics. @@ -101,7 +102,7 @@ def make_stat_input(datasets, dataloaders, nbatches): pass sys_stat_new = {} for dd in frame_data: - if dd == "type": + if dd == "type": continue if frame_data[dd] is None: sys_stat_new[dd] = None @@ -122,11 +123,12 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat_new[key] = None elif isinstance(stat_data[dd], torch.Tensor): sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0) - dict_to_device(sys_stat_new) + dict_to_device(sys_stat_new) lst.append(sys_stat_new) return lst + def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index fe2a4c9e51..2cd67193c2 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,15 +1,22 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later import unittest + import torch -from torch.utils.data import DataLoader -from deepmd.pt.utils.stat import make_stat_input +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.utils.stat import ( + make_stat_input, +) + class TestDataset: def __init__(self, samples): - self.samples = samples self.element_to_frames = {} for idx, sample in enumerate(samples): - atypes = sample['atype'] + atypes = sample["atype"] for atype in atypes: if atype not in self.element_to_frames: self.element_to_frames[atype] = [] @@ -25,26 +32,31 @@ def __len__(self): def __getitem__(self, idx): sample = self.samples[idx] return { - 'atype': torch.tensor(sample['atype'], dtype=torch.long), - 'energy': torch.tensor(sample['energy'], dtype=torch.float32), + "atype": torch.tensor(sample["atype"], dtype=torch.long), + "energy": torch.tensor(sample["energy"], dtype=torch.float32), } + class TestMakeStatInput(unittest.TestCase): def setUp(self): - self.system = TestDataset([ - {'atype': [1], 'energy': -1.0}, - {'atype': [2], 'energy': -2.0}, - ]) + self.system = TestDataset( + [ + {"atype": [1], "energy": -1.0}, + {"atype": [2], "energy": -2.0}, + ] + ) self.datasets = [self.system] self.dataloaders = [ DataLoader(self.system, batch_size=1, shuffle=False), ] + def test_make_stat_input(self): - nbatches = 1 + nbatches = 1 lst = make_stat_input(self.datasets, self.dataloaders, nbatches=nbatches) all_elements = self.system.get_all_atype - unique_elements = {1,2} + unique_elements = {1, 2} self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements") -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main()