Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix) Make bias statistics complete for all elements #4496

Open
wants to merge 106 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
32da243
4424
SumGuo-88 Dec 23, 2024
adf2315
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2024
4f6f63d
issues4424-2
SumGuo-88 Dec 26, 2024
b9bac38
ll
SumGuo-88 Dec 26, 2024
1db3408
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Dec 26, 2024
543a318
ll
SumGuo-88 Dec 26, 2024
ba72382
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2024
26f9a17
lll
SumGuo-88 Dec 26, 2024
25a803c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Dec 26, 2024
dc64307
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2024
8f962b5
allchange
SumGuo-88 Jan 2, 2025
b88e7fc
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
f57498d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
faeb7c5
test
SumGuo-88 Jan 2, 2025
725f1dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
ca7fc84
stat
SumGuo-88 Jan 2, 2025
394cf04
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
05128d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
4828619
check
SumGuo-88 Jan 2, 2025
37ccce4
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
c9406e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
2224f61
chec坑
SumGuo-88 Jan 2, 2025
ba12c2c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
9fcee84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
fe8579e
check3
SumGuo-88 Jan 2, 2025
f004dff
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
11138ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
a4a97a3
test
SumGuo-88 Jan 3, 2025
88566fe
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
10e538d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
203dc4e
ttt
SumGuo-88 Jan 3, 2025
bb9fbe1
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
6a65561
t
SumGuo-88 Jan 3, 2025
603aee9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
1c103c4
d
SumGuo-88 Jan 3, 2025
4173040
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
e3a1c9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
533e95e
ll
SumGuo-88 Jan 3, 2025
38dc18c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
714c197
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
6713c1a
last
SumGuo-88 Jan 4, 2025
e42e38d
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
1c15cf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2025
33c716d
q
SumGuo-88 Jan 4, 2025
6d38b94
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
6bbced8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2025
b462c97
ll
SumGuo-88 Jan 4, 2025
0d7154c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
0c7baa0
ll
SumGuo-88 Jan 4, 2025
28d94af
ll
SumGuo-88 Jan 4, 2025
87dcd66
l
SumGuo-88 Jan 4, 2025
5d33060
ll
SumGuo-88 Jan 4, 2025
521e3a6
ll
SumGuo-88 Jan 4, 2025
379d4ad
Merge branch 'devel' into devel
SumGuo-88 Jan 5, 2025
0dabf77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
a23528c
Update stat.py
SumGuo-88 Jan 5, 2025
0a97b54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
49744ed
Update deepmd/pt/utils/stat.py
SumGuo-88 Jan 6, 2025
556a684
Simplify logic and remove "not"
SumGuo-88 Jan 6, 2025
aa2633d
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
27999af
check import
SumGuo-88 Jan 6, 2025
83b7f1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
817d2ec
Add assert to ensure that the new frame contains the required elements
SumGuo-88 Jan 6, 2025
234e461
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
93a748f
check import
SumGuo-88 Jan 6, 2025
4a38f1d
check import
SumGuo-88 Jan 6, 2025
78b2a10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
6a5d169
check test.py
SumGuo-88 Jan 6, 2025
26205d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
3ccb4b9
check ut
SumGuo-88 Jan 6, 2025
f669ac5
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
87de0e8
check ut
SumGuo-88 Jan 6, 2025
0939ef1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
7ec779f
Update deepmd/utils/argcheck.py
SumGuo-88 Jan 7, 2025
24d1386
Update deepmd/utils/argcheck.py
SumGuo-88 Jan 7, 2025
708bc78
check msi defalut value
SumGuo-88 Jan 7, 2025
02f3f28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
c648e9e
Merge branch 'devel' into devel
SumGuo-88 Jan 7, 2025
d36a24a
check ut cuda
SumGuo-88 Jan 7, 2025
b6a483a
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 7, 2025
050dbaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
b00f8de
check ut
SumGuo-88 Jan 7, 2025
2f37dfe
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 7, 2025
47fe45b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
98890c2
Merge branch 'devel' into devel
SumGuo-88 Jan 7, 2025
b9bdee5
make truetype for more sys
SumGuo-88 Jan 9, 2025
a30053f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
cfbc88a
Add skip element check function to Chang bias
SumGuo-88 Jan 9, 2025
73a20b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
0b29b05
make changebias control minframes
SumGuo-88 Jan 10, 2025
85d4da3
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 10, 2025
0400233
check merge
SumGuo-88 Jan 10, 2025
c05ffb1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
139f037
improve ut with all frames
SumGuo-88 Jan 10, 2025
5e826bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
edf1d91
check ut
SumGuo-88 Jan 10, 2025
3887013
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 10, 2025
eb9f068
check
SumGuo-88 Jan 10, 2025
10ef768
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
c2dc7ef
check skip logic and def name
SumGuo-88 Jan 10, 2025
8763165
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
d5596bf
improve warning readable
SumGuo-88 Jan 10, 2025
9f389ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
0c76ad9
check args
SumGuo-88 Jan 10, 2025
58647f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
4ce9cfb
check stat.py
SumGuo-88 Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Optional,
)

