Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 26, 2024
1 parent 543a318 commit ba72382
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 23 deletions.
15 changes: 8 additions & 7 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Optional,
)

import numpy as np
from torch.utils.data import (
Dataset,
)
Expand All @@ -13,7 +14,7 @@
DataRequirementItem,
DeepmdData,
)
import numpy as np


class DeepmdDataSetForLoader(Dataset):
def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None:
Expand All @@ -40,21 +41,21 @@ def __getitem__(self, index):
b_data = self._data_system.get_item_torch(index)
b_data["natoms"] = self._natoms_vec
return b_data

def _build_element_to_frames(self):
"""Build mapping from element types to frame indexes and return all unique element types."""
element_to_frames = {element: [] for element in range(self._ntypes)}
all_elements = set()
element_to_frames = {element: [] for element in range(self._ntypes)}
all_elements = set()
all_frame_data = self._data_system.get_batch(self._data_system.nframes)
all_elements = np.unique(all_frame_data["type"])
for i in range(len(self)):
for i in range(len(self)):
for element in all_elements:
element_to_frames[element].append(i)
return element_to_frames, all_elements

def get_frames_for_element(self, missing_element_name):
"""Get the frames that contain the specified element type."""
element_index = self._type_map.index(missing_element_name)
element_index = self._type_map.index(missing_element_name)
return self.element_to_frames.get(element_index, [])

def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
Expand Down
8 changes: 5 additions & 3 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

log = logging.getLogger(__name__)


def make_stat_input(datasets, dataloaders, nbatches):
"""Pack data for statistics.
Expand Down Expand Up @@ -85,7 +86,7 @@ def make_stat_input(datasets, dataloaders, nbatches):
all_element = set()

for i in lst:
unique_values = np.unique(i['atype'].cpu().numpy())
unique_values = np.unique(i["atype"].cpu().numpy())
unique_elements.update(unique_values)
for i in datasets:
all_elements_in_dataset = i.get_all_atype
Expand All @@ -102,7 +103,7 @@ def make_stat_input(datasets, dataloaders, nbatches):
pass
sys_stat_new = {}
for dd in frame_data:
if dd == "type":
if dd == "type":
continue
if frame_data[dd] is None:
sys_stat_new[dd] = None
Expand All @@ -123,11 +124,12 @@ def make_stat_input(datasets, dataloaders, nbatches):
sys_stat_new[key] = None
elif isinstance(stat_data[dd], torch.Tensor):
sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0)
dict_to_device(sys_stat_new)
dict_to_device(sys_stat_new)
lst.append(sys_stat_new)

return lst


def _restore_from_file(
stat_file_path: DPPath,
keys: list[str] = ["energy"],
Expand Down
38 changes: 25 additions & 13 deletions source/tests/pt/test_make_stat_input.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import torch
from torch.utils.data import DataLoader
from deepmd.pt.utils.stat import make_stat_input
from torch.utils.data import (
DataLoader,
)

from deepmd.pt.utils.stat import (
make_stat_input,
)


class TestDataset:
def __init__(self, samples):

self.samples = samples
self.element_to_frames = {}
for idx, sample in enumerate(samples):
atypes = sample['atype']
atypes = sample["atype"]
for atype in atypes:
if atype not in self.element_to_frames:
self.element_to_frames[atype] = []
Expand All @@ -25,26 +32,31 @@ def __len__(self):
def __getitem__(self, idx):
sample = self.samples[idx]
return {
'atype': torch.tensor(sample['atype'], dtype=torch.long),
'energy': torch.tensor(sample['energy'], dtype=torch.float32),
"atype": torch.tensor(sample["atype"], dtype=torch.long),
"energy": torch.tensor(sample["energy"], dtype=torch.float32),
}


class TestMakeStatInput(unittest.TestCase):
def setUp(self):
self.system = TestDataset([
{'atype': [1], 'energy': -1.0},
{'atype': [2], 'energy': -2.0},
])
self.system = TestDataset(
[
{"atype": [1], "energy": -1.0},
{"atype": [2], "energy": -2.0},
]
)
self.datasets = [self.system]
self.dataloaders = [
DataLoader(self.system, batch_size=1, shuffle=False),
]

def test_make_stat_input(self):
nbatches = 1
nbatches = 1
lst = make_stat_input(self.datasets, self.dataloaders, nbatches=nbatches)
all_elements = self.system.get_all_atype
unique_elements = {1,2}
unique_elements = {1, 2}
self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements")

if __name__ == '__main__':

if __name__ == "__main__":
unittest.main()

0 comments on commit ba72382

Please sign in to comment.