diff --git a/python/docs/nuri.rst b/python/docs/nuri.rst index 352f5f5b..30be737b 100644 --- a/python/docs/nuri.rst +++ b/python/docs/nuri.rst @@ -28,10 +28,25 @@ Submodules Top-level Functions ------------------- +Readers +------- + .. autofunction:: nuri.readfile .. autofunction:: nuri.readstring +Writers +------- + +These functions all release the GIL and are thread-safe. Thread-based +parallelism is recommended for writing multiple molecules in parallel. + +.. autofunction:: nuri.to_smiles + +.. autofunction:: nuri.to_mol2 + +.. autofunction:: nuri.to_sdf + -------------------- Top-level Attributes -------------------- diff --git a/python/include/nuri/python/utils.h b/python/include/nuri/python/utils.h index 688c6d9f..cde16c47 100644 --- a/python/include/nuri/python/utils.h +++ b/python/include/nuri/python/utils.h @@ -324,6 +324,7 @@ inline int py_check_index(int size, int idx, const char *onerror) { } constexpr inline static py::keep_alive<0, 1> kReturnsSubobject {}; +constexpr inline static py::call_guard kThreadSafe {}; template py::class_ &add_sequence_interface(py::class_ &cls, Size size, diff --git a/python/src/nuri/__init__.py b/python/src/nuri/__init__.py index 1da5dad2..07f23c51 100644 --- a/python/src/nuri/__init__.py +++ b/python/src/nuri/__init__.py @@ -11,6 +11,9 @@ __all__ = [ "readfile", "readstring", + "to_smiles", + "to_mol2", + "to_sdf", "periodic_table", "__version__", ] @@ -22,4 +25,4 @@ from . import _log_adapter from .core import periodic_table -from .fmt import readfile, readstring +from .fmt import readfile, readstring, to_smiles, to_mol2, to_sdf diff --git a/python/src/nuri/fmt/fmt_module.cpp b/python/src/nuri/fmt/fmt_module.cpp index c57934e8..d909bdf6 100644 --- a/python/src/nuri/fmt/fmt_module.cpp +++ b/python/src/nuri/fmt/fmt_module.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -24,6 +25,9 @@ #include "nuri/core/molecule.h" #include "nuri/fmt/base.h" +#include "nuri/fmt/mol2.h" +#include "nuri/fmt/sdf.h" +#include "nuri/fmt/smiles.h" #include "nuri/python/core/core_module.h" #include "nuri/python/exception.h" #include "nuri/python/utils.h" @@ -101,6 +105,21 @@ class PyMoleculeReader { bool skip_on_error_; }; +template +std::string try_write(const Molecule &mol, std::string_view fmt, F writer, + Args &&...args) { + std::string buf; + if (!writer(buf, mol, std::forward(args)...)) + throw py::value_error(absl::StrCat("Failed to convert molecule to ", fmt)); + return buf; +} + +int writer_check_conf(const Molecule &mol, std::optional oconf) { + if (!oconf) + return -1; + return check_conf(mol, *oconf); +} + namespace fs = std::filesystem; NURI_PYTHON_MODULE(m) { @@ -167,6 +186,67 @@ The returned object is an iterable of molecules. >>> for mol in nuri.readstring("smi", "C"): ... print(mol[0].atomic_number) 6 +)doc"); + + m.def( + "to_smiles", + [](const PyMol &mol) { + return try_write(*mol, "smiles", write_smiles, false); + }, + py::arg("mol"), kThreadSafe, R"doc( +Convert a molecule to SMILES string. + +:param mol: The molecule to convert. +:raises ValueError: If the conversion fails. +)doc") + .def( + "to_mol2", + [](const PyMol &mol, std::optional oconf) { + int conf = writer_check_conf(*mol, oconf); + return try_write(*mol, "Mol2", write_mol2, conf); + }, + py::arg("mol"), py::arg("conf") = py::none(), kThreadSafe, R"doc( +Convert a molecule to Mol2 string. + +:param mol: The molecule to convert. +:param conf: The conformation to convert. If not specified, writes all + conformations. Ignored if the molecule has no conformations. +:raises IndexError: If the molecule has any conformations and `conf` is out of + range. +:raises ValueError: If the conversion fails. +)doc") + .def( + "to_sdf", + [](const PyMol &mol, std::optional oconf, + std::optional oversion) { + int conf = writer_check_conf(*mol, oconf); + + SDFVersion version = SDFVersion::kAutomatic; + if (oversion) { + int user_version = *oversion; + if (user_version == 2000) + version = SDFVersion::kV2000; + else if (user_version == 3000) + version = SDFVersion::kV3000; + else + throw py::value_error( + absl::StrCat("Invalid SDF version: ", user_version)); + } + + return try_write(*mol, "SDF", write_sdf, conf, version); + }, + py::arg("mol"), py::arg("conf") = py::none(), + py::arg("version") = py::none(), kThreadSafe, R"doc( +Convert a molecule to SDF string. + +:param mol: The molecule to convert. +:param conf: The conformation to convert. If not specified, writes all + conformations. Ignored if the molecule has no conformations. +:param version: The SDF version to write. If not specified, the version is + automatically determined. Only 2000 and 3000 are supported. +:raises IndexError: If the molecule has any conformations and `conf` is out of + range. +:raises ValueError: If the conversion fails, or if the version is invalid. )doc"); } } // namespace diff --git a/python/test/core/conftest.py b/python/test/conftest.py similarity index 100% rename from python/test/core/conftest.py rename to python/test/conftest.py diff --git a/python/test/fmt/mol2_test.py b/python/test/fmt/mol2_test.py index 3d444c7c..fd526ff9 100644 --- a/python/test/fmt/mol2_test.py +++ b/python/test/fmt/mol2_test.py @@ -6,6 +6,9 @@ from pathlib import Path from typing import List +import numpy as np +import pytest + import nuri from nuri.core import Molecule, Hyb @@ -108,3 +111,22 @@ def test_mol2_file(tmp_path: Path): def test_mol2_str(): mols = list(nuri.readstring("mol2", mol2_data)) _verify_mols(mols) + + mol2_re = "".join(map(nuri.to_mol2, mols)) + mols_re = list(nuri.readstring("mol2", mol2_re)) + _verify_mols(mols_re) + + +def test_mol2_options(mol3d: Molecule): + mol2s = nuri.to_mol2(mol3d) + + mols = list(nuri.readstring("mol2", mol2s)) + assert len(mols) == 2 + + mol2s = nuri.to_mol2(mol3d, conf=1) + mols = list(nuri.readstring("mol2", mol2s)) + assert len(mols) == 1 + assert np.allclose(mol3d.get_conf(1), mols[0].get_conf(0), atol=1e-3) + + with pytest.raises(IndexError): + mol2s = nuri.to_mol2(mol3d, conf=2) diff --git a/python/test/fmt/sdf_test.py b/python/test/fmt/sdf_test.py index a2b6a392..3629ef76 100644 --- a/python/test/fmt/sdf_test.py +++ b/python/test/fmt/sdf_test.py @@ -6,6 +6,9 @@ from pathlib import Path from typing import List +import numpy as np +import pytest + import nuri from nuri.core import Molecule, Hyb @@ -102,3 +105,33 @@ def test_sdf_file(tmp_path: Path): def test_sdf_str(): mols = list(nuri.readstring("sdf", sdf_data)) _verify_mols(mols) + + sdf_re = "".join(map(nuri.to_sdf, mols)) + mols_re = list(nuri.readstring("sdf", sdf_re)) + _verify_mols(mols_re) + + +def test_sdf_options(mol3d: Molecule): + sdfs = nuri.to_sdf(mol3d) + + mols = list(nuri.readstring("sdf", sdfs)) + assert len(mols) == 2 + + sdfs = nuri.to_sdf(mol3d, version=2000) + mols = list(nuri.readstring("sdf", sdfs)) + assert len(mols) == 2 + + sdfs = nuri.to_sdf(mol3d, version=3000) + mols = list(nuri.readstring("sdf", sdfs)) + assert len(mols) == 2 + + sdfs = nuri.to_sdf(mol3d, conf=1) + mols = list(nuri.readstring("sdf", sdfs)) + assert len(mols) == 1 + assert np.allclose(mol3d.get_conf(1), mols[0].get_conf(0), atol=1e-3) + + with pytest.raises(IndexError): + sdfs = nuri.to_sdf(mol3d, conf=2) + + with pytest.raises(ValueError, match="Invalid SDF version"): + sdfs = nuri.to_sdf(mol3d, version=9999) diff --git a/python/test/fmt/smi_test.py b/python/test/fmt/smi_test.py index a05aada5..06e30e0b 100644 --- a/python/test/fmt/smi_test.py +++ b/python/test/fmt/smi_test.py @@ -45,3 +45,7 @@ def test_smiles_file(tmp_path: Path): def test_smiles_str(): mols = list(nuri.readstring("smi", smi_data)) _verify_mols(mols) + + smiles_re = "".join(map(nuri.to_smiles, mols)) + mols_re = list(nuri.readstring("smi", smiles_re)) + _verify_mols(mols_re)