From 6e4b4fcf3c6ea8977d9dc13e20ef338dac86488a Mon Sep 17 00:00:00 2001 From: PKUfjh <2001110077@pku.edu.cn> Date: Tue, 24 May 2022 19:48:37 +0800 Subject: [PATCH] discriminate virials labels in Multisystem class --- dpdata/system.py | 2 ++ tests/test_ase_traj.py | 6 +++--- tests/test_quip_gap_xyz.py | 22 +++++++++++----------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/dpdata/system.py b/dpdata/system.py index df305a6e..184d191d 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1154,6 +1154,8 @@ def __append(self, system): return self.check_atom_names(system) formula = system.formula + if 'virials' in system.data: + formula = formula + "_virials" if formula in self.systems: self.systems[formula].append(system) else: diff --git a/tests/test_ase_traj.py b/tests/test_ase_traj.py index 6c37a31c..40aeea99 100644 --- a/tests/test_ase_traj.py +++ b/tests/test_ase_traj.py @@ -14,7 +14,7 @@ class TestASEtraj1(unittest.TestCase, CompLabeledSys, IsPBC): def setUp (self) : self.multi_systems = dpdata.MultiSystems.from_file('ase_traj/HeAlO.traj', fmt='ase_traj/structure') - self.system_1 = self.multi_systems.systems['Al0He4O0'] + self.system_1 = self.multi_systems.systems['Al0He4O0_virials'] self.system_2 = dpdata.LabeledSystem('ase_traj/Al0He4O0', fmt='deepmd') self.places = 6 self.e_places = 6 @@ -25,11 +25,11 @@ def setUp (self) : class TestASEtraj1(unittest.TestCase, CompLabeledSys, IsPBC): def setUp (self) : self.system_temp0 = dpdata.MultiSystems.from_file(file_name='ase_traj/HeAlO.traj', fmt='ase/structure') - self.system_1 = self.system_temp0.systems['Al2He1O3'] # .sort_atom_types() + self.system_1 = self.system_temp0.systems['Al2He1O3_virials'] # .sort_atom_types() self.system_temp1 = dpdata.LabeledSystem('ase_traj/Al2He1O3', fmt='deepmd') self.system_temp2 = dpdata.LabeledSystem('ase_traj/Al4He4O6', fmt='deepmd') self.system_temp3 = dpdata.MultiSystems(self.system_temp2, self.system_temp1) - self.system_2 = self.system_temp3.systems['Al2He1O3'] + self.system_2 = self.system_temp3.systems['Al2He1O3_virials'] self.places = 6 self.e_places = 6 self.f_places = 6 diff --git a/tests/test_quip_gap_xyz.py b/tests/test_quip_gap_xyz.py index 8a023bc4..adc187e4 100644 --- a/tests/test_quip_gap_xyz.py +++ b/tests/test_quip_gap_xyz.py @@ -7,7 +7,7 @@ class TestQuipGapxyz1(unittest.TestCase, CompLabeledSys, IsPBC): def setUp (self) : self.multi_systems = dpdata.MultiSystems.from_file('xyz/xyz_unittest.xyz','quip/gap/xyz') - self.system_1 = self.multi_systems.systems['B1C9'] + self.system_1 = self.multi_systems.systems['B1C9_virials'] self.system_2 = dpdata.LabeledSystem('xyz/B1C9', fmt='deepmd') self.places = 6 self.e_places = 6 @@ -17,11 +17,11 @@ def setUp (self) : class TestQuipGapxyz2(unittest.TestCase, CompLabeledSys, IsPBC): def setUp (self) : self.system_temp0 = dpdata.MultiSystems.from_file(file_name='xyz/xyz_unittest.xyz', fmt='quip/gap/xyz') - self.system_1 = self.system_temp0.systems['B5C7'] # .sort_atom_types() + self.system_1 = self.system_temp0.systems['B5C7_virials'] # .sort_atom_types() self.system_temp1 = dpdata.LabeledSystem('xyz/B1C9', fmt='deepmd') self.system_temp2 = dpdata.LabeledSystem('xyz/B5C7', fmt='deepmd') self.system_temp3 = dpdata.MultiSystems(self.system_temp2, self.system_temp1) - self.system_2 = self.system_temp3.systems['B5C7'] + self.system_2 = self.system_temp3.systems['B5C7_virials'] self.places = 6 self.e_places = 6 self.f_places = 6 @@ -30,10 +30,10 @@ def setUp (self) : class TestQuipGapxyzsort1(unittest.TestCase, CompLabeledSys, IsPBC): def setUp (self) : self.multi_systems_1 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.sort.xyz','quip/gap/xyz') - self.system_1 = self.multi_systems_1.systems['B5C7'] + self.system_1 = self.multi_systems_1.systems['B5C7_virials'] self.system_1.sort_atom_types() self.multi_systems_2 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.xyz','quip/gap/xyz') - self.system_2 = self.multi_systems_2.systems['B5C7'] + self.system_2 = self.multi_systems_2.systems['B5C7_virials'] self.places = 6 self.e_places = 6 self.f_places = 6 @@ -42,10 +42,10 @@ def setUp (self) : class TestQuipGapxyzsort2(unittest.TestCase, CompLabeledSys, IsPBC): def setUp (self) : self.multi_systems_1 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.sort.xyz','quip/gap/xyz') - self.system_1 = self.multi_systems_1.systems['B1C9'] + self.system_1 = self.multi_systems_1.systems['B1C9_virials'] self.system_1.sort_atom_types() self.multi_systems_2 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.xyz','quip/gap/xyz') - self.system_2 = self.multi_systems_2.systems['B1C9'] + self.system_2 = self.multi_systems_2.systems['B1C9_virials'] self.places = 6 self.e_places = 6 self.f_places = 6 @@ -54,10 +54,10 @@ def setUp (self) : class TestQuipGapxyzfield(unittest.TestCase, CompLabeledSys, IsPBC): def setUp (self) : self.multi_systems_1 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.field.xyz','quip/gap/xyz') - self.system_1 = self.multi_systems_1.systems['B1C9'] + self.system_1 = self.multi_systems_1.systems['B1C9_virials'] self.system_1.sort_atom_types() self.multi_systems_2 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.xyz','quip/gap/xyz') - self.system_2 = self.multi_systems_2.systems['B1C9'] + self.system_2 = self.multi_systems_2.systems['B1C9_virials'] self.places = 6 self.e_places = 6 self.f_places = 6 @@ -66,10 +66,10 @@ def setUp (self) : class TestQuipGapxyzfield2(unittest.TestCase, CompLabeledSys, IsPBC): def setUp (self) : self.multi_systems_1 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.field.xyz','quip/gap/xyz') - self.system_1 = self.multi_systems_1.systems['B5C7'] + self.system_1 = self.multi_systems_1.systems['B5C7_virials'] self.system_1.sort_atom_types() self.multi_systems_2 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.xyz','quip/gap/xyz') - self.system_2 = self.multi_systems_2.systems['B5C7'] + self.system_2 = self.multi_systems_2.systems['B5C7_virials'] self.places = 6 self.e_places = 6 self.f_places = 6