import numpy as np
from torch.utils.data import (
Dataset,
)
Expand All @@ -17,10 +18,10 @@

class DeepmdDataSetForLoader(Dataset):
def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None:
"""Construct DeePMD-style dataset containing frames cross different systems.
"""Construct DeePMD-style dataset containing frames across different systems.

Args:
- systems: Paths to systems.
- system: Path to the system.
- type_map: Atom types.
"""
self.system = system
Expand All @@ -30,6 +31,7 @@ def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None:
self._ntypes = self._data_system.get_ntypes()
self._natoms = self._data_system.get_natoms()
self._natoms_vec = self._data_system.get_natoms_vec(self._ntypes)
self.element_to_frames, self.get_all_atype = self._build_element_to_frames()

def __len__(self) -> int:
return self._data_system.nframes
Expand All @@ -40,6 +42,22 @@ def __getitem__(self, index):
b_data["natoms"] = self._natoms_vec
return b_data

def _build_element_to_frames(self):
"""Build mapping from element types to frame indexes and return all unique element types."""
element_to_frames = {element: [] for element in range(self._ntypes)}
all_elements = set()
all_frame_data = self._data_system.get_batch(self._data_system.nframes)
all_elements = np.unique(all_frame_data["type"])
for i in range(len(self)):
for element in all_elements:
element_to_frames[element].append(i)
return element_to_frames, all_elements

def get_frames_for_element(self, missing_element_name):
"""Get the frames that contain the specified element type."""
element_index = self._type_map.index(missing_element_name)
return self.element_to_frames.get(element_index, [])

SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
"""Add data requirement for this data system."""
for data_item in data_requirement:
Expand Down
44 changes: 44 additions & 0 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,50 @@ def make_stat_input(datasets, dataloaders, nbatches):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
lst.append(sys_stat)

collect_elements = set()
all_element = set()
for i in lst:
collect_values = np.unique(i["atype"].cpu().numpy())
collect_elements.update(collect_values)
for i in datasets:
all_elements_in_dataset = i.get_all_atype
all_element.update(all_elements_in_dataset)
missing_element = all_element - collect_elements
for miss in missing_element:
for i in datasets:
if i.element_to_frames.get(miss, []) is not None:
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
frame_indices = i.element_to_frames.get(miss, [])
frame_data = i.__getitem__(frame_indices[0])
break
else:
pass
sys_stat_new = {}
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
for dd in frame_data:
if dd == "type":
continue
if frame_data[dd] is None:
sys_stat_new[dd] = None
elif isinstance(frame_data[dd], np.ndarray):
if dd not in sys_stat_new:
sys_stat_new[dd] = []
frame_data[dd] = torch.from_numpy(frame_data[dd])
frame_data[dd] = frame_data[dd].unsqueeze(0)
sys_stat_new[dd].append(frame_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat_new[dd] = frame_data[dd]
else:
pass
for key in sys_stat_new:
if isinstance(sys_stat_new[key], np.float32):
pass
elif sys_stat_new[key] is None or sys_stat_new[key][0] is None:
sys_stat_new[key] = None
elif isinstance(stat_data[dd], torch.Tensor):
sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0)
dict_to_device(sys_stat_new)
lst.append(sys_stat_new)

SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
return lst


Expand Down
62 changes: 62 additions & 0 deletions source/tests/pt/test_make_stat_input.py
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import torch
from torch.utils.data import (
DataLoader,
)

from deepmd.pt.utils.stat import (
make_stat_input,
)


class TestDataset:
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, samples):
self.samples = samples
self.element_to_frames = {}
for idx, sample in enumerate(samples):
atypes = sample["atype"]
for atype in atypes:
if atype not in self.element_to_frames:
self.element_to_frames[atype] = []
self.element_to_frames[atype].append(idx)

@property
def get_all_atype(self):
return set(self.element_to_frames.keys())

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
sample = self.samples[idx]
return {
"atype": torch.tensor(sample["atype"], dtype=torch.long),
"energy": torch.tensor(sample["energy"], dtype=torch.float32),
}


class TestMakeStatInput(unittest.TestCase):
def setUp(self):
self.system = TestDataset(
[
{"atype": [1], "energy": -1.0},
{"atype": [2], "energy": -2.0},
]
)
self.datasets = [self.system]
self.dataloaders = [
DataLoader(self.system, batch_size=1, shuffle=False),
]

def test_make_stat_input(self):
nbatches = 1
lst = make_stat_input(self.datasets, self.dataloaders, nbatches=nbatches)
all_elements = self.system.get_all_atype
unique_elements = {1, 2}
self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements")


if __name__ == "__main__":
unittest.main()
Loading