diff --git a/kliff/dataset/extxyz.py b/kliff/dataset/extxyz.py index c8a0e239..5364b90f 100644 --- a/kliff/dataset/extxyz.py +++ b/kliff/dataset/extxyz.py @@ -46,8 +46,22 @@ def read_extxyz( cell = _parse_key_value(line, "Lattice", "float", 9, filename) cell = np.reshape(cell, (3, 3)) - # PBC - PBC = _parse_key_value(line, "PBC", "int", 3, filename) + # PBC integer (0, 1) or str ("T", "F") + PBC = _parse_key_value(line, "PBC", "str", 3, filename) + try: + PBC = [bool(int(i)) for i in PBC] + except ValueError: + tmp = [] + for p in PBC: + if p.lower() == "t": + tmp.append(True) + elif p.lower() == "f": + tmp.append(False) + else: + raise InputError( + f"Invalid PBC value {p} at line 2 of file {filename}." + ) + PBC = tmp # energy is optional try: @@ -125,6 +139,7 @@ def write_extxyz( energy: Optional[float] = None, forces: Optional[np.ndarray] = None, stress: Optional[List[float]] = None, + bool_as_str: bool = False, ): """ Write configuration info to a file in extended xyz file_format. @@ -139,6 +154,8 @@ def write_extxyz( forces: Nx3 array, forces on atoms; If `None`, not write to file stress: 1D array of size 6, stress on the cell in Voigt notation; If `None`, not write to file + bool_as_str: If `True`, write PBC as "T" or "F"; otherwise, write PBC as 1 or 0. + """ with open(filename, "w") as fout: @@ -155,7 +172,10 @@ def write_extxyz( else: fout.write("{:.15g} ".format(item)) - PBC = [int(i) for i in PBC] + if bool_as_str: + PBC = ["T" if i else "F" for i in PBC] + else: + PBC = [int(i) for i in PBC] fout.write('PBC="{} {} {}" '.format(PBC[0], PBC[1], PBC[2])) if energy is not None: @@ -204,7 +224,7 @@ def _parse_key_value( Args: line: The string line. key: Keyword to parse. - dtype: Expected data type of value, `int` or `float`. + dtype: Expected data type of value, `int`, `float` or `str`. size: Expected size of value. filename: File name where the line comes from. @@ -221,7 +241,7 @@ def _parse_key_value( else: value = value[value.index("=") + 1 :] value = value.lstrip(" ") - value += " " # add an whitespace at end in case this is the last key + value += " " # add a whitespace at end in case this is the last key value = value[: value.index(" ")] value = value.split() except Exception as e: @@ -237,6 +257,8 @@ def _parse_key_value( value = [float(i) for i in value] elif dtype == "int": value = [int(i) for i in value] + elif dtype == "str": + pass except Exception as e: raise InputError(f"{e}.\nCorrupted {key} data at line 2 of file {filename}.") diff --git a/tests/dataset/test_extxyz.py b/tests/dataset/test_extxyz.py index 3dfbd581..d1d8280c 100644 --- a/tests/dataset/test_extxyz.py +++ b/tests/dataset/test_extxyz.py @@ -1,7 +1,34 @@ import numpy as np import pytest -from kliff.dataset.dataset import Configuration, Dataset +from kliff.dataset.dataset import Configuration, Dataset, read_extxyz, write_extxyz + + +def test_read_write_extxyz(test_data_dir, tmp_dir): + path = test_data_dir.joinpath("configs/MoS2/MoS2_energy_forces_stress.xyz") + cell, species, coords, PBC, energy, forces, stress = read_extxyz(path) + + fname = "test.xyz" + write_extxyz( + fname, + cell, + species, + coords, + PBC, + energy, + forces, + stress, + bool_as_str=True, + ) + cell1, species1, coords1, PBC1, energy1, forces1, stress1 = read_extxyz(fname) + + assert np.allclose(cell, cell1) + assert species == species1 + assert np.allclose(coords, coords1) + assert PBC == PBC1 + assert energy == energy1 + assert np.allclose(forces, forces1) + assert stress == stress1 @pytest.mark.parametrize(