Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix) Make bias statistics complete for all elements #4496

Draft
wants to merge 10 commits into
base: devel
Choose a base branch
from
12 changes: 12 additions & 0 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@
b_data["natoms"] = self._natoms_vec
return b_data

def _build_element_to_frames(self):

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.
"""Mapping element types to frame indexes"""
element_to_frames = {element: [] for element in range(self._ntypes)}
for frame_idx in range(len(self)):
frame_data = self._data_system.get_item_torch(frame_idx)

elements = frame_data["atype"]
for element in set(elements):
if len(element_to_frames[element]) < 10:
element_to_frames[element].append(frame_idx)
return element_to_frames

def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
"""Add data requirement for this data system."""
for data_item in data_requirement:
Expand Down
41 changes: 41 additions & 0 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,47 @@
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
lst.append(sys_stat)

all_elements = set()
if datasets and hasattr(datasets[0], "element_to_frames"):
all_elements.update(datasets[0].element_to_frames.keys())
print("we want", all_elements)

collected_elements = set()
for sys_stat in lst:
if "atype" in sys_stat:
collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy()))
missing_elements = all_elements - collected_elements

for missing_element in missing_elements:
for i, dataset in enumerate(datasets):
if hasattr(dataset, "element_to_frames"):
frame_indices = dataset.element_to_frames.get(
missing_element, []
)
for frame_idx in frame_indices:
if len(lst[i]["atype"]) >= nbatches:
break
frame_data = dataset[frame_idx]
for key in frame_data:
if key not in lst[i]:
lst[i][key] = []
lst[i][key].append(frame_data[key])

collected_elements = set()
for sys_stat in lst:
if "atype" in sys_stat:
collected_elements.update(
np.unique(sys_stat["atype"].cpu().numpy())
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Based on the code analysis, I can now provide a final response:

Add validation check for missing elements after collection

After the final collection of elements, add a validation check to ensure all required elements were found:

# After the final collected_elements update
if all_elements - collected_elements:
    missing = all_elements - collected_elements
    warnings.warn(f"Could not find frames containing elements: {missing}")
  • Add this check at line 118 in deepmd/pt/utils/stat.py after the final collection loop
  • Import warnings module at the top of the file
🔗 Analysis chain

Consider adding validation for collected elements.

After collecting elements, it would be good to verify that all required elements were found and log a warning if any are still missing.


Let me analyze the code further to understand the element collection and validation context.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if there are any elements that might be consistently missing across datasets

# Search for element type definitions and usages
rg -A 5 "atype|element_to_frames"

Length of output: 84908


Script:

#!/bin/bash
# Search for element validation or error handling related to missing elements
rg -A 5 "missing_elements|element_to_frames|all_elements"

Length of output: 2630

Check notice

Code scanning / CodeQL

Nested loops with same variable Note

Nested for statement uses loop variable 'sys_stat' of enclosing
for statement
.
for sys_stat in lst:
for key in sys_stat:
if isinstance(sys_stat[key], list) and isinstance(
sys_stat[key][0], torch.Tensor
):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)

return lst


Expand Down
Loading