Skip to content

Commit

Permalink
discriminate virials labels in Multisystem class
Browse files Browse the repository at this point in the history
  • Loading branch information
PKUfjh committed May 24, 2022
1 parent c0bb798 commit 6e4b4fc
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
2 changes: 2 additions & 0 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,8 @@ def __append(self, system):
return
self.check_atom_names(system)
formula = system.formula
if 'virials' in system.data:
formula = formula + "_virials"
if formula in self.systems:
self.systems[formula].append(system)
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_ase_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class TestASEtraj1(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
self.multi_systems = dpdata.MultiSystems.from_file('ase_traj/HeAlO.traj', fmt='ase_traj/structure')
self.system_1 = self.multi_systems.systems['Al0He4O0']
self.system_1 = self.multi_systems.systems['Al0He4O0_virials']
self.system_2 = dpdata.LabeledSystem('ase_traj/Al0He4O0', fmt='deepmd')
self.places = 6
self.e_places = 6
Expand All @@ -25,11 +25,11 @@ def setUp (self) :
class TestASEtraj1(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
self.system_temp0 = dpdata.MultiSystems.from_file(file_name='ase_traj/HeAlO.traj', fmt='ase/structure')
self.system_1 = self.system_temp0.systems['Al2He1O3'] # .sort_atom_types()
self.system_1 = self.system_temp0.systems['Al2He1O3_virials'] # .sort_atom_types()
self.system_temp1 = dpdata.LabeledSystem('ase_traj/Al2He1O3', fmt='deepmd')
self.system_temp2 = dpdata.LabeledSystem('ase_traj/Al4He4O6', fmt='deepmd')
self.system_temp3 = dpdata.MultiSystems(self.system_temp2, self.system_temp1)
self.system_2 = self.system_temp3.systems['Al2He1O3']
self.system_2 = self.system_temp3.systems['Al2He1O3_virials']
self.places = 6
self.e_places = 6
self.f_places = 6
Expand Down
22 changes: 11 additions & 11 deletions tests/test_quip_gap_xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class TestQuipGapxyz1(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
self.multi_systems = dpdata.MultiSystems.from_file('xyz/xyz_unittest.xyz','quip/gap/xyz')
self.system_1 = self.multi_systems.systems['B1C9']
self.system_1 = self.multi_systems.systems['B1C9_virials']
self.system_2 = dpdata.LabeledSystem('xyz/B1C9', fmt='deepmd')
self.places = 6
self.e_places = 6
Expand All @@ -17,11 +17,11 @@ def setUp (self) :
class TestQuipGapxyz2(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
self.system_temp0 = dpdata.MultiSystems.from_file(file_name='xyz/xyz_unittest.xyz', fmt='quip/gap/xyz')
self.system_1 = self.system_temp0.systems['B5C7'] # .sort_atom_types()
self.system_1 = self.system_temp0.systems['B5C7_virials'] # .sort_atom_types()
self.system_temp1 = dpdata.LabeledSystem('xyz/B1C9', fmt='deepmd')
self.system_temp2 = dpdata.LabeledSystem('xyz/B5C7', fmt='deepmd')
self.system_temp3 = dpdata.MultiSystems(self.system_temp2, self.system_temp1)
self.system_2 = self.system_temp3.systems['B5C7']
self.system_2 = self.system_temp3.systems['B5C7_virials']
self.places = 6
self.e_places = 6
self.f_places = 6
Expand All @@ -30,10 +30,10 @@ def setUp (self) :
class TestQuipGapxyzsort1(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
self.multi_systems_1 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.sort.xyz','quip/gap/xyz')
self.system_1 = self.multi_systems_1.systems['B5C7']
self.system_1 = self.multi_systems_1.systems['B5C7_virials']
self.system_1.sort_atom_types()
self.multi_systems_2 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.xyz','quip/gap/xyz')
self.system_2 = self.multi_systems_2.systems['B5C7']
self.system_2 = self.multi_systems_2.systems['B5C7_virials']
self.places = 6
self.e_places = 6
self.f_places = 6
Expand All @@ -42,10 +42,10 @@ def setUp (self) :
class TestQuipGapxyzsort2(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
self.multi_systems_1 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.sort.xyz','quip/gap/xyz')
self.system_1 = self.multi_systems_1.systems['B1C9']
self.system_1 = self.multi_systems_1.systems['B1C9_virials']
self.system_1.sort_atom_types()
self.multi_systems_2 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.xyz','quip/gap/xyz')
self.system_2 = self.multi_systems_2.systems['B1C9']
self.system_2 = self.multi_systems_2.systems['B1C9_virials']
self.places = 6
self.e_places = 6
self.f_places = 6
Expand All @@ -54,10 +54,10 @@ def setUp (self) :
class TestQuipGapxyzfield(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
self.multi_systems_1 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.field.xyz','quip/gap/xyz')
self.system_1 = self.multi_systems_1.systems['B1C9']
self.system_1 = self.multi_systems_1.systems['B1C9_virials']
self.system_1.sort_atom_types()
self.multi_systems_2 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.xyz','quip/gap/xyz')
self.system_2 = self.multi_systems_2.systems['B1C9']
self.system_2 = self.multi_systems_2.systems['B1C9_virials']
self.places = 6
self.e_places = 6
self.f_places = 6
Expand All @@ -66,10 +66,10 @@ def setUp (self) :
class TestQuipGapxyzfield2(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp (self) :
self.multi_systems_1 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.field.xyz','quip/gap/xyz')
self.system_1 = self.multi_systems_1.systems['B5C7']
self.system_1 = self.multi_systems_1.systems['B5C7_virials']
self.system_1.sort_atom_types()
self.multi_systems_2 = dpdata.MultiSystems.from_file('xyz/xyz_unittest.xyz','quip/gap/xyz')
self.system_2 = self.multi_systems_2.systems['B5C7']
self.system_2 = self.multi_systems_2.systems['B5C7_virials']
self.places = 6
self.e_places = 6
self.f_places = 6
Expand Down

0 comments on commit 6e4b4fc

Please sign in to comment.