Skip to content

Commit

Permalink
breaking: use all sets for training and test (#3862)
Browse files Browse the repository at this point in the history
Fix #3860.

Remove `train_dirs` and `test_dir` in `DeepmdData`.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- All data sets are now trained and tested by default, simplifying the
training process.

- **Bug Fixes**
- Improved logic for handling training directories and test set merging.

- **Tests**
  - Added new test cases for the updated data handling methods.
- Updated existing tests to reflect changes in data set handling and
batch sizes.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jun 11, 2024
1 parent a7ab1af commit 7786126
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 63 deletions.
61 changes: 24 additions & 37 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class DeepmdData:
modifier
Data modifier that has the method `modify_data`
trn_all_set
Use all sets as training dataset. Otherwise, if the number of sets is more than 1, the last set is left for test.
[DEPRECATED] Deprecated. Now all sets are trained and tested.
sort_atoms : bool
Sort atoms by atom types. Required to enable when the data is directly feeded to
descriptors except mixed types.
Expand Down Expand Up @@ -109,15 +109,6 @@ def __init__(
# make idx map
self.sort_atoms = sort_atoms
self.idx_map = self._make_idx_map(self.atom_type)
# train dirs
self.test_dir = self.dirs[-1]
if trn_all_set:
self.train_dirs = self.dirs
else:
if len(self.dirs) == 1:
self.train_dirs = self.dirs
else:
self.train_dirs = self.dirs[:-1]
self.data_dict = {}
# add box and coord
self.add("box", 9, must=self.pbc)
Expand Down Expand Up @@ -225,7 +216,7 @@ def get_data_dict(self) -> dict:

def check_batch_size(self, batch_size):
"""Check if the system can get a batch of data with `batch_size` frames."""
for ii in self.train_dirs:
for ii in self.dirs:
if self.data_dict["coord"]["high_prec"]:
tmpe = (
(ii / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
Expand All @@ -240,24 +231,7 @@ def check_batch_size(self, batch_size):

def check_test_size(self, test_size):
"""Check if the system can get a test dataset with `test_size` frames."""
if self.data_dict["coord"]["high_prec"]:
tmpe = (
(self.test_dir / "coord.npy")
.load_numpy()
.astype(GLOBAL_ENER_FLOAT_PRECISION)
)
else:
tmpe = (
(self.test_dir / "coord.npy")
.load_numpy()
.astype(GLOBAL_NP_FLOAT_PRECISION)
)
if tmpe.ndim == 1:
tmpe = tmpe.reshape([1, -1])
if tmpe.shape[0] < test_size:
return self.test_dir, tmpe.shape[0]
else:
return None
return self.check_batch_size(test_size)

def get_item_torch(self, index: int) -> dict:
"""Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets.
Expand Down Expand Up @@ -287,7 +261,7 @@ def get_batch(self, batch_size: int) -> dict:
else:
set_size = 0
if self.iterator + batch_size > set_size:
self._load_batch_set(self.train_dirs[self.set_count % self.get_numb_set()])
self._load_batch_set(self.dirs[self.set_count % self.get_numb_set()])
self.set_count += 1
set_size = self.batch_set["coord"].shape[0]
iterator_1 = self.iterator + batch_size
Expand All @@ -307,7 +281,7 @@ def get_test(self, ntests: int = -1) -> dict:
Size of the test data set. If `ntests` is -1, all test data will be get.
"""
if not hasattr(self, "test_set"):
self._load_test_set(self.test_dir, self.shuffle_test)
self._load_test_set(self.shuffle_test)
if ntests == -1:
idx = None
else:
Expand Down Expand Up @@ -340,11 +314,11 @@ def get_atom_type(self) -> List[int]:

def get_numb_set(self) -> int:
"""Get number of training sets."""
return len(self.train_dirs)
return len(self.dirs)

def get_numb_batch(self, batch_size: int, set_idx: int) -> int:
"""Get the number of batches in a set."""
data = self._load_set(self.train_dirs[set_idx])
data = self._load_set(self.dirs[set_idx])
ret = data["coord"].shape[0] // batch_size
if ret == 0:
ret = 1
Expand All @@ -353,7 +327,7 @@ def get_numb_batch(self, batch_size: int, set_idx: int) -> int:
def get_sys_numb_batch(self, batch_size: int) -> int:
"""Get the number of batches in the data system."""
ret = 0
for ii in range(len(self.train_dirs)):
for ii in range(len(self.dirs)):
ret += self.get_numb_batch(batch_size, ii)
return ret

Expand Down Expand Up @@ -388,7 +362,7 @@ def avg(self, key):
info = self.data_dict[key]
ndof = info["ndof"]
eners = []
for ii in self.train_dirs:
for ii in self.dirs:
data = self._load_set(ii)
ei = data[key].reshape([-1, ndof])
eners.append(ei)
Expand Down Expand Up @@ -441,8 +415,21 @@ def _load_batch_set(self, set_name: DPPath):
def reset_get_batch(self):
self.iterator = 0

def _load_test_set(self, set_name: DPPath, shuffle_test):
self.test_set = self._load_set(set_name)
def _load_test_set(self, shuffle_test: bool):
test_sets = []
for ii in self.dirs:
test_set = self._load_set(ii)
test_sets.append(test_set)
# merge test sets
self.test_set = {}
assert len(test_sets) > 0
for kk in test_sets[0]:
if "find_" in kk:
self.test_set[kk] = test_sets[0][kk]
else:
self.test_set[kk] = np.concatenate(
[test_set[kk] for test_set in test_sets], axis=0
)
if shuffle_test:
self.test_set, _ = self._shuffle_data(self.test_set)

Expand Down
33 changes: 25 additions & 8 deletions source/tests/tf/test_deepmd_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def setUp(self):
path = os.path.join(self.data_name, "set.bar", "test_frame.npy")
self.test_frame_bar = rng.random([self.nframes, 5])
np.save(path, self.test_frame_bar)
path = os.path.join(self.data_name, "set.tar", "test_frame.npy")
self.test_frame_tar = rng.random([2, 5])
np.save(path, self.test_frame_tar)
# t n
self.test_null = np.zeros([self.nframes, 2 * self.natoms])
# tensor shape
Expand All @@ -162,8 +165,9 @@ def test_init(self):
self.assertEqual(dd.idx_map[0], 1)
self.assertEqual(dd.idx_map[1], 0)
self.assertEqual(dd.type_map, ["foo", "bar"])
self.assertEqual(dd.test_dir, "test_data/set.tar")
self.assertEqual(dd.train_dirs, ["test_data/set.bar", "test_data/set.foo"])
self.assertEqual(
dd.dirs, ["test_data/set.bar", "test_data/set.foo", "test_data/set.tar"]
)

def test_init_type_map(self):
dd = DeepmdData(self.data_name, type_map=["bar", "foo", "tar"])
Expand All @@ -182,7 +186,7 @@ def test_load_set(self):
)
data = dd._load_set(os.path.join(self.data_name, "set.foo"))
nframes = data["coord"].shape[0]
self.assertEqual(dd.get_numb_set(), 2)
self.assertEqual(dd.get_numb_set(), 3)
self.assertEqual(dd.get_type_map(), ["foo", "bar"])
self.assertEqual(dd.get_natoms(), 2)
self.assertEqual(list(dd.get_natoms_vec(3)), [2, 2, 1, 1, 0])
Expand Down Expand Up @@ -257,7 +261,10 @@ def test_avg(self):
dd = DeepmdData(self.data_name).add("test_frame", 5, atomic=False, must=True)
favg = dd.avg("test_frame")
fcmp = np.average(
np.concatenate((self.test_frame, self.test_frame_bar), axis=0), axis=0
np.concatenate(
(self.test_frame, self.test_frame_bar, self.test_frame_tar), axis=0
),
axis=0,
)
np.testing.assert_almost_equal(favg, fcmp, places)

Expand All @@ -266,13 +273,17 @@ def test_check_batch_size(self):
ret = dd.check_batch_size(10)
self.assertEqual(ret, (os.path.join(self.data_name, "set.bar"), 5))
ret = dd.check_batch_size(5)
self.assertEqual(ret, (os.path.join(self.data_name, "set.tar"), 2))
ret = dd.check_batch_size(1)
self.assertEqual(ret, None)

def test_check_test_size(self):
dd = DeepmdData(self.data_name)
ret = dd.check_test_size(10)
self.assertEqual(ret, (os.path.join(self.data_name, "set.bar"), 5))
ret = dd.check_test_size(5)
self.assertEqual(ret, (os.path.join(self.data_name, "set.tar"), 2))
ret = dd.check_test_size(2)
ret = dd.check_test_size(1)
self.assertEqual(ret, None)

def test_get_batch(self):
Expand All @@ -284,6 +295,10 @@ def test_get_batch(self):
data = dd.get_batch(5)
self._comp_np_mat2(np.sort(data["coord"], axis=0), np.sort(self.coord, axis=0))
data = dd.get_batch(5)
self._comp_np_mat2(
np.sort(data["coord"], axis=0), np.sort(self.coord_tar, axis=0)
)
data = dd.get_batch(5)
self._comp_np_mat2(
np.sort(data["coord"], axis=0), np.sort(self.coord_bar, axis=0)
)
Expand All @@ -293,8 +308,11 @@ def test_get_batch(self):
def test_get_test(self):
dd = DeepmdData(self.data_name)
data = dd.get_test()
expected_coord = np.concatenate(
(self.coord_bar, self.coord, self.coord_tar), axis=0
)
self._comp_np_mat2(
np.sort(data["coord"], axis=0), np.sort(self.coord_tar, axis=0)
np.sort(data["coord"], axis=0), np.sort(expected_coord, axis=0)
)

def test_get_nbatch(self):
Expand Down Expand Up @@ -368,8 +386,7 @@ def test_init(self):
dd = DeepmdData(self.data_name)
self.assertEqual(dd.idx_map[0], 0)
self.assertEqual(dd.type_map, ["X"])
self.assertEqual(dd.test_dir, self.data_name + "#/set.000")
self.assertEqual(dd.train_dirs, [self.data_name + "#/set.000"])
self.assertEqual(dd.dirs[0], self.data_name + "#/set.000")

def test_get_batch(self):
dd = DeepmdData(self.data_name)
Expand Down
Loading

0 comments on commit 7786126

Please sign in to comment.